Show More
@@ -0,0 +1,25 b'' | |||
|
1 | """Grab the global logger instance.""" | |
|
2 | ||
|
3 | # Copyright (c) IPython Development Team. | |
|
4 | # Distributed under the terms of the Modified BSD License. | |
|
5 | ||
|
6 | import logging | |
|
7 | ||
|
8 | _logger = None | |
|
9 | ||
|
10 | def get_logger(): | |
|
11 | """Grab the global logger instance. | |
|
12 | ||
|
13 | If a global IPython Application is instantiated, grab its logger. | |
|
14 | Otherwise, grab the root logger. | |
|
15 | """ | |
|
16 | global _logger | |
|
17 | ||
|
18 | if _logger is None: | |
|
19 | from IPython.config import Application | |
|
20 | if Application.initialized(): | |
|
21 | _logger = Application.instance().log | |
|
22 | else: | |
|
23 | logging.basicConfig() | |
|
24 | _logger = logging.getLogger() | |
|
25 | return _logger |
@@ -1,390 +1,366 b'' | |||
|
1 | 1 | # encoding: utf-8 |
|
2 | """ | |
|
3 | A base class for objects that are configurable. | |
|
2 | """A base class for objects that are configurable.""" | |
|
4 | 3 | |
|
5 | Inheritance diagram: | |
|
4 | # Copyright (c) IPython Development Team. | |
|
5 | # Distributed under the terms of the Modified BSD License. | |
|
6 | 6 | |
|
7 | .. inheritance-diagram:: IPython.config.configurable | |
|
8 | :parts: 3 | |
|
9 | ||
|
10 | Authors: | |
|
11 | ||
|
12 | * Brian Granger | |
|
13 | * Fernando Perez | |
|
14 | * Min RK | |
|
15 | """ | |
|
16 | 7 | from __future__ import print_function |
|
17 | 8 | |
|
18 | #----------------------------------------------------------------------------- | |
|
19 | # Copyright (C) 2008-2011 The IPython Development Team | |
|
20 | # | |
|
21 | # Distributed under the terms of the BSD License. The full license is in | |
|
22 | # the file COPYING, distributed as part of this software. | |
|
23 | #----------------------------------------------------------------------------- | |
|
24 | ||
|
25 | #----------------------------------------------------------------------------- | |
|
26 | # Imports | |
|
27 | #----------------------------------------------------------------------------- | |
|
28 | ||
|
29 | 9 | import logging |
|
30 | 10 | from copy import deepcopy |
|
31 | 11 | |
|
32 | 12 | from .loader import Config, LazyConfigValue |
|
33 | 13 | from IPython.utils.traitlets import HasTraits, Instance |
|
34 | 14 | from IPython.utils.text import indent, wrap_paragraphs |
|
35 | 15 | from IPython.utils.py3compat import iteritems |
|
36 | 16 | |
|
37 | 17 | |
|
38 | 18 | #----------------------------------------------------------------------------- |
|
39 | 19 | # Helper classes for Configurables |
|
40 | 20 | #----------------------------------------------------------------------------- |
|
41 | 21 | |
|
42 | 22 | |
|
43 | 23 | class ConfigurableError(Exception): |
|
44 | 24 | pass |
|
45 | 25 | |
|
46 | 26 | |
|
47 | 27 | class MultipleInstanceError(ConfigurableError): |
|
48 | 28 | pass |
|
49 | 29 | |
|
50 | 30 | #----------------------------------------------------------------------------- |
|
51 | 31 | # Configurable implementation |
|
52 | 32 | #----------------------------------------------------------------------------- |
|
53 | 33 | |
|
54 | 34 | class Configurable(HasTraits): |
|
55 | 35 | |
|
56 | 36 | config = Instance(Config, (), {}) |
|
57 | 37 | parent = Instance('IPython.config.configurable.Configurable') |
|
58 | 38 | |
|
59 | 39 | def __init__(self, **kwargs): |
|
60 | 40 | """Create a configurable given a config config. |
|
61 | 41 | |
|
62 | 42 | Parameters |
|
63 | 43 | ---------- |
|
64 | 44 | config : Config |
|
65 | 45 | If this is empty, default values are used. If config is a |
|
66 | 46 | :class:`Config` instance, it will be used to configure the |
|
67 | 47 | instance. |
|
68 | 48 | parent : Configurable instance, optional |
|
69 | 49 | The parent Configurable instance of this object. |
|
70 | 50 | |
|
71 | 51 | Notes |
|
72 | 52 | ----- |
|
73 | 53 | Subclasses of Configurable must call the :meth:`__init__` method of |
|
74 | 54 | :class:`Configurable` *before* doing anything else and using |
|
75 | 55 | :func:`super`:: |
|
76 | 56 | |
|
77 | 57 | class MyConfigurable(Configurable): |
|
78 | 58 | def __init__(self, config=None): |
|
79 | 59 | super(MyConfigurable, self).__init__(config=config) |
|
80 | 60 | # Then any other code you need to finish initialization. |
|
81 | 61 | |
|
82 | 62 | This ensures that instances will be configured properly. |
|
83 | 63 | """ |
|
84 | 64 | parent = kwargs.pop('parent', None) |
|
85 | 65 | if parent is not None: |
|
86 | 66 | # config is implied from parent |
|
87 | 67 | if kwargs.get('config', None) is None: |
|
88 | 68 | kwargs['config'] = parent.config |
|
89 | 69 | self.parent = parent |
|
90 | 70 | |
|
91 | 71 | config = kwargs.pop('config', None) |
|
92 | 72 | if config is not None: |
|
93 | 73 | # We used to deepcopy, but for now we are trying to just save |
|
94 | 74 | # by reference. This *could* have side effects as all components |
|
95 | 75 | # will share config. In fact, I did find such a side effect in |
|
96 | 76 | # _config_changed below. If a config attribute value was a mutable type |
|
97 | 77 | # all instances of a component were getting the same copy, effectively |
|
98 | 78 | # making that a class attribute. |
|
99 | 79 | # self.config = deepcopy(config) |
|
100 | 80 | self.config = config |
|
101 | 81 | # This should go second so individual keyword arguments override |
|
102 | 82 | # the values in config. |
|
103 | 83 | super(Configurable, self).__init__(**kwargs) |
|
104 | 84 | |
|
105 | 85 | #------------------------------------------------------------------------- |
|
106 | 86 | # Static trait notifiations |
|
107 | 87 | #------------------------------------------------------------------------- |
|
108 | 88 | |
|
109 | 89 | @classmethod |
|
110 | 90 | def section_names(cls): |
|
111 | 91 | """return section names as a list""" |
|
112 | 92 | return [c.__name__ for c in reversed(cls.__mro__) if |
|
113 | 93 | issubclass(c, Configurable) and issubclass(cls, c) |
|
114 | 94 | ] |
|
115 | 95 | |
|
116 | 96 | def _find_my_config(self, cfg): |
|
117 | 97 | """extract my config from a global Config object |
|
118 | 98 | |
|
119 | 99 | will construct a Config object of only the config values that apply to me |
|
120 | 100 | based on my mro(), as well as those of my parent(s) if they exist. |
|
121 | 101 | |
|
122 | 102 | If I am Bar and my parent is Foo, and their parent is Tim, |
|
123 | 103 | this will return merge following config sections, in this order:: |
|
124 | 104 | |
|
125 | 105 | [Bar, Foo.bar, Tim.Foo.Bar] |
|
126 | 106 | |
|
127 | 107 | With the last item being the highest priority. |
|
128 | 108 | """ |
|
129 | 109 | cfgs = [cfg] |
|
130 | 110 | if self.parent: |
|
131 | 111 | cfgs.append(self.parent._find_my_config(cfg)) |
|
132 | 112 | my_config = Config() |
|
133 | 113 | for c in cfgs: |
|
134 | 114 | for sname in self.section_names(): |
|
135 | 115 | # Don't do a blind getattr as that would cause the config to |
|
136 | 116 | # dynamically create the section with name Class.__name__. |
|
137 | 117 | if c._has_section(sname): |
|
138 | 118 | my_config.merge(c[sname]) |
|
139 | 119 | return my_config |
|
140 | 120 | |
|
141 | 121 | def _load_config(self, cfg, section_names=None, traits=None): |
|
142 | 122 | """load traits from a Config object""" |
|
143 | 123 | |
|
144 | 124 | if traits is None: |
|
145 | 125 | traits = self.traits(config=True) |
|
146 | 126 | if section_names is None: |
|
147 | 127 | section_names = self.section_names() |
|
148 | 128 | |
|
149 | 129 | my_config = self._find_my_config(cfg) |
|
150 | 130 | for name, config_value in iteritems(my_config): |
|
151 | 131 | if name in traits: |
|
152 | 132 | if isinstance(config_value, LazyConfigValue): |
|
153 | 133 | # ConfigValue is a wrapper for using append / update on containers |
|
154 | 134 | # without having to copy the |
|
155 | 135 | initial = getattr(self, name) |
|
156 | 136 | config_value = config_value.get_value(initial) |
|
157 | 137 | # We have to do a deepcopy here if we don't deepcopy the entire |
|
158 | 138 | # config object. If we don't, a mutable config_value will be |
|
159 | 139 | # shared by all instances, effectively making it a class attribute. |
|
160 | 140 | setattr(self, name, deepcopy(config_value)) |
|
161 | 141 | |
|
162 | 142 | def _config_changed(self, name, old, new): |
|
163 | 143 | """Update all the class traits having ``config=True`` as metadata. |
|
164 | 144 | |
|
165 | 145 | For any class trait with a ``config`` metadata attribute that is |
|
166 | 146 | ``True``, we update the trait with the value of the corresponding |
|
167 | 147 | config entry. |
|
168 | 148 | """ |
|
169 | 149 | # Get all traits with a config metadata entry that is True |
|
170 | 150 | traits = self.traits(config=True) |
|
171 | 151 | |
|
172 | 152 | # We auto-load config section for this class as well as any parent |
|
173 | 153 | # classes that are Configurable subclasses. This starts with Configurable |
|
174 | 154 | # and works down the mro loading the config for each section. |
|
175 | 155 | section_names = self.section_names() |
|
176 | 156 | self._load_config(new, traits=traits, section_names=section_names) |
|
177 | 157 | |
|
178 | 158 | def update_config(self, config): |
|
179 | 159 | """Fire the traits events when the config is updated.""" |
|
180 | 160 | # Save a copy of the current config. |
|
181 | 161 | newconfig = deepcopy(self.config) |
|
182 | 162 | # Merge the new config into the current one. |
|
183 | 163 | newconfig.merge(config) |
|
184 | 164 | # Save the combined config as self.config, which triggers the traits |
|
185 | 165 | # events. |
|
186 | 166 | self.config = newconfig |
|
187 | 167 | |
|
188 | 168 | @classmethod |
|
189 | 169 | def class_get_help(cls, inst=None): |
|
190 | 170 | """Get the help string for this class in ReST format. |
|
191 | 171 | |
|
192 | 172 | If `inst` is given, it's current trait values will be used in place of |
|
193 | 173 | class defaults. |
|
194 | 174 | """ |
|
195 | 175 | assert inst is None or isinstance(inst, cls) |
|
196 | 176 | final_help = [] |
|
197 | 177 | final_help.append(u'%s options' % cls.__name__) |
|
198 | 178 | final_help.append(len(final_help[0])*u'-') |
|
199 | 179 | for k, v in sorted(cls.class_traits(config=True).items()): |
|
200 | 180 | help = cls.class_get_trait_help(v, inst) |
|
201 | 181 | final_help.append(help) |
|
202 | 182 | return '\n'.join(final_help) |
|
203 | 183 | |
|
204 | 184 | @classmethod |
|
205 | 185 | def class_get_trait_help(cls, trait, inst=None): |
|
206 | 186 | """Get the help string for a single trait. |
|
207 | 187 | |
|
208 | 188 | If `inst` is given, it's current trait values will be used in place of |
|
209 | 189 | the class default. |
|
210 | 190 | """ |
|
211 | 191 | assert inst is None or isinstance(inst, cls) |
|
212 | 192 | lines = [] |
|
213 | 193 | header = "--%s.%s=<%s>" % (cls.__name__, trait.name, trait.__class__.__name__) |
|
214 | 194 | lines.append(header) |
|
215 | 195 | if inst is not None: |
|
216 | 196 | lines.append(indent('Current: %r' % getattr(inst, trait.name), 4)) |
|
217 | 197 | else: |
|
218 | 198 | try: |
|
219 | 199 | dvr = repr(trait.get_default_value()) |
|
220 | 200 | except Exception: |
|
221 | 201 | dvr = None # ignore defaults we can't construct |
|
222 | 202 | if dvr is not None: |
|
223 | 203 | if len(dvr) > 64: |
|
224 | 204 | dvr = dvr[:61]+'...' |
|
225 | 205 | lines.append(indent('Default: %s' % dvr, 4)) |
|
226 | 206 | if 'Enum' in trait.__class__.__name__: |
|
227 | 207 | # include Enum choices |
|
228 | 208 | lines.append(indent('Choices: %r' % (trait.values,))) |
|
229 | 209 | |
|
230 | 210 | help = trait.get_metadata('help') |
|
231 | 211 | if help is not None: |
|
232 | 212 | help = '\n'.join(wrap_paragraphs(help, 76)) |
|
233 | 213 | lines.append(indent(help, 4)) |
|
234 | 214 | return '\n'.join(lines) |
|
235 | 215 | |
|
236 | 216 | @classmethod |
|
237 | 217 | def class_print_help(cls, inst=None): |
|
238 | 218 | """Get the help string for a single trait and print it.""" |
|
239 | 219 | print(cls.class_get_help(inst)) |
|
240 | 220 | |
|
241 | 221 | @classmethod |
|
242 | 222 | def class_config_section(cls): |
|
243 | 223 | """Get the config class config section""" |
|
244 | 224 | def c(s): |
|
245 | 225 | """return a commented, wrapped block.""" |
|
246 | 226 | s = '\n\n'.join(wrap_paragraphs(s, 78)) |
|
247 | 227 | |
|
248 | 228 | return '# ' + s.replace('\n', '\n# ') |
|
249 | 229 | |
|
250 | 230 | # section header |
|
251 | 231 | breaker = '#' + '-'*78 |
|
252 | 232 | s = "# %s configuration" % cls.__name__ |
|
253 | 233 | lines = [breaker, s, breaker, ''] |
|
254 | 234 | # get the description trait |
|
255 | 235 | desc = cls.class_traits().get('description') |
|
256 | 236 | if desc: |
|
257 | 237 | desc = desc.default_value |
|
258 | 238 | else: |
|
259 | 239 | # no description trait, use __doc__ |
|
260 | 240 | desc = getattr(cls, '__doc__', '') |
|
261 | 241 | if desc: |
|
262 | 242 | lines.append(c(desc)) |
|
263 | 243 | lines.append('') |
|
264 | 244 | |
|
265 | 245 | parents = [] |
|
266 | 246 | for parent in cls.mro(): |
|
267 | 247 | # only include parents that are not base classes |
|
268 | 248 | # and are not the class itself |
|
269 | 249 | # and have some configurable traits to inherit |
|
270 | 250 | if parent is not cls and issubclass(parent, Configurable) and \ |
|
271 | 251 | parent.class_traits(config=True): |
|
272 | 252 | parents.append(parent) |
|
273 | 253 | |
|
274 | 254 | if parents: |
|
275 | 255 | pstr = ', '.join([ p.__name__ for p in parents ]) |
|
276 | 256 | lines.append(c('%s will inherit config from: %s'%(cls.__name__, pstr))) |
|
277 | 257 | lines.append('') |
|
278 | 258 | |
|
279 | 259 | for name, trait in iteritems(cls.class_traits(config=True)): |
|
280 | 260 | help = trait.get_metadata('help') or '' |
|
281 | 261 | lines.append(c(help)) |
|
282 | 262 | lines.append('# c.%s.%s = %r'%(cls.__name__, name, trait.get_default_value())) |
|
283 | 263 | lines.append('') |
|
284 | 264 | return '\n'.join(lines) |
|
285 | 265 | |
|
286 | 266 | |
|
287 | 267 | |
|
288 | 268 | class SingletonConfigurable(Configurable): |
|
289 | 269 | """A configurable that only allows one instance. |
|
290 | 270 | |
|
291 | 271 | This class is for classes that should only have one instance of itself |
|
292 | 272 | or *any* subclass. To create and retrieve such a class use the |
|
293 | 273 | :meth:`SingletonConfigurable.instance` method. |
|
294 | 274 | """ |
|
295 | 275 | |
|
296 | 276 | _instance = None |
|
297 | 277 | |
|
298 | 278 | @classmethod |
|
299 | 279 | def _walk_mro(cls): |
|
300 | 280 | """Walk the cls.mro() for parent classes that are also singletons |
|
301 | 281 | |
|
302 | 282 | For use in instance() |
|
303 | 283 | """ |
|
304 | 284 | |
|
305 | 285 | for subclass in cls.mro(): |
|
306 | 286 | if issubclass(cls, subclass) and \ |
|
307 | 287 | issubclass(subclass, SingletonConfigurable) and \ |
|
308 | 288 | subclass != SingletonConfigurable: |
|
309 | 289 | yield subclass |
|
310 | 290 | |
|
311 | 291 | @classmethod |
|
312 | 292 | def clear_instance(cls): |
|
313 | 293 | """unset _instance for this class and singleton parents. |
|
314 | 294 | """ |
|
315 | 295 | if not cls.initialized(): |
|
316 | 296 | return |
|
317 | 297 | for subclass in cls._walk_mro(): |
|
318 | 298 | if isinstance(subclass._instance, cls): |
|
319 | 299 | # only clear instances that are instances |
|
320 | 300 | # of the calling class |
|
321 | 301 | subclass._instance = None |
|
322 | 302 | |
|
323 | 303 | @classmethod |
|
324 | 304 | def instance(cls, *args, **kwargs): |
|
325 | 305 | """Returns a global instance of this class. |
|
326 | 306 | |
|
327 | 307 | This method create a new instance if none have previously been created |
|
328 | 308 | and returns a previously created instance is one already exists. |
|
329 | 309 | |
|
330 | 310 | The arguments and keyword arguments passed to this method are passed |
|
331 | 311 | on to the :meth:`__init__` method of the class upon instantiation. |
|
332 | 312 | |
|
333 | 313 | Examples |
|
334 | 314 | -------- |
|
335 | 315 | |
|
336 | 316 | Create a singleton class using instance, and retrieve it:: |
|
337 | 317 | |
|
338 | 318 | >>> from IPython.config.configurable import SingletonConfigurable |
|
339 | 319 | >>> class Foo(SingletonConfigurable): pass |
|
340 | 320 | >>> foo = Foo.instance() |
|
341 | 321 | >>> foo == Foo.instance() |
|
342 | 322 | True |
|
343 | 323 | |
|
344 | 324 | Create a subclass that is retrived using the base class instance:: |
|
345 | 325 | |
|
346 | 326 | >>> class Bar(SingletonConfigurable): pass |
|
347 | 327 | >>> class Bam(Bar): pass |
|
348 | 328 | >>> bam = Bam.instance() |
|
349 | 329 | >>> bam == Bar.instance() |
|
350 | 330 | True |
|
351 | 331 | """ |
|
352 | 332 | # Create and save the instance |
|
353 | 333 | if cls._instance is None: |
|
354 | 334 | inst = cls(*args, **kwargs) |
|
355 | 335 | # Now make sure that the instance will also be returned by |
|
356 | 336 | # parent classes' _instance attribute. |
|
357 | 337 | for subclass in cls._walk_mro(): |
|
358 | 338 | subclass._instance = inst |
|
359 | 339 | |
|
360 | 340 | if isinstance(cls._instance, cls): |
|
361 | 341 | return cls._instance |
|
362 | 342 | else: |
|
363 | 343 | raise MultipleInstanceError( |
|
364 | 344 | 'Multiple incompatible subclass instances of ' |
|
365 | 345 | '%s are being created.' % cls.__name__ |
|
366 | 346 | ) |
|
367 | 347 | |
|
368 | 348 | @classmethod |
|
369 | 349 | def initialized(cls): |
|
370 | 350 | """Has an instance been created?""" |
|
371 | 351 | return hasattr(cls, "_instance") and cls._instance is not None |
|
372 | 352 | |
|
373 | 353 | |
|
374 | 354 | class LoggingConfigurable(Configurable): |
|
375 | 355 | """A parent class for Configurables that log. |
|
376 | 356 | |
|
377 | 357 | Subclasses have a log trait, and the default behavior |
|
378 | is to get the logger from the currently running Application | |
|
379 | via Application.instance().log. | |
|
358 | is to get the logger from the currently running Application. | |
|
380 | 359 | """ |
|
381 | 360 | |
|
382 | 361 | log = Instance('logging.Logger') |
|
383 | 362 | def _log_default(self): |
|
384 |
from IPython. |
|
|
385 | if Application.initialized(): | |
|
386 | return Application.instance().log | |
|
387 | else: | |
|
388 | return logging.getLogger() | |
|
363 | from IPython.utils import log | |
|
364 | return log.get_logger() | |
|
389 | 365 | |
|
390 | 366 |
@@ -1,846 +1,824 b'' | |||
|
1 | """A simple configuration system. | |
|
1 | # encoding: utf-8 | |
|
2 | """A simple configuration system.""" | |
|
2 | 3 | |
|
3 | Inheritance diagram: | |
|
4 | ||
|
5 | .. inheritance-diagram:: IPython.config.loader | |
|
6 | :parts: 3 | |
|
7 | ||
|
8 | Authors | |
|
9 | ------- | |
|
10 | * Brian Granger | |
|
11 | * Fernando Perez | |
|
12 | * Min RK | |
|
13 | """ | |
|
14 | ||
|
15 | #----------------------------------------------------------------------------- | |
|
16 | # Copyright (C) 2008-2011 The IPython Development Team | |
|
17 | # | |
|
18 | # Distributed under the terms of the BSD License. The full license is in | |
|
19 | # the file COPYING, distributed as part of this software. | |
|
20 | #----------------------------------------------------------------------------- | |
|
21 | ||
|
22 | #----------------------------------------------------------------------------- | |
|
23 | # Imports | |
|
24 | #----------------------------------------------------------------------------- | |
|
4 | # Copyright (c) IPython Development Team. | |
|
5 | # Distributed under the terms of the Modified BSD License. | |
|
25 | 6 | |
|
26 | 7 | import argparse |
|
27 | 8 | import copy |
|
28 | 9 | import logging |
|
29 | 10 | import os |
|
30 | 11 | import re |
|
31 | 12 | import sys |
|
32 | 13 | import json |
|
33 | 14 | |
|
34 | 15 | from IPython.utils.path import filefind, get_ipython_dir |
|
35 | 16 | from IPython.utils import py3compat |
|
36 | 17 | from IPython.utils.encoding import DEFAULT_ENCODING |
|
37 | 18 | from IPython.utils.py3compat import unicode_type, iteritems |
|
38 | 19 | from IPython.utils.traitlets import HasTraits, List, Any |
|
39 | 20 | |
|
40 | 21 | #----------------------------------------------------------------------------- |
|
41 | 22 | # Exceptions |
|
42 | 23 | #----------------------------------------------------------------------------- |
|
43 | 24 | |
|
44 | 25 | |
|
45 | 26 | class ConfigError(Exception): |
|
46 | 27 | pass |
|
47 | 28 | |
|
48 | 29 | class ConfigLoaderError(ConfigError): |
|
49 | 30 | pass |
|
50 | 31 | |
|
51 | 32 | class ConfigFileNotFound(ConfigError): |
|
52 | 33 | pass |
|
53 | 34 | |
|
54 | 35 | class ArgumentError(ConfigLoaderError): |
|
55 | 36 | pass |
|
56 | 37 | |
|
57 | 38 | #----------------------------------------------------------------------------- |
|
58 | 39 | # Argparse fix |
|
59 | 40 | #----------------------------------------------------------------------------- |
|
60 | 41 | |
|
61 | 42 | # Unfortunately argparse by default prints help messages to stderr instead of |
|
62 | 43 | # stdout. This makes it annoying to capture long help screens at the command |
|
63 | 44 | # line, since one must know how to pipe stderr, which many users don't know how |
|
64 | 45 | # to do. So we override the print_help method with one that defaults to |
|
65 | 46 | # stdout and use our class instead. |
|
66 | 47 | |
|
67 | 48 | class ArgumentParser(argparse.ArgumentParser): |
|
68 | 49 | """Simple argparse subclass that prints help to stdout by default.""" |
|
69 | 50 | |
|
70 | 51 | def print_help(self, file=None): |
|
71 | 52 | if file is None: |
|
72 | 53 | file = sys.stdout |
|
73 | 54 | return super(ArgumentParser, self).print_help(file) |
|
74 | 55 | |
|
75 | 56 | print_help.__doc__ = argparse.ArgumentParser.print_help.__doc__ |
|
76 | 57 | |
|
77 | 58 | #----------------------------------------------------------------------------- |
|
78 | 59 | # Config class for holding config information |
|
79 | 60 | #----------------------------------------------------------------------------- |
|
80 | 61 | |
|
81 | 62 | class LazyConfigValue(HasTraits): |
|
82 | 63 | """Proxy object for exposing methods on configurable containers |
|
83 | 64 | |
|
84 | 65 | Exposes: |
|
85 | 66 | |
|
86 | 67 | - append, extend, insert on lists |
|
87 | 68 | - update on dicts |
|
88 | 69 | - update, add on sets |
|
89 | 70 | """ |
|
90 | 71 | |
|
91 | 72 | _value = None |
|
92 | 73 | |
|
93 | 74 | # list methods |
|
94 | 75 | _extend = List() |
|
95 | 76 | _prepend = List() |
|
96 | 77 | |
|
97 | 78 | def append(self, obj): |
|
98 | 79 | self._extend.append(obj) |
|
99 | 80 | |
|
100 | 81 | def extend(self, other): |
|
101 | 82 | self._extend.extend(other) |
|
102 | 83 | |
|
103 | 84 | def prepend(self, other): |
|
104 | 85 | """like list.extend, but for the front""" |
|
105 | 86 | self._prepend[:0] = other |
|
106 | 87 | |
|
107 | 88 | _inserts = List() |
|
108 | 89 | def insert(self, index, other): |
|
109 | 90 | if not isinstance(index, int): |
|
110 | 91 | raise TypeError("An integer is required") |
|
111 | 92 | self._inserts.append((index, other)) |
|
112 | 93 | |
|
113 | 94 | # dict methods |
|
114 | 95 | # update is used for both dict and set |
|
115 | 96 | _update = Any() |
|
116 | 97 | def update(self, other): |
|
117 | 98 | if self._update is None: |
|
118 | 99 | if isinstance(other, dict): |
|
119 | 100 | self._update = {} |
|
120 | 101 | else: |
|
121 | 102 | self._update = set() |
|
122 | 103 | self._update.update(other) |
|
123 | 104 | |
|
124 | 105 | # set methods |
|
125 | 106 | def add(self, obj): |
|
126 | 107 | self.update({obj}) |
|
127 | 108 | |
|
128 | 109 | def get_value(self, initial): |
|
129 | 110 | """construct the value from the initial one |
|
130 | 111 | |
|
131 | 112 | after applying any insert / extend / update changes |
|
132 | 113 | """ |
|
133 | 114 | if self._value is not None: |
|
134 | 115 | return self._value |
|
135 | 116 | value = copy.deepcopy(initial) |
|
136 | 117 | if isinstance(value, list): |
|
137 | 118 | for idx, obj in self._inserts: |
|
138 | 119 | value.insert(idx, obj) |
|
139 | 120 | value[:0] = self._prepend |
|
140 | 121 | value.extend(self._extend) |
|
141 | 122 | |
|
142 | 123 | elif isinstance(value, dict): |
|
143 | 124 | if self._update: |
|
144 | 125 | value.update(self._update) |
|
145 | 126 | elif isinstance(value, set): |
|
146 | 127 | if self._update: |
|
147 | 128 | value.update(self._update) |
|
148 | 129 | self._value = value |
|
149 | 130 | return value |
|
150 | 131 | |
|
151 | 132 | def to_dict(self): |
|
152 | 133 | """return JSONable dict form of my data |
|
153 | 134 | |
|
154 | 135 | Currently update as dict or set, extend, prepend as lists, and inserts as list of tuples. |
|
155 | 136 | """ |
|
156 | 137 | d = {} |
|
157 | 138 | if self._update: |
|
158 | 139 | d['update'] = self._update |
|
159 | 140 | if self._extend: |
|
160 | 141 | d['extend'] = self._extend |
|
161 | 142 | if self._prepend: |
|
162 | 143 | d['prepend'] = self._prepend |
|
163 | 144 | elif self._inserts: |
|
164 | 145 | d['inserts'] = self._inserts |
|
165 | 146 | return d |
|
166 | 147 | |
|
167 | 148 | |
|
168 | 149 | def _is_section_key(key): |
|
169 | 150 | """Is a Config key a section name (does it start with a capital)?""" |
|
170 | 151 | if key and key[0].upper()==key[0] and not key.startswith('_'): |
|
171 | 152 | return True |
|
172 | 153 | else: |
|
173 | 154 | return False |
|
174 | 155 | |
|
175 | 156 | |
|
176 | 157 | class Config(dict): |
|
177 | 158 | """An attribute based dict that can do smart merges.""" |
|
178 | 159 | |
|
179 | 160 | def __init__(self, *args, **kwds): |
|
180 | 161 | dict.__init__(self, *args, **kwds) |
|
181 | 162 | self._ensure_subconfig() |
|
182 | 163 | |
|
183 | 164 | def _ensure_subconfig(self): |
|
184 | 165 | """ensure that sub-dicts that should be Config objects are |
|
185 | 166 | |
|
186 | 167 | casts dicts that are under section keys to Config objects, |
|
187 | 168 | which is necessary for constructing Config objects from dict literals. |
|
188 | 169 | """ |
|
189 | 170 | for key in self: |
|
190 | 171 | obj = self[key] |
|
191 | 172 | if _is_section_key(key) \ |
|
192 | 173 | and isinstance(obj, dict) \ |
|
193 | 174 | and not isinstance(obj, Config): |
|
194 | 175 | setattr(self, key, Config(obj)) |
|
195 | 176 | |
|
196 | 177 | def _merge(self, other): |
|
197 | 178 | """deprecated alias, use Config.merge()""" |
|
198 | 179 | self.merge(other) |
|
199 | 180 | |
|
200 | 181 | def merge(self, other): |
|
201 | 182 | """merge another config object into this one""" |
|
202 | 183 | to_update = {} |
|
203 | 184 | for k, v in iteritems(other): |
|
204 | 185 | if k not in self: |
|
205 | 186 | to_update[k] = copy.deepcopy(v) |
|
206 | 187 | else: # I have this key |
|
207 | 188 | if isinstance(v, Config) and isinstance(self[k], Config): |
|
208 | 189 | # Recursively merge common sub Configs |
|
209 | 190 | self[k].merge(v) |
|
210 | 191 | else: |
|
211 | 192 | # Plain updates for non-Configs |
|
212 | 193 | to_update[k] = copy.deepcopy(v) |
|
213 | 194 | |
|
214 | 195 | self.update(to_update) |
|
215 | 196 | |
|
216 | 197 | def __contains__(self, key): |
|
217 | 198 | # allow nested contains of the form `"Section.key" in config` |
|
218 | 199 | if '.' in key: |
|
219 | 200 | first, remainder = key.split('.', 1) |
|
220 | 201 | if first not in self: |
|
221 | 202 | return False |
|
222 | 203 | return remainder in self[first] |
|
223 | 204 | |
|
224 | 205 | return super(Config, self).__contains__(key) |
|
225 | 206 | |
|
226 | 207 | # .has_key is deprecated for dictionaries. |
|
227 | 208 | has_key = __contains__ |
|
228 | 209 | |
|
229 | 210 | def _has_section(self, key): |
|
230 | 211 | return _is_section_key(key) and key in self |
|
231 | 212 | |
|
232 | 213 | def copy(self): |
|
233 | 214 | return type(self)(dict.copy(self)) |
|
234 | 215 | |
|
235 | 216 | def __copy__(self): |
|
236 | 217 | return self.copy() |
|
237 | 218 | |
|
238 | 219 | def __deepcopy__(self, memo): |
|
239 | 220 | import copy |
|
240 | 221 | return type(self)(copy.deepcopy(list(self.items()))) |
|
241 | 222 | |
|
242 | 223 | def __getitem__(self, key): |
|
243 | 224 | try: |
|
244 | 225 | return dict.__getitem__(self, key) |
|
245 | 226 | except KeyError: |
|
246 | 227 | if _is_section_key(key): |
|
247 | 228 | c = Config() |
|
248 | 229 | dict.__setitem__(self, key, c) |
|
249 | 230 | return c |
|
250 | 231 | elif not key.startswith('_'): |
|
251 | 232 | # undefined, create lazy value, used for container methods |
|
252 | 233 | v = LazyConfigValue() |
|
253 | 234 | dict.__setitem__(self, key, v) |
|
254 | 235 | return v |
|
255 | 236 | else: |
|
256 | 237 | raise KeyError |
|
257 | 238 | |
|
258 | 239 | def __setitem__(self, key, value): |
|
259 | 240 | if _is_section_key(key): |
|
260 | 241 | if not isinstance(value, Config): |
|
261 | 242 | raise ValueError('values whose keys begin with an uppercase ' |
|
262 | 243 | 'char must be Config instances: %r, %r' % (key, value)) |
|
263 | 244 | dict.__setitem__(self, key, value) |
|
264 | 245 | |
|
265 | 246 | def __getattr__(self, key): |
|
266 | 247 | if key.startswith('__'): |
|
267 | 248 | return dict.__getattr__(self, key) |
|
268 | 249 | try: |
|
269 | 250 | return self.__getitem__(key) |
|
270 | 251 | except KeyError as e: |
|
271 | 252 | raise AttributeError(e) |
|
272 | 253 | |
|
273 | 254 | def __setattr__(self, key, value): |
|
274 | 255 | if key.startswith('__'): |
|
275 | 256 | return dict.__setattr__(self, key, value) |
|
276 | 257 | try: |
|
277 | 258 | self.__setitem__(key, value) |
|
278 | 259 | except KeyError as e: |
|
279 | 260 | raise AttributeError(e) |
|
280 | 261 | |
|
281 | 262 | def __delattr__(self, key): |
|
282 | 263 | if key.startswith('__'): |
|
283 | 264 | return dict.__delattr__(self, key) |
|
284 | 265 | try: |
|
285 | 266 | dict.__delitem__(self, key) |
|
286 | 267 | except KeyError as e: |
|
287 | 268 | raise AttributeError(e) |
|
288 | 269 | |
|
289 | 270 | |
|
290 | 271 | #----------------------------------------------------------------------------- |
|
291 | 272 | # Config loading classes |
|
292 | 273 | #----------------------------------------------------------------------------- |
|
293 | 274 | |
|
294 | 275 | |
|
295 | 276 | class ConfigLoader(object): |
|
296 | 277 | """A object for loading configurations from just about anywhere. |
|
297 | 278 | |
|
298 | 279 | The resulting configuration is packaged as a :class:`Config`. |
|
299 | 280 | |
|
300 | 281 | Notes |
|
301 | 282 | ----- |
|
302 | 283 | A :class:`ConfigLoader` does one thing: load a config from a source |
|
303 | 284 | (file, command line arguments) and returns the data as a :class:`Config` object. |
|
304 | 285 | There are lots of things that :class:`ConfigLoader` does not do. It does |
|
305 | 286 | not implement complex logic for finding config files. It does not handle |
|
306 | 287 | default values or merge multiple configs. These things need to be |
|
307 | 288 | handled elsewhere. |
|
308 | 289 | """ |
|
309 | 290 | |
|
310 | 291 | def _log_default(self): |
|
311 |
from IPython. |
|
|
312 | if Application.initialized(): | |
|
313 | return Application.instance().log | |
|
314 | else: | |
|
315 | return logging.getLogger() | |
|
292 | from IPython.utils.log import get_logger | |
|
293 | return get_logger() | |
|
316 | 294 | |
|
317 | 295 | def __init__(self, log=None): |
|
318 | 296 | """A base class for config loaders. |
|
319 | 297 | |
|
320 | 298 | log : instance of :class:`logging.Logger` to use. |
|
321 | 299 | By default loger of :meth:`IPython.config.application.Application.instance()` |
|
322 | 300 | will be used |
|
323 | 301 | |
|
324 | 302 | Examples |
|
325 | 303 | -------- |
|
326 | 304 | |
|
327 | 305 | >>> cl = ConfigLoader() |
|
328 | 306 | >>> config = cl.load_config() |
|
329 | 307 | >>> config |
|
330 | 308 | {} |
|
331 | 309 | """ |
|
332 | 310 | self.clear() |
|
333 | 311 | if log is None: |
|
334 | 312 | self.log = self._log_default() |
|
335 | 313 | self.log.debug('Using default logger') |
|
336 | 314 | else: |
|
337 | 315 | self.log = log |
|
338 | 316 | |
|
339 | 317 | def clear(self): |
|
340 | 318 | self.config = Config() |
|
341 | 319 | |
|
342 | 320 | def load_config(self): |
|
343 | 321 | """Load a config from somewhere, return a :class:`Config` instance. |
|
344 | 322 | |
|
345 | 323 | Usually, this will cause self.config to be set and then returned. |
|
346 | 324 | However, in most cases, :meth:`ConfigLoader.clear` should be called |
|
347 | 325 | to erase any previous state. |
|
348 | 326 | """ |
|
349 | 327 | self.clear() |
|
350 | 328 | return self.config |
|
351 | 329 | |
|
352 | 330 | |
|
353 | 331 | class FileConfigLoader(ConfigLoader): |
|
354 | 332 | """A base class for file based configurations. |
|
355 | 333 | |
|
356 | 334 | As we add more file based config loaders, the common logic should go |
|
357 | 335 | here. |
|
358 | 336 | """ |
|
359 | 337 | |
|
360 | 338 | def __init__(self, filename, path=None, **kw): |
|
361 | 339 | """Build a config loader for a filename and path. |
|
362 | 340 | |
|
363 | 341 | Parameters |
|
364 | 342 | ---------- |
|
365 | 343 | filename : str |
|
366 | 344 | The file name of the config file. |
|
367 | 345 | path : str, list, tuple |
|
368 | 346 | The path to search for the config file on, or a sequence of |
|
369 | 347 | paths to try in order. |
|
370 | 348 | """ |
|
371 | 349 | super(FileConfigLoader, self).__init__(**kw) |
|
372 | 350 | self.filename = filename |
|
373 | 351 | self.path = path |
|
374 | 352 | self.full_filename = '' |
|
375 | 353 | |
|
376 | 354 | def _find_file(self): |
|
377 | 355 | """Try to find the file by searching the paths.""" |
|
378 | 356 | self.full_filename = filefind(self.filename, self.path) |
|
379 | 357 | |
|
380 | 358 | class JSONFileConfigLoader(FileConfigLoader): |
|
381 | 359 | """A Json file loader for config""" |
|
382 | 360 | |
|
383 | 361 | def load_config(self): |
|
384 | 362 | """Load the config from a file and return it as a Config object.""" |
|
385 | 363 | self.clear() |
|
386 | 364 | try: |
|
387 | 365 | self._find_file() |
|
388 | 366 | except IOError as e: |
|
389 | 367 | raise ConfigFileNotFound(str(e)) |
|
390 | 368 | dct = self._read_file_as_dict() |
|
391 | 369 | self.config = self._convert_to_config(dct) |
|
392 | 370 | return self.config |
|
393 | 371 | |
|
394 | 372 | def _read_file_as_dict(self): |
|
395 | 373 | with open(self.full_filename) as f: |
|
396 | 374 | return json.load(f) |
|
397 | 375 | |
|
398 | 376 | def _convert_to_config(self, dictionary): |
|
399 | 377 | if 'version' in dictionary: |
|
400 | 378 | version = dictionary.pop('version') |
|
401 | 379 | else: |
|
402 | 380 | version = 1 |
|
403 | 381 | self.log.warn("Unrecognized JSON config file version, assuming version {}".format(version)) |
|
404 | 382 | |
|
405 | 383 | if version == 1: |
|
406 | 384 | return Config(dictionary) |
|
407 | 385 | else: |
|
408 | 386 | raise ValueError('Unknown version of JSON config file: {version}'.format(version=version)) |
|
409 | 387 | |
|
410 | 388 | |
|
411 | 389 | class PyFileConfigLoader(FileConfigLoader): |
|
412 | 390 | """A config loader for pure python files. |
|
413 | 391 | |
|
414 | 392 | This is responsible for locating a Python config file by filename and |
|
415 | 393 | path, then executing it to construct a Config object. |
|
416 | 394 | """ |
|
417 | 395 | |
|
418 | 396 | def load_config(self): |
|
419 | 397 | """Load the config from a file and return it as a Config object.""" |
|
420 | 398 | self.clear() |
|
421 | 399 | try: |
|
422 | 400 | self._find_file() |
|
423 | 401 | except IOError as e: |
|
424 | 402 | raise ConfigFileNotFound(str(e)) |
|
425 | 403 | self._read_file_as_dict() |
|
426 | 404 | return self.config |
|
427 | 405 | |
|
428 | 406 | |
|
429 | 407 | def _read_file_as_dict(self): |
|
430 | 408 | """Load the config file into self.config, with recursive loading.""" |
|
431 | 409 | # This closure is made available in the namespace that is used |
|
432 | 410 | # to exec the config file. It allows users to call |
|
433 | 411 | # load_subconfig('myconfig.py') to load config files recursively. |
|
434 | 412 | # It needs to be a closure because it has references to self.path |
|
435 | 413 | # and self.config. The sub-config is loaded with the same path |
|
436 | 414 | # as the parent, but it uses an empty config which is then merged |
|
437 | 415 | # with the parents. |
|
438 | 416 | |
|
439 | 417 | # If a profile is specified, the config file will be loaded |
|
440 | 418 | # from that profile |
|
441 | 419 | |
|
442 | 420 | def load_subconfig(fname, profile=None): |
|
443 | 421 | # import here to prevent circular imports |
|
444 | 422 | from IPython.core.profiledir import ProfileDir, ProfileDirError |
|
445 | 423 | if profile is not None: |
|
446 | 424 | try: |
|
447 | 425 | profile_dir = ProfileDir.find_profile_dir_by_name( |
|
448 | 426 | get_ipython_dir(), |
|
449 | 427 | profile, |
|
450 | 428 | ) |
|
451 | 429 | except ProfileDirError: |
|
452 | 430 | return |
|
453 | 431 | path = profile_dir.location |
|
454 | 432 | else: |
|
455 | 433 | path = self.path |
|
456 | 434 | loader = PyFileConfigLoader(fname, path) |
|
457 | 435 | try: |
|
458 | 436 | sub_config = loader.load_config() |
|
459 | 437 | except ConfigFileNotFound: |
|
460 | 438 | # Pass silently if the sub config is not there. This happens |
|
461 | 439 | # when a user s using a profile, but not the default config. |
|
462 | 440 | pass |
|
463 | 441 | else: |
|
464 | 442 | self.config.merge(sub_config) |
|
465 | 443 | |
|
466 | 444 | # Again, this needs to be a closure and should be used in config |
|
467 | 445 | # files to get the config being loaded. |
|
468 | 446 | def get_config(): |
|
469 | 447 | return self.config |
|
470 | 448 | |
|
471 | 449 | namespace = dict( |
|
472 | 450 | load_subconfig=load_subconfig, |
|
473 | 451 | get_config=get_config, |
|
474 | 452 | __file__=self.full_filename, |
|
475 | 453 | ) |
|
476 | 454 | fs_encoding = sys.getfilesystemencoding() or 'ascii' |
|
477 | 455 | conf_filename = self.full_filename.encode(fs_encoding) |
|
478 | 456 | py3compat.execfile(conf_filename, namespace) |
|
479 | 457 | |
|
480 | 458 | |
|
481 | 459 | class CommandLineConfigLoader(ConfigLoader): |
|
482 | 460 | """A config loader for command line arguments. |
|
483 | 461 | |
|
484 | 462 | As we add more command line based loaders, the common logic should go |
|
485 | 463 | here. |
|
486 | 464 | """ |
|
487 | 465 | |
|
488 | 466 | def _exec_config_str(self, lhs, rhs): |
|
489 | 467 | """execute self.config.<lhs> = <rhs> |
|
490 | 468 | |
|
491 | 469 | * expands ~ with expanduser |
|
492 | 470 | * tries to assign with raw eval, otherwise assigns with just the string, |
|
493 | 471 | allowing `--C.a=foobar` and `--C.a="foobar"` to be equivalent. *Not* |
|
494 | 472 | equivalent are `--C.a=4` and `--C.a='4'`. |
|
495 | 473 | """ |
|
496 | 474 | rhs = os.path.expanduser(rhs) |
|
497 | 475 | try: |
|
498 | 476 | # Try to see if regular Python syntax will work. This |
|
499 | 477 | # won't handle strings as the quote marks are removed |
|
500 | 478 | # by the system shell. |
|
501 | 479 | value = eval(rhs) |
|
502 | 480 | except (NameError, SyntaxError): |
|
503 | 481 | # This case happens if the rhs is a string. |
|
504 | 482 | value = rhs |
|
505 | 483 | |
|
506 | 484 | exec(u'self.config.%s = value' % lhs) |
|
507 | 485 | |
|
508 | 486 | def _load_flag(self, cfg): |
|
509 | 487 | """update self.config from a flag, which can be a dict or Config""" |
|
510 | 488 | if isinstance(cfg, (dict, Config)): |
|
511 | 489 | # don't clobber whole config sections, update |
|
512 | 490 | # each section from config: |
|
513 | 491 | for sec,c in iteritems(cfg): |
|
514 | 492 | self.config[sec].update(c) |
|
515 | 493 | else: |
|
516 | 494 | raise TypeError("Invalid flag: %r" % cfg) |
|
517 | 495 | |
|
518 | 496 | # raw --identifier=value pattern |
|
519 | 497 | # but *also* accept '-' as wordsep, for aliases |
|
520 | 498 | # accepts: --foo=a |
|
521 | 499 | # --Class.trait=value |
|
522 | 500 | # --alias-name=value |
|
523 | 501 | # rejects: -foo=value |
|
524 | 502 | # --foo |
|
525 | 503 | # --Class.trait |
|
526 | 504 | kv_pattern = re.compile(r'\-\-[A-Za-z][\w\-]*(\.[\w\-]+)*\=.*') |
|
527 | 505 | |
|
528 | 506 | # just flags, no assignments, with two *or one* leading '-' |
|
529 | 507 | # accepts: --foo |
|
530 | 508 | # -foo-bar-again |
|
531 | 509 | # rejects: --anything=anything |
|
532 | 510 | # --two.word |
|
533 | 511 | |
|
534 | 512 | flag_pattern = re.compile(r'\-\-?\w+[\-\w]*$') |
|
535 | 513 | |
|
536 | 514 | class KeyValueConfigLoader(CommandLineConfigLoader): |
|
537 | 515 | """A config loader that loads key value pairs from the command line. |
|
538 | 516 | |
|
539 | 517 | This allows command line options to be gives in the following form:: |
|
540 | 518 | |
|
541 | 519 | ipython --profile="foo" --InteractiveShell.autocall=False |
|
542 | 520 | """ |
|
543 | 521 | |
|
544 | 522 | def __init__(self, argv=None, aliases=None, flags=None, **kw): |
|
545 | 523 | """Create a key value pair config loader. |
|
546 | 524 | |
|
547 | 525 | Parameters |
|
548 | 526 | ---------- |
|
549 | 527 | argv : list |
|
550 | 528 | A list that has the form of sys.argv[1:] which has unicode |
|
551 | 529 | elements of the form u"key=value". If this is None (default), |
|
552 | 530 | then sys.argv[1:] will be used. |
|
553 | 531 | aliases : dict |
|
554 | 532 | A dict of aliases for configurable traits. |
|
555 | 533 | Keys are the short aliases, Values are the resolved trait. |
|
556 | 534 | Of the form: `{'alias' : 'Configurable.trait'}` |
|
557 | 535 | flags : dict |
|
558 | 536 | A dict of flags, keyed by str name. Vaues can be Config objects, |
|
559 | 537 | dicts, or "key=value" strings. If Config or dict, when the flag |
|
560 | 538 | is triggered, The flag is loaded as `self.config.update(m)`. |
|
561 | 539 | |
|
562 | 540 | Returns |
|
563 | 541 | ------- |
|
564 | 542 | config : Config |
|
565 | 543 | The resulting Config object. |
|
566 | 544 | |
|
567 | 545 | Examples |
|
568 | 546 | -------- |
|
569 | 547 | |
|
570 | 548 | >>> from IPython.config.loader import KeyValueConfigLoader |
|
571 | 549 | >>> cl = KeyValueConfigLoader() |
|
572 | 550 | >>> d = cl.load_config(["--A.name='brian'","--B.number=0"]) |
|
573 | 551 | >>> sorted(d.items()) |
|
574 | 552 | [('A', {'name': 'brian'}), ('B', {'number': 0})] |
|
575 | 553 | """ |
|
576 | 554 | super(KeyValueConfigLoader, self).__init__(**kw) |
|
577 | 555 | if argv is None: |
|
578 | 556 | argv = sys.argv[1:] |
|
579 | 557 | self.argv = argv |
|
580 | 558 | self.aliases = aliases or {} |
|
581 | 559 | self.flags = flags or {} |
|
582 | 560 | |
|
583 | 561 | |
|
584 | 562 | def clear(self): |
|
585 | 563 | super(KeyValueConfigLoader, self).clear() |
|
586 | 564 | self.extra_args = [] |
|
587 | 565 | |
|
588 | 566 | |
|
589 | 567 | def _decode_argv(self, argv, enc=None): |
|
590 | 568 | """decode argv if bytes, using stin.encoding, falling back on default enc""" |
|
591 | 569 | uargv = [] |
|
592 | 570 | if enc is None: |
|
593 | 571 | enc = DEFAULT_ENCODING |
|
594 | 572 | for arg in argv: |
|
595 | 573 | if not isinstance(arg, unicode_type): |
|
596 | 574 | # only decode if not already decoded |
|
597 | 575 | arg = arg.decode(enc) |
|
598 | 576 | uargv.append(arg) |
|
599 | 577 | return uargv |
|
600 | 578 | |
|
601 | 579 | |
|
602 | 580 | def load_config(self, argv=None, aliases=None, flags=None): |
|
603 | 581 | """Parse the configuration and generate the Config object. |
|
604 | 582 | |
|
605 | 583 | After loading, any arguments that are not key-value or |
|
606 | 584 | flags will be stored in self.extra_args - a list of |
|
607 | 585 | unparsed command-line arguments. This is used for |
|
608 | 586 | arguments such as input files or subcommands. |
|
609 | 587 | |
|
610 | 588 | Parameters |
|
611 | 589 | ---------- |
|
612 | 590 | argv : list, optional |
|
613 | 591 | A list that has the form of sys.argv[1:] which has unicode |
|
614 | 592 | elements of the form u"key=value". If this is None (default), |
|
615 | 593 | then self.argv will be used. |
|
616 | 594 | aliases : dict |
|
617 | 595 | A dict of aliases for configurable traits. |
|
618 | 596 | Keys are the short aliases, Values are the resolved trait. |
|
619 | 597 | Of the form: `{'alias' : 'Configurable.trait'}` |
|
620 | 598 | flags : dict |
|
621 | 599 | A dict of flags, keyed by str name. Values can be Config objects |
|
622 | 600 | or dicts. When the flag is triggered, The config is loaded as |
|
623 | 601 | `self.config.update(cfg)`. |
|
624 | 602 | """ |
|
625 | 603 | self.clear() |
|
626 | 604 | if argv is None: |
|
627 | 605 | argv = self.argv |
|
628 | 606 | if aliases is None: |
|
629 | 607 | aliases = self.aliases |
|
630 | 608 | if flags is None: |
|
631 | 609 | flags = self.flags |
|
632 | 610 | |
|
633 | 611 | # ensure argv is a list of unicode strings: |
|
634 | 612 | uargv = self._decode_argv(argv) |
|
635 | 613 | for idx,raw in enumerate(uargv): |
|
636 | 614 | # strip leading '-' |
|
637 | 615 | item = raw.lstrip('-') |
|
638 | 616 | |
|
639 | 617 | if raw == '--': |
|
640 | 618 | # don't parse arguments after '--' |
|
641 | 619 | # this is useful for relaying arguments to scripts, e.g. |
|
642 | 620 | # ipython -i foo.py --matplotlib=qt -- args after '--' go-to-foo.py |
|
643 | 621 | self.extra_args.extend(uargv[idx+1:]) |
|
644 | 622 | break |
|
645 | 623 | |
|
646 | 624 | if kv_pattern.match(raw): |
|
647 | 625 | lhs,rhs = item.split('=',1) |
|
648 | 626 | # Substitute longnames for aliases. |
|
649 | 627 | if lhs in aliases: |
|
650 | 628 | lhs = aliases[lhs] |
|
651 | 629 | if '.' not in lhs: |
|
652 | 630 | # probably a mistyped alias, but not technically illegal |
|
653 | 631 | self.log.warn("Unrecognized alias: '%s', it will probably have no effect.", raw) |
|
654 | 632 | try: |
|
655 | 633 | self._exec_config_str(lhs, rhs) |
|
656 | 634 | except Exception: |
|
657 | 635 | raise ArgumentError("Invalid argument: '%s'" % raw) |
|
658 | 636 | |
|
659 | 637 | elif flag_pattern.match(raw): |
|
660 | 638 | if item in flags: |
|
661 | 639 | cfg,help = flags[item] |
|
662 | 640 | self._load_flag(cfg) |
|
663 | 641 | else: |
|
664 | 642 | raise ArgumentError("Unrecognized flag: '%s'"%raw) |
|
665 | 643 | elif raw.startswith('-'): |
|
666 | 644 | kv = '--'+item |
|
667 | 645 | if kv_pattern.match(kv): |
|
668 | 646 | raise ArgumentError("Invalid argument: '%s', did you mean '%s'?"%(raw, kv)) |
|
669 | 647 | else: |
|
670 | 648 | raise ArgumentError("Invalid argument: '%s'"%raw) |
|
671 | 649 | else: |
|
672 | 650 | # keep all args that aren't valid in a list, |
|
673 | 651 | # in case our parent knows what to do with them. |
|
674 | 652 | self.extra_args.append(item) |
|
675 | 653 | return self.config |
|
676 | 654 | |
|
677 | 655 | class ArgParseConfigLoader(CommandLineConfigLoader): |
|
678 | 656 | """A loader that uses the argparse module to load from the command line.""" |
|
679 | 657 | |
|
680 | 658 | def __init__(self, argv=None, aliases=None, flags=None, log=None, *parser_args, **parser_kw): |
|
681 | 659 | """Create a config loader for use with argparse. |
|
682 | 660 | |
|
683 | 661 | Parameters |
|
684 | 662 | ---------- |
|
685 | 663 | |
|
686 | 664 | argv : optional, list |
|
687 | 665 | If given, used to read command-line arguments from, otherwise |
|
688 | 666 | sys.argv[1:] is used. |
|
689 | 667 | |
|
690 | 668 | parser_args : tuple |
|
691 | 669 | A tuple of positional arguments that will be passed to the |
|
692 | 670 | constructor of :class:`argparse.ArgumentParser`. |
|
693 | 671 | |
|
694 | 672 | parser_kw : dict |
|
695 | 673 | A tuple of keyword arguments that will be passed to the |
|
696 | 674 | constructor of :class:`argparse.ArgumentParser`. |
|
697 | 675 | |
|
698 | 676 | Returns |
|
699 | 677 | ------- |
|
700 | 678 | config : Config |
|
701 | 679 | The resulting Config object. |
|
702 | 680 | """ |
|
703 | 681 | super(CommandLineConfigLoader, self).__init__(log=log) |
|
704 | 682 | self.clear() |
|
705 | 683 | if argv is None: |
|
706 | 684 | argv = sys.argv[1:] |
|
707 | 685 | self.argv = argv |
|
708 | 686 | self.aliases = aliases or {} |
|
709 | 687 | self.flags = flags or {} |
|
710 | 688 | |
|
711 | 689 | self.parser_args = parser_args |
|
712 | 690 | self.version = parser_kw.pop("version", None) |
|
713 | 691 | kwargs = dict(argument_default=argparse.SUPPRESS) |
|
714 | 692 | kwargs.update(parser_kw) |
|
715 | 693 | self.parser_kw = kwargs |
|
716 | 694 | |
|
717 | 695 | def load_config(self, argv=None, aliases=None, flags=None): |
|
718 | 696 | """Parse command line arguments and return as a Config object. |
|
719 | 697 | |
|
720 | 698 | Parameters |
|
721 | 699 | ---------- |
|
722 | 700 | |
|
723 | 701 | args : optional, list |
|
724 | 702 | If given, a list with the structure of sys.argv[1:] to parse |
|
725 | 703 | arguments from. If not given, the instance's self.argv attribute |
|
726 | 704 | (given at construction time) is used.""" |
|
727 | 705 | self.clear() |
|
728 | 706 | if argv is None: |
|
729 | 707 | argv = self.argv |
|
730 | 708 | if aliases is None: |
|
731 | 709 | aliases = self.aliases |
|
732 | 710 | if flags is None: |
|
733 | 711 | flags = self.flags |
|
734 | 712 | self._create_parser(aliases, flags) |
|
735 | 713 | self._parse_args(argv) |
|
736 | 714 | self._convert_to_config() |
|
737 | 715 | return self.config |
|
738 | 716 | |
|
739 | 717 | def get_extra_args(self): |
|
740 | 718 | if hasattr(self, 'extra_args'): |
|
741 | 719 | return self.extra_args |
|
742 | 720 | else: |
|
743 | 721 | return [] |
|
744 | 722 | |
|
745 | 723 | def _create_parser(self, aliases=None, flags=None): |
|
746 | 724 | self.parser = ArgumentParser(*self.parser_args, **self.parser_kw) |
|
747 | 725 | self._add_arguments(aliases, flags) |
|
748 | 726 | |
|
749 | 727 | def _add_arguments(self, aliases=None, flags=None): |
|
750 | 728 | raise NotImplementedError("subclasses must implement _add_arguments") |
|
751 | 729 | |
|
752 | 730 | def _parse_args(self, args): |
|
753 | 731 | """self.parser->self.parsed_data""" |
|
754 | 732 | # decode sys.argv to support unicode command-line options |
|
755 | 733 | enc = DEFAULT_ENCODING |
|
756 | 734 | uargs = [py3compat.cast_unicode(a, enc) for a in args] |
|
757 | 735 | self.parsed_data, self.extra_args = self.parser.parse_known_args(uargs) |
|
758 | 736 | |
|
759 | 737 | def _convert_to_config(self): |
|
760 | 738 | """self.parsed_data->self.config""" |
|
761 | 739 | for k, v in iteritems(vars(self.parsed_data)): |
|
762 | 740 | exec("self.config.%s = v"%k, locals(), globals()) |
|
763 | 741 | |
|
764 | 742 | class KVArgParseConfigLoader(ArgParseConfigLoader): |
|
765 | 743 | """A config loader that loads aliases and flags with argparse, |
|
766 | 744 | but will use KVLoader for the rest. This allows better parsing |
|
767 | 745 | of common args, such as `ipython -c 'print 5'`, but still gets |
|
768 | 746 | arbitrary config with `ipython --InteractiveShell.use_readline=False`""" |
|
769 | 747 | |
|
770 | 748 | def _add_arguments(self, aliases=None, flags=None): |
|
771 | 749 | self.alias_flags = {} |
|
772 | 750 | # print aliases, flags |
|
773 | 751 | if aliases is None: |
|
774 | 752 | aliases = self.aliases |
|
775 | 753 | if flags is None: |
|
776 | 754 | flags = self.flags |
|
777 | 755 | paa = self.parser.add_argument |
|
778 | 756 | for key,value in iteritems(aliases): |
|
779 | 757 | if key in flags: |
|
780 | 758 | # flags |
|
781 | 759 | nargs = '?' |
|
782 | 760 | else: |
|
783 | 761 | nargs = None |
|
784 | 762 | if len(key) is 1: |
|
785 | 763 | paa('-'+key, '--'+key, type=unicode_type, dest=value, nargs=nargs) |
|
786 | 764 | else: |
|
787 | 765 | paa('--'+key, type=unicode_type, dest=value, nargs=nargs) |
|
788 | 766 | for key, (value, help) in iteritems(flags): |
|
789 | 767 | if key in self.aliases: |
|
790 | 768 | # |
|
791 | 769 | self.alias_flags[self.aliases[key]] = value |
|
792 | 770 | continue |
|
793 | 771 | if len(key) is 1: |
|
794 | 772 | paa('-'+key, '--'+key, action='append_const', dest='_flags', const=value) |
|
795 | 773 | else: |
|
796 | 774 | paa('--'+key, action='append_const', dest='_flags', const=value) |
|
797 | 775 | |
|
798 | 776 | def _convert_to_config(self): |
|
799 | 777 | """self.parsed_data->self.config, parse unrecognized extra args via KVLoader.""" |
|
800 | 778 | # remove subconfigs list from namespace before transforming the Namespace |
|
801 | 779 | if '_flags' in self.parsed_data: |
|
802 | 780 | subcs = self.parsed_data._flags |
|
803 | 781 | del self.parsed_data._flags |
|
804 | 782 | else: |
|
805 | 783 | subcs = [] |
|
806 | 784 | |
|
807 | 785 | for k, v in iteritems(vars(self.parsed_data)): |
|
808 | 786 | if v is None: |
|
809 | 787 | # it was a flag that shares the name of an alias |
|
810 | 788 | subcs.append(self.alias_flags[k]) |
|
811 | 789 | else: |
|
812 | 790 | # eval the KV assignment |
|
813 | 791 | self._exec_config_str(k, v) |
|
814 | 792 | |
|
815 | 793 | for subc in subcs: |
|
816 | 794 | self._load_flag(subc) |
|
817 | 795 | |
|
818 | 796 | if self.extra_args: |
|
819 | 797 | sub_parser = KeyValueConfigLoader(log=self.log) |
|
820 | 798 | sub_parser.load_config(self.extra_args) |
|
821 | 799 | self.config.merge(sub_parser.config) |
|
822 | 800 | self.extra_args = sub_parser.extra_args |
|
823 | 801 | |
|
824 | 802 | |
|
825 | 803 | def load_pyconfig_files(config_files, path): |
|
826 | 804 | """Load multiple Python config files, merging each of them in turn. |
|
827 | 805 | |
|
828 | 806 | Parameters |
|
829 | 807 | ========== |
|
830 | 808 | config_files : list of str |
|
831 | 809 | List of config files names to load and merge into the config. |
|
832 | 810 | path : unicode |
|
833 | 811 | The full path to the location of the config files. |
|
834 | 812 | """ |
|
835 | 813 | config = Config() |
|
836 | 814 | for cf in config_files: |
|
837 | 815 | loader = PyFileConfigLoader(cf, path=path) |
|
838 | 816 | try: |
|
839 | 817 | next_config = loader.load_config() |
|
840 | 818 | except ConfigFileNotFound: |
|
841 | 819 | pass |
|
842 | 820 | except: |
|
843 | 821 | raise |
|
844 | 822 | else: |
|
845 | 823 | config.merge(next_config) |
|
846 | 824 | return config |
@@ -1,390 +1,376 b'' | |||
|
1 | """Base Tornado handlers for the notebook. | |
|
2 | ||
|
3 | Authors: | |
|
4 | ||
|
5 | * Brian Granger | |
|
6 | """ | |
|
7 | ||
|
8 | #----------------------------------------------------------------------------- | |
|
9 | # Copyright (C) 2011 The IPython Development Team | |
|
10 | # | |
|
11 | # Distributed under the terms of the BSD License. The full license is in | |
|
12 | # the file COPYING, distributed as part of this software. | |
|
13 | #----------------------------------------------------------------------------- | |
|
14 | ||
|
15 | #----------------------------------------------------------------------------- | |
|
16 | # Imports | |
|
17 | #----------------------------------------------------------------------------- | |
|
1 | """Base Tornado handlers for the notebook.""" | |
|
18 | 2 | |
|
3 | # Copyright (c) IPython Development Team. | |
|
4 | # Distributed under the terms of the Modified BSD License. | |
|
19 | 5 | |
|
20 | 6 | import functools |
|
21 | 7 | import json |
|
22 | 8 | import logging |
|
23 | 9 | import os |
|
24 | 10 | import re |
|
25 | 11 | import sys |
|
26 | 12 | import traceback |
|
27 | 13 | try: |
|
28 | 14 | # py3 |
|
29 | 15 | from http.client import responses |
|
30 | 16 | except ImportError: |
|
31 | 17 | from httplib import responses |
|
32 | 18 | |
|
33 | 19 | from jinja2 import TemplateNotFound |
|
34 | 20 | from tornado import web |
|
35 | 21 | |
|
36 | 22 | try: |
|
37 | 23 | from tornado.log import app_log |
|
38 | 24 | except ImportError: |
|
39 | 25 | app_log = logging.getLogger() |
|
40 | 26 | |
|
41 | 27 | from IPython.config import Application |
|
42 | 28 | from IPython.utils.path import filefind |
|
43 | 29 | from IPython.utils.py3compat import string_types |
|
44 | 30 | from IPython.html.utils import is_hidden |
|
45 | 31 | |
|
46 | 32 | #----------------------------------------------------------------------------- |
|
47 | 33 | # Top-level handlers |
|
48 | 34 | #----------------------------------------------------------------------------- |
|
49 | 35 | non_alphanum = re.compile(r'[^A-Za-z0-9]') |
|
50 | 36 | |
|
51 | 37 | class AuthenticatedHandler(web.RequestHandler): |
|
52 | 38 | """A RequestHandler with an authenticated user.""" |
|
53 | 39 | |
|
54 | 40 | def set_default_headers(self): |
|
55 | 41 | headers = self.settings.get('headers', {}) |
|
56 | 42 | for header_name,value in headers.items() : |
|
57 | 43 | try: |
|
58 | 44 | self.set_header(header_name, value) |
|
59 | 45 | except Exception: |
|
60 | 46 | # tornado raise Exception (not a subclass) |
|
61 | 47 | # if method is unsupported (websocket and Access-Control-Allow-Origin |
|
62 | 48 | # for example, so just ignore) |
|
63 | 49 | pass |
|
64 | 50 | |
|
65 | 51 | def clear_login_cookie(self): |
|
66 | 52 | self.clear_cookie(self.cookie_name) |
|
67 | 53 | |
|
68 | 54 | def get_current_user(self): |
|
69 | 55 | user_id = self.get_secure_cookie(self.cookie_name) |
|
70 | 56 | # For now the user_id should not return empty, but it could eventually |
|
71 | 57 | if user_id == '': |
|
72 | 58 | user_id = 'anonymous' |
|
73 | 59 | if user_id is None: |
|
74 | 60 | # prevent extra Invalid cookie sig warnings: |
|
75 | 61 | self.clear_login_cookie() |
|
76 | 62 | if not self.login_available: |
|
77 | 63 | user_id = 'anonymous' |
|
78 | 64 | return user_id |
|
79 | 65 | |
|
80 | 66 | @property |
|
81 | 67 | def cookie_name(self): |
|
82 | 68 | default_cookie_name = non_alphanum.sub('-', 'username-{}'.format( |
|
83 | 69 | self.request.host |
|
84 | 70 | )) |
|
85 | 71 | return self.settings.get('cookie_name', default_cookie_name) |
|
86 | 72 | |
|
87 | 73 | @property |
|
88 | 74 | def password(self): |
|
89 | 75 | """our password""" |
|
90 | 76 | return self.settings.get('password', '') |
|
91 | 77 | |
|
92 | 78 | @property |
|
93 | 79 | def logged_in(self): |
|
94 | 80 | """Is a user currently logged in? |
|
95 | 81 | |
|
96 | 82 | """ |
|
97 | 83 | user = self.get_current_user() |
|
98 | 84 | return (user and not user == 'anonymous') |
|
99 | 85 | |
|
100 | 86 | @property |
|
101 | 87 | def login_available(self): |
|
102 | 88 | """May a user proceed to log in? |
|
103 | 89 | |
|
104 | 90 | This returns True if login capability is available, irrespective of |
|
105 | 91 | whether the user is already logged in or not. |
|
106 | 92 | |
|
107 | 93 | """ |
|
108 | 94 | return bool(self.settings.get('password', '')) |
|
109 | 95 | |
|
110 | 96 | |
|
111 | 97 | class IPythonHandler(AuthenticatedHandler): |
|
112 | 98 | """IPython-specific extensions to authenticated handling |
|
113 | 99 | |
|
114 | 100 | Mostly property shortcuts to IPython-specific settings. |
|
115 | 101 | """ |
|
116 | 102 | |
|
117 | 103 | @property |
|
118 | 104 | def config(self): |
|
119 | 105 | return self.settings.get('config', None) |
|
120 | 106 | |
|
121 | 107 | @property |
|
122 | 108 | def log(self): |
|
123 | 109 | """use the IPython log by default, falling back on tornado's logger""" |
|
124 | 110 | if Application.initialized(): |
|
125 | 111 | return Application.instance().log |
|
126 | 112 | else: |
|
127 | 113 | return app_log |
|
128 | 114 | |
|
129 | 115 | #--------------------------------------------------------------- |
|
130 | 116 | # URLs |
|
131 | 117 | #--------------------------------------------------------------- |
|
132 | 118 | |
|
133 | 119 | @property |
|
134 | 120 | def mathjax_url(self): |
|
135 | 121 | return self.settings.get('mathjax_url', '') |
|
136 | 122 | |
|
137 | 123 | @property |
|
138 | 124 | def base_url(self): |
|
139 | 125 | return self.settings.get('base_url', '/') |
|
140 | 126 | |
|
141 | 127 | #--------------------------------------------------------------- |
|
142 | 128 | # Manager objects |
|
143 | 129 | #--------------------------------------------------------------- |
|
144 | 130 | |
|
145 | 131 | @property |
|
146 | 132 | def kernel_manager(self): |
|
147 | 133 | return self.settings['kernel_manager'] |
|
148 | 134 | |
|
149 | 135 | @property |
|
150 | 136 | def notebook_manager(self): |
|
151 | 137 | return self.settings['notebook_manager'] |
|
152 | 138 | |
|
153 | 139 | @property |
|
154 | 140 | def cluster_manager(self): |
|
155 | 141 | return self.settings['cluster_manager'] |
|
156 | 142 | |
|
157 | 143 | @property |
|
158 | 144 | def session_manager(self): |
|
159 | 145 | return self.settings['session_manager'] |
|
160 | 146 | |
|
161 | 147 | @property |
|
162 | 148 | def kernel_spec_manager(self): |
|
163 | 149 | return self.settings['kernel_spec_manager'] |
|
164 | 150 | |
|
165 | 151 | @property |
|
166 | 152 | def project_dir(self): |
|
167 | 153 | return self.notebook_manager.notebook_dir |
|
168 | 154 | |
|
169 | 155 | #--------------------------------------------------------------- |
|
170 | 156 | # template rendering |
|
171 | 157 | #--------------------------------------------------------------- |
|
172 | 158 | |
|
173 | 159 | def get_template(self, name): |
|
174 | 160 | """Return the jinja template object for a given name""" |
|
175 | 161 | return self.settings['jinja2_env'].get_template(name) |
|
176 | 162 | |
|
177 | 163 | def render_template(self, name, **ns): |
|
178 | 164 | ns.update(self.template_namespace) |
|
179 | 165 | template = self.get_template(name) |
|
180 | 166 | return template.render(**ns) |
|
181 | 167 | |
|
182 | 168 | @property |
|
183 | 169 | def template_namespace(self): |
|
184 | 170 | return dict( |
|
185 | 171 | base_url=self.base_url, |
|
186 | 172 | logged_in=self.logged_in, |
|
187 | 173 | login_available=self.login_available, |
|
188 | 174 | static_url=self.static_url, |
|
189 | 175 | ) |
|
190 | 176 | |
|
191 | 177 | def get_json_body(self): |
|
192 | 178 | """Return the body of the request as JSON data.""" |
|
193 | 179 | if not self.request.body: |
|
194 | 180 | return None |
|
195 | 181 | # Do we need to call body.decode('utf-8') here? |
|
196 | 182 | body = self.request.body.strip().decode(u'utf-8') |
|
197 | 183 | try: |
|
198 | 184 | model = json.loads(body) |
|
199 | 185 | except Exception: |
|
200 | 186 | self.log.debug("Bad JSON: %r", body) |
|
201 | 187 | self.log.error("Couldn't parse JSON", exc_info=True) |
|
202 | 188 | raise web.HTTPError(400, u'Invalid JSON in body of request') |
|
203 | 189 | return model |
|
204 | 190 | |
|
205 | 191 | def get_error_html(self, status_code, **kwargs): |
|
206 | 192 | """render custom error pages""" |
|
207 | 193 | exception = kwargs.get('exception') |
|
208 | 194 | message = '' |
|
209 | 195 | status_message = responses.get(status_code, 'Unknown HTTP Error') |
|
210 | 196 | if exception: |
|
211 | 197 | # get the custom message, if defined |
|
212 | 198 | try: |
|
213 | 199 | message = exception.log_message % exception.args |
|
214 | 200 | except Exception: |
|
215 | 201 | pass |
|
216 | 202 | |
|
217 | 203 | # construct the custom reason, if defined |
|
218 | 204 | reason = getattr(exception, 'reason', '') |
|
219 | 205 | if reason: |
|
220 | 206 | status_message = reason |
|
221 | 207 | |
|
222 | 208 | # build template namespace |
|
223 | 209 | ns = dict( |
|
224 | 210 | status_code=status_code, |
|
225 | 211 | status_message=status_message, |
|
226 | 212 | message=message, |
|
227 | 213 | exception=exception, |
|
228 | 214 | ) |
|
229 | 215 | |
|
230 | 216 | # render the template |
|
231 | 217 | try: |
|
232 | 218 | html = self.render_template('%s.html' % status_code, **ns) |
|
233 | 219 | except TemplateNotFound: |
|
234 | 220 | self.log.debug("No template for %d", status_code) |
|
235 | 221 | html = self.render_template('error.html', **ns) |
|
236 | 222 | return html |
|
237 | 223 | |
|
238 | 224 | |
|
239 | 225 | class Template404(IPythonHandler): |
|
240 | 226 | """Render our 404 template""" |
|
241 | 227 | def prepare(self): |
|
242 | 228 | raise web.HTTPError(404) |
|
243 | 229 | |
|
244 | 230 | |
|
245 | 231 | class AuthenticatedFileHandler(IPythonHandler, web.StaticFileHandler): |
|
246 | 232 | """static files should only be accessible when logged in""" |
|
247 | 233 | |
|
248 | 234 | @web.authenticated |
|
249 | 235 | def get(self, path): |
|
250 | 236 | if os.path.splitext(path)[1] == '.ipynb': |
|
251 | 237 | name = os.path.basename(path) |
|
252 | 238 | self.set_header('Content-Type', 'application/json') |
|
253 | 239 | self.set_header('Content-Disposition','attachment; filename="%s"' % name) |
|
254 | 240 | |
|
255 | 241 | return web.StaticFileHandler.get(self, path) |
|
256 | 242 | |
|
257 | 243 | def compute_etag(self): |
|
258 | 244 | return None |
|
259 | 245 | |
|
260 | 246 | def validate_absolute_path(self, root, absolute_path): |
|
261 | 247 | """Validate and return the absolute path. |
|
262 | 248 | |
|
263 | 249 | Requires tornado 3.1 |
|
264 | 250 | |
|
265 | 251 | Adding to tornado's own handling, forbids the serving of hidden files. |
|
266 | 252 | """ |
|
267 | 253 | abs_path = super(AuthenticatedFileHandler, self).validate_absolute_path(root, absolute_path) |
|
268 | 254 | abs_root = os.path.abspath(root) |
|
269 | 255 | if is_hidden(abs_path, abs_root): |
|
270 | 256 | self.log.info("Refusing to serve hidden file, via 404 Error") |
|
271 | 257 | raise web.HTTPError(404) |
|
272 | 258 | return abs_path |
|
273 | 259 | |
|
274 | 260 | |
|
275 | 261 | def json_errors(method): |
|
276 | 262 | """Decorate methods with this to return GitHub style JSON errors. |
|
277 | 263 | |
|
278 | 264 | This should be used on any JSON API on any handler method that can raise HTTPErrors. |
|
279 | 265 | |
|
280 | 266 | This will grab the latest HTTPError exception using sys.exc_info |
|
281 | 267 | and then: |
|
282 | 268 | |
|
283 | 269 | 1. Set the HTTP status code based on the HTTPError |
|
284 | 270 | 2. Create and return a JSON body with a message field describing |
|
285 | 271 | the error in a human readable form. |
|
286 | 272 | """ |
|
287 | 273 | @functools.wraps(method) |
|
288 | 274 | def wrapper(self, *args, **kwargs): |
|
289 | 275 | try: |
|
290 | 276 | result = method(self, *args, **kwargs) |
|
291 | 277 | except web.HTTPError as e: |
|
292 | 278 | status = e.status_code |
|
293 | 279 | message = e.log_message |
|
294 | 280 | self.log.warn(message) |
|
295 | 281 | self.set_status(e.status_code) |
|
296 | 282 | self.finish(json.dumps(dict(message=message))) |
|
297 | 283 | except Exception: |
|
298 | 284 | self.log.error("Unhandled error in API request", exc_info=True) |
|
299 | 285 | status = 500 |
|
300 | 286 | message = "Unknown server error" |
|
301 | 287 | t, value, tb = sys.exc_info() |
|
302 | 288 | self.set_status(status) |
|
303 | 289 | tb_text = ''.join(traceback.format_exception(t, value, tb)) |
|
304 | 290 | reply = dict(message=message, traceback=tb_text) |
|
305 | 291 | self.finish(json.dumps(reply)) |
|
306 | 292 | else: |
|
307 | 293 | return result |
|
308 | 294 | return wrapper |
|
309 | 295 | |
|
310 | 296 | |
|
311 | 297 | |
|
312 | 298 | #----------------------------------------------------------------------------- |
|
313 | 299 | # File handler |
|
314 | 300 | #----------------------------------------------------------------------------- |
|
315 | 301 | |
|
316 | 302 | # to minimize subclass changes: |
|
317 | 303 | HTTPError = web.HTTPError |
|
318 | 304 | |
|
319 | 305 | class FileFindHandler(web.StaticFileHandler): |
|
320 | 306 | """subclass of StaticFileHandler for serving files from a search path""" |
|
321 | 307 | |
|
322 | 308 | # cache search results, don't search for files more than once |
|
323 | 309 | _static_paths = {} |
|
324 | 310 | |
|
325 | 311 | def initialize(self, path, default_filename=None): |
|
326 | 312 | if isinstance(path, string_types): |
|
327 | 313 | path = [path] |
|
328 | 314 | |
|
329 | 315 | self.root = tuple( |
|
330 | 316 | os.path.abspath(os.path.expanduser(p)) + os.sep for p in path |
|
331 | 317 | ) |
|
332 | 318 | self.default_filename = default_filename |
|
333 | 319 | |
|
334 | 320 | def compute_etag(self): |
|
335 | 321 | return None |
|
336 | 322 | |
|
337 | 323 | @classmethod |
|
338 | 324 | def get_absolute_path(cls, roots, path): |
|
339 | 325 | """locate a file to serve on our static file search path""" |
|
340 | 326 | with cls._lock: |
|
341 | 327 | if path in cls._static_paths: |
|
342 | 328 | return cls._static_paths[path] |
|
343 | 329 | try: |
|
344 | 330 | abspath = os.path.abspath(filefind(path, roots)) |
|
345 | 331 | except IOError: |
|
346 | 332 | # IOError means not found |
|
347 | 333 | return '' |
|
348 | 334 | |
|
349 | 335 | cls._static_paths[path] = abspath |
|
350 | 336 | return abspath |
|
351 | 337 | |
|
352 | 338 | def validate_absolute_path(self, root, absolute_path): |
|
353 | 339 | """check if the file should be served (raises 404, 403, etc.)""" |
|
354 | 340 | if absolute_path == '': |
|
355 | 341 | raise web.HTTPError(404) |
|
356 | 342 | |
|
357 | 343 | for root in self.root: |
|
358 | 344 | if (absolute_path + os.sep).startswith(root): |
|
359 | 345 | break |
|
360 | 346 | |
|
361 | 347 | return super(FileFindHandler, self).validate_absolute_path(root, absolute_path) |
|
362 | 348 | |
|
363 | 349 | |
|
364 | 350 | class TrailingSlashHandler(web.RequestHandler): |
|
365 | 351 | """Simple redirect handler that strips trailing slashes |
|
366 | 352 | |
|
367 | 353 | This should be the first, highest priority handler. |
|
368 | 354 | """ |
|
369 | 355 | |
|
370 | 356 | SUPPORTED_METHODS = ['GET'] |
|
371 | 357 | |
|
372 | 358 | def get(self): |
|
373 | 359 | self.redirect(self.request.uri.rstrip('/')) |
|
374 | 360 | |
|
375 | 361 | #----------------------------------------------------------------------------- |
|
376 | 362 | # URL pattern fragments for re-use |
|
377 | 363 | #----------------------------------------------------------------------------- |
|
378 | 364 | |
|
379 | 365 | path_regex = r"(?P<path>(?:/.*)*)" |
|
380 | 366 | notebook_name_regex = r"(?P<name>[^/]+\.ipynb)" |
|
381 | 367 | notebook_path_regex = "%s/%s" % (path_regex, notebook_name_regex) |
|
382 | 368 | |
|
383 | 369 | #----------------------------------------------------------------------------- |
|
384 | 370 | # URL to handler mappings |
|
385 | 371 | #----------------------------------------------------------------------------- |
|
386 | 372 | |
|
387 | 373 | |
|
388 | 374 | default_handlers = [ |
|
389 | 375 | (r".*/", TrailingSlashHandler) |
|
390 | 376 | ] |
@@ -1,238 +1,217 b'' | |||
|
1 | """The official API for working with notebooks in the current format version. | |
|
2 | ||
|
3 | Authors: | |
|
4 | ||
|
5 | * Brian Granger | |
|
6 | * Jonathan Frederic | |
|
7 | """ | |
|
8 | ||
|
9 | #----------------------------------------------------------------------------- | |
|
10 | # Copyright (C) 2008-2011 The IPython Development Team | |
|
11 | # | |
|
12 | # Distributed under the terms of the BSD License. The full license is in | |
|
13 | # the file COPYING, distributed as part of this software. | |
|
14 | #----------------------------------------------------------------------------- | |
|
15 | ||
|
16 | #----------------------------------------------------------------------------- | |
|
17 | # Imports | |
|
18 | #----------------------------------------------------------------------------- | |
|
1 | """The official API for working with notebooks in the current format version.""" | |
|
19 | 2 | |
|
20 | 3 | from __future__ import print_function |
|
21 | 4 | |
|
22 | 5 | from xml.etree import ElementTree as ET |
|
23 | 6 | import re |
|
24 | 7 | |
|
25 | 8 | from IPython.utils.py3compat import unicode_type |
|
26 | 9 | |
|
27 | 10 | from IPython.nbformat.v3 import ( |
|
28 | 11 | NotebookNode, |
|
29 | 12 | new_code_cell, new_text_cell, new_notebook, new_output, new_worksheet, |
|
30 | 13 | parse_filename, new_metadata, new_author, new_heading_cell, nbformat, |
|
31 | 14 | nbformat_minor, nbformat_schema, to_notebook_json |
|
32 | 15 | ) |
|
33 | 16 | from IPython.nbformat import v3 as _v_latest |
|
34 | 17 | |
|
35 | 18 | from .reader import reads as reader_reads |
|
36 | 19 | from .reader import versions |
|
37 | 20 | from .convert import convert |
|
38 | 21 | from .validator import validate |
|
39 | 22 | |
|
40 | import logging | |
|
41 | logger = logging.getLogger('NotebookApp') | |
|
23 | from IPython.utils.log import get_logger | |
|
42 | 24 | |
|
43 | #----------------------------------------------------------------------------- | |
|
44 | # Code | |
|
45 | #----------------------------------------------------------------------------- | |
|
46 | 25 | |
|
47 | 26 | current_nbformat = nbformat |
|
48 | 27 | current_nbformat_minor = nbformat_minor |
|
49 | 28 | current_nbformat_module = _v_latest.__name__ |
|
50 | 29 | |
|
51 | 30 | |
|
52 | 31 | def docstring_nbformat_mod(func): |
|
53 | 32 | """Decorator for docstrings referring to classes/functions accessed through |
|
54 | 33 | nbformat.current. |
|
55 | 34 | |
|
56 | 35 | Put {nbformat_mod} in the docstring in place of 'IPython.nbformat.v3'. |
|
57 | 36 | """ |
|
58 | 37 | func.__doc__ = func.__doc__.format(nbformat_mod=current_nbformat_module) |
|
59 | 38 | return func |
|
60 | 39 | |
|
61 | 40 | |
|
62 | 41 | class NBFormatError(ValueError): |
|
63 | 42 | pass |
|
64 | 43 | |
|
65 | 44 | |
|
66 | 45 | def parse_py(s, **kwargs): |
|
67 | 46 | """Parse a string into a (nbformat, string) tuple.""" |
|
68 | 47 | nbf = current_nbformat |
|
69 | 48 | nbm = current_nbformat_minor |
|
70 | 49 | |
|
71 | 50 | pattern = r'# <nbformat>(?P<nbformat>\d+[\.\d+]*)</nbformat>' |
|
72 | 51 | m = re.search(pattern,s) |
|
73 | 52 | if m is not None: |
|
74 | 53 | digits = m.group('nbformat').split('.') |
|
75 | 54 | nbf = int(digits[0]) |
|
76 | 55 | if len(digits) > 1: |
|
77 | 56 | nbm = int(digits[1]) |
|
78 | 57 | |
|
79 | 58 | return nbf, nbm, s |
|
80 | 59 | |
|
81 | 60 | |
|
82 | 61 | def reads_json(nbjson, **kwargs): |
|
83 | 62 | """Read a JSON notebook from a string and return the NotebookNode |
|
84 | 63 | object. Report if any JSON format errors are detected. |
|
85 | 64 | |
|
86 | 65 | """ |
|
87 | 66 | nb = reader_reads(nbjson, **kwargs) |
|
88 | 67 | nb_current = convert(nb, current_nbformat) |
|
89 | 68 | errors = validate(nb_current) |
|
90 | 69 | if errors: |
|
91 | logger.error( | |
|
70 | get_logger().error( | |
|
92 | 71 | "Notebook JSON is invalid (%d errors detected during read)", |
|
93 | 72 | len(errors)) |
|
94 | 73 | return nb_current |
|
95 | 74 | |
|
96 | 75 | |
|
97 | 76 | def writes_json(nb, **kwargs): |
|
98 | 77 | """Take a NotebookNode object and write out a JSON string. Report if |
|
99 | 78 | any JSON format errors are detected. |
|
100 | 79 | |
|
101 | 80 | """ |
|
102 | 81 | errors = validate(nb) |
|
103 | 82 | if errors: |
|
104 | logger.error( | |
|
83 | get_logger().error( | |
|
105 | 84 | "Notebook JSON is invalid (%d errors detected during write)", |
|
106 | 85 | len(errors)) |
|
107 | 86 | nbjson = versions[current_nbformat].writes_json(nb, **kwargs) |
|
108 | 87 | return nbjson |
|
109 | 88 | |
|
110 | 89 | |
|
111 | 90 | def reads_py(s, **kwargs): |
|
112 | 91 | """Read a .py notebook from a string and return the NotebookNode object.""" |
|
113 | 92 | nbf, nbm, s = parse_py(s, **kwargs) |
|
114 | 93 | if nbf in (2, 3): |
|
115 | 94 | nb = versions[nbf].to_notebook_py(s, **kwargs) |
|
116 | 95 | else: |
|
117 | 96 | raise NBFormatError('Unsupported PY nbformat version: %i' % nbf) |
|
118 | 97 | return nb |
|
119 | 98 | |
|
120 | 99 | |
|
121 | 100 | def writes_py(nb, **kwargs): |
|
122 | 101 | # nbformat 3 is the latest format that supports py |
|
123 | 102 | return versions[3].writes_py(nb, **kwargs) |
|
124 | 103 | |
|
125 | 104 | |
|
126 | 105 | # High level API |
|
127 | 106 | |
|
128 | 107 | |
|
129 | 108 | def reads(s, format, **kwargs): |
|
130 | 109 | """Read a notebook from a string and return the NotebookNode object. |
|
131 | 110 | |
|
132 | 111 | This function properly handles notebooks of any version. The notebook |
|
133 | 112 | returned will always be in the current version's format. |
|
134 | 113 | |
|
135 | 114 | Parameters |
|
136 | 115 | ---------- |
|
137 | 116 | s : unicode |
|
138 | 117 | The raw unicode string to read the notebook from. |
|
139 | 118 | format : (u'json', u'ipynb', u'py') |
|
140 | 119 | The format that the string is in. |
|
141 | 120 | |
|
142 | 121 | Returns |
|
143 | 122 | ------- |
|
144 | 123 | nb : NotebookNode |
|
145 | 124 | The notebook that was read. |
|
146 | 125 | """ |
|
147 | 126 | format = unicode_type(format) |
|
148 | 127 | if format == u'json' or format == u'ipynb': |
|
149 | 128 | return reads_json(s, **kwargs) |
|
150 | 129 | elif format == u'py': |
|
151 | 130 | return reads_py(s, **kwargs) |
|
152 | 131 | else: |
|
153 | 132 | raise NBFormatError('Unsupported format: %s' % format) |
|
154 | 133 | |
|
155 | 134 | |
|
156 | 135 | def writes(nb, format, **kwargs): |
|
157 | 136 | """Write a notebook to a string in a given format in the current nbformat version. |
|
158 | 137 | |
|
159 | 138 | This function always writes the notebook in the current nbformat version. |
|
160 | 139 | |
|
161 | 140 | Parameters |
|
162 | 141 | ---------- |
|
163 | 142 | nb : NotebookNode |
|
164 | 143 | The notebook to write. |
|
165 | 144 | format : (u'json', u'ipynb', u'py') |
|
166 | 145 | The format to write the notebook in. |
|
167 | 146 | |
|
168 | 147 | Returns |
|
169 | 148 | ------- |
|
170 | 149 | s : unicode |
|
171 | 150 | The notebook string. |
|
172 | 151 | """ |
|
173 | 152 | format = unicode_type(format) |
|
174 | 153 | if format == u'json' or format == u'ipynb': |
|
175 | 154 | return writes_json(nb, **kwargs) |
|
176 | 155 | elif format == u'py': |
|
177 | 156 | return writes_py(nb, **kwargs) |
|
178 | 157 | else: |
|
179 | 158 | raise NBFormatError('Unsupported format: %s' % format) |
|
180 | 159 | |
|
181 | 160 | |
|
182 | 161 | def read(fp, format, **kwargs): |
|
183 | 162 | """Read a notebook from a file and return the NotebookNode object. |
|
184 | 163 | |
|
185 | 164 | This function properly handles notebooks of any version. The notebook |
|
186 | 165 | returned will always be in the current version's format. |
|
187 | 166 | |
|
188 | 167 | Parameters |
|
189 | 168 | ---------- |
|
190 | 169 | fp : file |
|
191 | 170 | Any file-like object with a read method. |
|
192 | 171 | format : (u'json', u'ipynb', u'py') |
|
193 | 172 | The format that the string is in. |
|
194 | 173 | |
|
195 | 174 | Returns |
|
196 | 175 | ------- |
|
197 | 176 | nb : NotebookNode |
|
198 | 177 | The notebook that was read. |
|
199 | 178 | """ |
|
200 | 179 | return reads(fp.read(), format, **kwargs) |
|
201 | 180 | |
|
202 | 181 | |
|
203 | 182 | def write(nb, fp, format, **kwargs): |
|
204 | 183 | """Write a notebook to a file in a given format in the current nbformat version. |
|
205 | 184 | |
|
206 | 185 | This function always writes the notebook in the current nbformat version. |
|
207 | 186 | |
|
208 | 187 | Parameters |
|
209 | 188 | ---------- |
|
210 | 189 | nb : NotebookNode |
|
211 | 190 | The notebook to write. |
|
212 | 191 | fp : file |
|
213 | 192 | Any file-like object with a write method. |
|
214 | 193 | format : (u'json', u'ipynb', u'py') |
|
215 | 194 | The format to write the notebook in. |
|
216 | 195 | |
|
217 | 196 | Returns |
|
218 | 197 | ------- |
|
219 | 198 | s : unicode |
|
220 | 199 | The notebook string. |
|
221 | 200 | """ |
|
222 | 201 | return fp.write(writes(nb, format, **kwargs)) |
|
223 | 202 | |
|
224 | 203 | def _convert_to_metadata(): |
|
225 | 204 | """Convert to a notebook having notebook metadata.""" |
|
226 | 205 | import glob |
|
227 | 206 | for fname in glob.glob('*.ipynb'): |
|
228 | 207 | print('Converting file:',fname) |
|
229 | 208 | with open(fname,'r') as f: |
|
230 | 209 | nb = read(f,u'json') |
|
231 | 210 | md = new_metadata() |
|
232 | 211 | if u'name' in nb: |
|
233 | 212 | md.name = nb.name |
|
234 | 213 | del nb[u'name'] |
|
235 | 214 | nb.metadata = md |
|
236 | 215 | with open(fname,'w') as f: |
|
237 | 216 | write(nb, f, u'json') |
|
238 | 217 |
@@ -1,859 +1,848 b'' | |||
|
1 | 1 | """The Python scheduler for rich scheduling. |
|
2 | 2 | |
|
3 | 3 | The Pure ZMQ scheduler does not allow routing schemes other than LRU, |
|
4 | 4 | nor does it check msg_id DAG dependencies. For those, a slightly slower |
|
5 | 5 | Python Scheduler exists. |
|
6 | ||
|
7 | Authors: | |
|
8 | ||
|
9 | * Min RK | |
|
10 | 6 | """ |
|
11 | #----------------------------------------------------------------------------- | |
|
12 | # Copyright (C) 2010-2011 The IPython Development Team | |
|
13 | # | |
|
14 | # Distributed under the terms of the BSD License. The full license is in | |
|
15 | # the file COPYING, distributed as part of this software. | |
|
16 | #----------------------------------------------------------------------------- | |
|
17 | 7 | |
|
18 | #---------------------------------------------------------------------- | |
|
19 | # Imports | |
|
20 | #---------------------------------------------------------------------- | |
|
8 | # Copyright (c) IPython Development Team. | |
|
9 | # Distributed under the terms of the Modified BSD License. | |
|
21 | 10 | |
|
22 | 11 | import logging |
|
23 | 12 | import sys |
|
24 | 13 | import time |
|
25 | 14 | |
|
26 | 15 | from collections import deque |
|
27 | 16 | from datetime import datetime |
|
28 | 17 | from random import randint, random |
|
29 | 18 | from types import FunctionType |
|
30 | 19 | |
|
31 | 20 | try: |
|
32 | 21 | import numpy |
|
33 | 22 | except ImportError: |
|
34 | 23 | numpy = None |
|
35 | 24 | |
|
36 | 25 | import zmq |
|
37 | 26 | from zmq.eventloop import ioloop, zmqstream |
|
38 | 27 | |
|
39 | 28 | # local imports |
|
40 | 29 | from IPython.external.decorator import decorator |
|
41 | 30 | from IPython.config.application import Application |
|
42 | 31 | from IPython.config.loader import Config |
|
43 | 32 | from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes |
|
44 | 33 | from IPython.utils.py3compat import cast_bytes |
|
45 | 34 | |
|
46 | 35 | from IPython.parallel import error, util |
|
47 | 36 | from IPython.parallel.factory import SessionFactory |
|
48 | 37 | from IPython.parallel.util import connect_logger, local_logger |
|
49 | 38 | |
|
50 | 39 | from .dependency import Dependency |
|
51 | 40 | |
|
52 | 41 | @decorator |
|
53 | 42 | def logged(f,self,*args,**kwargs): |
|
54 | 43 | # print ("#--------------------") |
|
55 | 44 | self.log.debug("scheduler::%s(*%s,**%s)", f.__name__, args, kwargs) |
|
56 | 45 | # print ("#--") |
|
57 | 46 | return f(self,*args, **kwargs) |
|
58 | 47 | |
|
59 | 48 | #---------------------------------------------------------------------- |
|
60 | 49 | # Chooser functions |
|
61 | 50 | #---------------------------------------------------------------------- |
|
62 | 51 | |
|
63 | 52 | def plainrandom(loads): |
|
64 | 53 | """Plain random pick.""" |
|
65 | 54 | n = len(loads) |
|
66 | 55 | return randint(0,n-1) |
|
67 | 56 | |
|
68 | 57 | def lru(loads): |
|
69 | 58 | """Always pick the front of the line. |
|
70 | 59 | |
|
71 | 60 | The content of `loads` is ignored. |
|
72 | 61 | |
|
73 | 62 | Assumes LRU ordering of loads, with oldest first. |
|
74 | 63 | """ |
|
75 | 64 | return 0 |
|
76 | 65 | |
|
77 | 66 | def twobin(loads): |
|
78 | 67 | """Pick two at random, use the LRU of the two. |
|
79 | 68 | |
|
80 | 69 | The content of loads is ignored. |
|
81 | 70 | |
|
82 | 71 | Assumes LRU ordering of loads, with oldest first. |
|
83 | 72 | """ |
|
84 | 73 | n = len(loads) |
|
85 | 74 | a = randint(0,n-1) |
|
86 | 75 | b = randint(0,n-1) |
|
87 | 76 | return min(a,b) |
|
88 | 77 | |
|
89 | 78 | def weighted(loads): |
|
90 | 79 | """Pick two at random using inverse load as weight. |
|
91 | 80 | |
|
92 | 81 | Return the less loaded of the two. |
|
93 | 82 | """ |
|
94 | 83 | # weight 0 a million times more than 1: |
|
95 | 84 | weights = 1./(1e-6+numpy.array(loads)) |
|
96 | 85 | sums = weights.cumsum() |
|
97 | 86 | t = sums[-1] |
|
98 | 87 | x = random()*t |
|
99 | 88 | y = random()*t |
|
100 | 89 | idx = 0 |
|
101 | 90 | idy = 0 |
|
102 | 91 | while sums[idx] < x: |
|
103 | 92 | idx += 1 |
|
104 | 93 | while sums[idy] < y: |
|
105 | 94 | idy += 1 |
|
106 | 95 | if weights[idy] > weights[idx]: |
|
107 | 96 | return idy |
|
108 | 97 | else: |
|
109 | 98 | return idx |
|
110 | 99 | |
|
111 | 100 | def leastload(loads): |
|
112 | 101 | """Always choose the lowest load. |
|
113 | 102 | |
|
114 | 103 | If the lowest load occurs more than once, the first |
|
115 | 104 | occurance will be used. If loads has LRU ordering, this means |
|
116 | 105 | the LRU of those with the lowest load is chosen. |
|
117 | 106 | """ |
|
118 | 107 | return loads.index(min(loads)) |
|
119 | 108 | |
|
120 | 109 | #--------------------------------------------------------------------- |
|
121 | 110 | # Classes |
|
122 | 111 | #--------------------------------------------------------------------- |
|
123 | 112 | |
|
124 | 113 | |
|
125 | 114 | # store empty default dependency: |
|
126 | 115 | MET = Dependency([]) |
|
127 | 116 | |
|
128 | 117 | |
|
129 | 118 | class Job(object): |
|
130 | 119 | """Simple container for a job""" |
|
131 | 120 | def __init__(self, msg_id, raw_msg, idents, msg, header, metadata, |
|
132 | 121 | targets, after, follow, timeout): |
|
133 | 122 | self.msg_id = msg_id |
|
134 | 123 | self.raw_msg = raw_msg |
|
135 | 124 | self.idents = idents |
|
136 | 125 | self.msg = msg |
|
137 | 126 | self.header = header |
|
138 | 127 | self.metadata = metadata |
|
139 | 128 | self.targets = targets |
|
140 | 129 | self.after = after |
|
141 | 130 | self.follow = follow |
|
142 | 131 | self.timeout = timeout |
|
143 | 132 | |
|
144 | 133 | self.removed = False # used for lazy-delete from sorted queue |
|
145 | 134 | self.timestamp = time.time() |
|
146 | 135 | self.timeout_id = 0 |
|
147 | 136 | self.blacklist = set() |
|
148 | 137 | |
|
149 | 138 | def __lt__(self, other): |
|
150 | 139 | return self.timestamp < other.timestamp |
|
151 | 140 | |
|
152 | 141 | def __cmp__(self, other): |
|
153 | 142 | return cmp(self.timestamp, other.timestamp) |
|
154 | 143 | |
|
155 | 144 | @property |
|
156 | 145 | def dependents(self): |
|
157 | 146 | return self.follow.union(self.after) |
|
158 | 147 | |
|
159 | 148 | |
|
160 | 149 | class TaskScheduler(SessionFactory): |
|
161 | 150 | """Python TaskScheduler object. |
|
162 | 151 | |
|
163 | 152 | This is the simplest object that supports msg_id based |
|
164 | 153 | DAG dependencies. *Only* task msg_ids are checked, not |
|
165 | 154 | msg_ids of jobs submitted via the MUX queue. |
|
166 | 155 | |
|
167 | 156 | """ |
|
168 | 157 | |
|
169 | 158 | hwm = Integer(1, config=True, |
|
170 | 159 | help="""specify the High Water Mark (HWM) for the downstream |
|
171 | 160 | socket in the Task scheduler. This is the maximum number |
|
172 | 161 | of allowed outstanding tasks on each engine. |
|
173 | 162 | |
|
174 | 163 | The default (1) means that only one task can be outstanding on each |
|
175 | 164 | engine. Setting TaskScheduler.hwm=0 means there is no limit, and the |
|
176 | 165 | engines continue to be assigned tasks while they are working, |
|
177 | 166 | effectively hiding network latency behind computation, but can result |
|
178 | 167 | in an imbalance of work when submitting many heterogenous tasks all at |
|
179 | 168 | once. Any positive value greater than one is a compromise between the |
|
180 | 169 | two. |
|
181 | 170 | |
|
182 | 171 | """ |
|
183 | 172 | ) |
|
184 | 173 | scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'), |
|
185 | 174 | 'leastload', config=True, allow_none=False, |
|
186 | 175 | help="""select the task scheduler scheme [default: Python LRU] |
|
187 | 176 | Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'""" |
|
188 | 177 | ) |
|
189 | 178 | def _scheme_name_changed(self, old, new): |
|
190 | 179 | self.log.debug("Using scheme %r"%new) |
|
191 | 180 | self.scheme = globals()[new] |
|
192 | 181 | |
|
193 | 182 | # input arguments: |
|
194 | 183 | scheme = Instance(FunctionType) # function for determining the destination |
|
195 | 184 | def _scheme_default(self): |
|
196 | 185 | return leastload |
|
197 | 186 | client_stream = Instance(zmqstream.ZMQStream) # client-facing stream |
|
198 | 187 | engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream |
|
199 | 188 | notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream |
|
200 | 189 | mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream |
|
201 | 190 | query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream |
|
202 | 191 | |
|
203 | 192 | # internals: |
|
204 | 193 | queue = Instance(deque) # sorted list of Jobs |
|
205 | 194 | def _queue_default(self): |
|
206 | 195 | return deque() |
|
207 | 196 | queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue) |
|
208 | 197 | graph = Dict() # dict by msg_id of [ msg_ids that depend on key ] |
|
209 | 198 | retries = Dict() # dict by msg_id of retries remaining (non-neg ints) |
|
210 | 199 | # waiting = List() # list of msg_ids ready to run, but haven't due to HWM |
|
211 | 200 | pending = Dict() # dict by engine_uuid of submitted tasks |
|
212 | 201 | completed = Dict() # dict by engine_uuid of completed tasks |
|
213 | 202 | failed = Dict() # dict by engine_uuid of failed tasks |
|
214 | 203 | destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed) |
|
215 | 204 | clients = Dict() # dict by msg_id for who submitted the task |
|
216 | 205 | targets = List() # list of target IDENTs |
|
217 | 206 | loads = List() # list of engine loads |
|
218 | 207 | # full = Set() # set of IDENTs that have HWM outstanding tasks |
|
219 | 208 | all_completed = Set() # set of all completed tasks |
|
220 | 209 | all_failed = Set() # set of all failed tasks |
|
221 | 210 | all_done = Set() # set of all finished tasks=union(completed,failed) |
|
222 | 211 | all_ids = Set() # set of all submitted task IDs |
|
223 | 212 | |
|
224 | 213 | ident = CBytes() # ZMQ identity. This should just be self.session.session |
|
225 | 214 | # but ensure Bytes |
|
226 | 215 | def _ident_default(self): |
|
227 | 216 | return self.session.bsession |
|
228 | 217 | |
|
229 | 218 | def start(self): |
|
230 | 219 | self.query_stream.on_recv(self.dispatch_query_reply) |
|
231 | 220 | self.session.send(self.query_stream, "connection_request", {}) |
|
232 | 221 | |
|
233 | 222 | self.engine_stream.on_recv(self.dispatch_result, copy=False) |
|
234 | 223 | self.client_stream.on_recv(self.dispatch_submission, copy=False) |
|
235 | 224 | |
|
236 | 225 | self._notification_handlers = dict( |
|
237 | 226 | registration_notification = self._register_engine, |
|
238 | 227 | unregistration_notification = self._unregister_engine |
|
239 | 228 | ) |
|
240 | 229 | self.notifier_stream.on_recv(self.dispatch_notification) |
|
241 | 230 | self.log.info("Scheduler started [%s]" % self.scheme_name) |
|
242 | 231 | |
|
243 | 232 | def resume_receiving(self): |
|
244 | 233 | """Resume accepting jobs.""" |
|
245 | 234 | self.client_stream.on_recv(self.dispatch_submission, copy=False) |
|
246 | 235 | |
|
247 | 236 | def stop_receiving(self): |
|
248 | 237 | """Stop accepting jobs while there are no engines. |
|
249 | 238 | Leave them in the ZMQ queue.""" |
|
250 | 239 | self.client_stream.on_recv(None) |
|
251 | 240 | |
|
252 | 241 | #----------------------------------------------------------------------- |
|
253 | 242 | # [Un]Registration Handling |
|
254 | 243 | #----------------------------------------------------------------------- |
|
255 | 244 | |
|
256 | 245 | |
|
257 | 246 | def dispatch_query_reply(self, msg): |
|
258 | 247 | """handle reply to our initial connection request""" |
|
259 | 248 | try: |
|
260 | 249 | idents,msg = self.session.feed_identities(msg) |
|
261 | 250 | except ValueError: |
|
262 | 251 | self.log.warn("task::Invalid Message: %r",msg) |
|
263 | 252 | return |
|
264 | 253 | try: |
|
265 | 254 | msg = self.session.unserialize(msg) |
|
266 | 255 | except ValueError: |
|
267 | 256 | self.log.warn("task::Unauthorized message from: %r"%idents) |
|
268 | 257 | return |
|
269 | 258 | |
|
270 | 259 | content = msg['content'] |
|
271 | 260 | for uuid in content.get('engines', {}).values(): |
|
272 | 261 | self._register_engine(cast_bytes(uuid)) |
|
273 | 262 | |
|
274 | 263 | |
|
275 | 264 | @util.log_errors |
|
276 | 265 | def dispatch_notification(self, msg): |
|
277 | 266 | """dispatch register/unregister events.""" |
|
278 | 267 | try: |
|
279 | 268 | idents,msg = self.session.feed_identities(msg) |
|
280 | 269 | except ValueError: |
|
281 | 270 | self.log.warn("task::Invalid Message: %r",msg) |
|
282 | 271 | return |
|
283 | 272 | try: |
|
284 | 273 | msg = self.session.unserialize(msg) |
|
285 | 274 | except ValueError: |
|
286 | 275 | self.log.warn("task::Unauthorized message from: %r"%idents) |
|
287 | 276 | return |
|
288 | 277 | |
|
289 | 278 | msg_type = msg['header']['msg_type'] |
|
290 | 279 | |
|
291 | 280 | handler = self._notification_handlers.get(msg_type, None) |
|
292 | 281 | if handler is None: |
|
293 | 282 | self.log.error("Unhandled message type: %r"%msg_type) |
|
294 | 283 | else: |
|
295 | 284 | try: |
|
296 | 285 | handler(cast_bytes(msg['content']['uuid'])) |
|
297 | 286 | except Exception: |
|
298 | 287 | self.log.error("task::Invalid notification msg: %r", msg, exc_info=True) |
|
299 | 288 | |
|
300 | 289 | def _register_engine(self, uid): |
|
301 | 290 | """New engine with ident `uid` became available.""" |
|
302 | 291 | # head of the line: |
|
303 | 292 | self.targets.insert(0,uid) |
|
304 | 293 | self.loads.insert(0,0) |
|
305 | 294 | |
|
306 | 295 | # initialize sets |
|
307 | 296 | self.completed[uid] = set() |
|
308 | 297 | self.failed[uid] = set() |
|
309 | 298 | self.pending[uid] = {} |
|
310 | 299 | |
|
311 | 300 | # rescan the graph: |
|
312 | 301 | self.update_graph(None) |
|
313 | 302 | |
|
314 | 303 | def _unregister_engine(self, uid): |
|
315 | 304 | """Existing engine with ident `uid` became unavailable.""" |
|
316 | 305 | if len(self.targets) == 1: |
|
317 | 306 | # this was our only engine |
|
318 | 307 | pass |
|
319 | 308 | |
|
320 | 309 | # handle any potentially finished tasks: |
|
321 | 310 | self.engine_stream.flush() |
|
322 | 311 | |
|
323 | 312 | # don't pop destinations, because they might be used later |
|
324 | 313 | # map(self.destinations.pop, self.completed.pop(uid)) |
|
325 | 314 | # map(self.destinations.pop, self.failed.pop(uid)) |
|
326 | 315 | |
|
327 | 316 | # prevent this engine from receiving work |
|
328 | 317 | idx = self.targets.index(uid) |
|
329 | 318 | self.targets.pop(idx) |
|
330 | 319 | self.loads.pop(idx) |
|
331 | 320 | |
|
332 | 321 | # wait 5 seconds before cleaning up pending jobs, since the results might |
|
333 | 322 | # still be incoming |
|
334 | 323 | if self.pending[uid]: |
|
335 | 324 | dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop) |
|
336 | 325 | dc.start() |
|
337 | 326 | else: |
|
338 | 327 | self.completed.pop(uid) |
|
339 | 328 | self.failed.pop(uid) |
|
340 | 329 | |
|
341 | 330 | |
|
342 | 331 | def handle_stranded_tasks(self, engine): |
|
343 | 332 | """Deal with jobs resident in an engine that died.""" |
|
344 | 333 | lost = self.pending[engine] |
|
345 | 334 | for msg_id in lost.keys(): |
|
346 | 335 | if msg_id not in self.pending[engine]: |
|
347 | 336 | # prevent double-handling of messages |
|
348 | 337 | continue |
|
349 | 338 | |
|
350 | 339 | raw_msg = lost[msg_id].raw_msg |
|
351 | 340 | idents,msg = self.session.feed_identities(raw_msg, copy=False) |
|
352 | 341 | parent = self.session.unpack(msg[1].bytes) |
|
353 | 342 | idents = [engine, idents[0]] |
|
354 | 343 | |
|
355 | 344 | # build fake error reply |
|
356 | 345 | try: |
|
357 | 346 | raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id)) |
|
358 | 347 | except: |
|
359 | 348 | content = error.wrap_exception() |
|
360 | 349 | # build fake metadata |
|
361 | 350 | md = dict( |
|
362 | 351 | status=u'error', |
|
363 | 352 | engine=engine.decode('ascii'), |
|
364 | 353 | date=datetime.now(), |
|
365 | 354 | ) |
|
366 | 355 | msg = self.session.msg('apply_reply', content, parent=parent, metadata=md) |
|
367 | 356 | raw_reply = list(map(zmq.Message, self.session.serialize(msg, ident=idents))) |
|
368 | 357 | # and dispatch it |
|
369 | 358 | self.dispatch_result(raw_reply) |
|
370 | 359 | |
|
371 | 360 | # finally scrub completed/failed lists |
|
372 | 361 | self.completed.pop(engine) |
|
373 | 362 | self.failed.pop(engine) |
|
374 | 363 | |
|
375 | 364 | |
|
376 | 365 | #----------------------------------------------------------------------- |
|
377 | 366 | # Job Submission |
|
378 | 367 | #----------------------------------------------------------------------- |
|
379 | 368 | |
|
380 | 369 | |
|
381 | 370 | @util.log_errors |
|
382 | 371 | def dispatch_submission(self, raw_msg): |
|
383 | 372 | """Dispatch job submission to appropriate handlers.""" |
|
384 | 373 | # ensure targets up to date: |
|
385 | 374 | self.notifier_stream.flush() |
|
386 | 375 | try: |
|
387 | 376 | idents, msg = self.session.feed_identities(raw_msg, copy=False) |
|
388 | 377 | msg = self.session.unserialize(msg, content=False, copy=False) |
|
389 | 378 | except Exception: |
|
390 | 379 | self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True) |
|
391 | 380 | return |
|
392 | 381 | |
|
393 | 382 | |
|
394 | 383 | # send to monitor |
|
395 | 384 | self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False) |
|
396 | 385 | |
|
397 | 386 | header = msg['header'] |
|
398 | 387 | md = msg['metadata'] |
|
399 | 388 | msg_id = header['msg_id'] |
|
400 | 389 | self.all_ids.add(msg_id) |
|
401 | 390 | |
|
402 | 391 | # get targets as a set of bytes objects |
|
403 | 392 | # from a list of unicode objects |
|
404 | 393 | targets = md.get('targets', []) |
|
405 | 394 | targets = set(map(cast_bytes, targets)) |
|
406 | 395 | |
|
407 | 396 | retries = md.get('retries', 0) |
|
408 | 397 | self.retries[msg_id] = retries |
|
409 | 398 | |
|
410 | 399 | # time dependencies |
|
411 | 400 | after = md.get('after', None) |
|
412 | 401 | if after: |
|
413 | 402 | after = Dependency(after) |
|
414 | 403 | if after.all: |
|
415 | 404 | if after.success: |
|
416 | 405 | after = Dependency(after.difference(self.all_completed), |
|
417 | 406 | success=after.success, |
|
418 | 407 | failure=after.failure, |
|
419 | 408 | all=after.all, |
|
420 | 409 | ) |
|
421 | 410 | if after.failure: |
|
422 | 411 | after = Dependency(after.difference(self.all_failed), |
|
423 | 412 | success=after.success, |
|
424 | 413 | failure=after.failure, |
|
425 | 414 | all=after.all, |
|
426 | 415 | ) |
|
427 | 416 | if after.check(self.all_completed, self.all_failed): |
|
428 | 417 | # recast as empty set, if `after` already met, |
|
429 | 418 | # to prevent unnecessary set comparisons |
|
430 | 419 | after = MET |
|
431 | 420 | else: |
|
432 | 421 | after = MET |
|
433 | 422 | |
|
434 | 423 | # location dependencies |
|
435 | 424 | follow = Dependency(md.get('follow', [])) |
|
436 | 425 | |
|
437 | 426 | timeout = md.get('timeout', None) |
|
438 | 427 | if timeout: |
|
439 | 428 | timeout = float(timeout) |
|
440 | 429 | |
|
441 | 430 | job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg, |
|
442 | 431 | header=header, targets=targets, after=after, follow=follow, |
|
443 | 432 | timeout=timeout, metadata=md, |
|
444 | 433 | ) |
|
445 | 434 | # validate and reduce dependencies: |
|
446 | 435 | for dep in after,follow: |
|
447 | 436 | if not dep: # empty dependency |
|
448 | 437 | continue |
|
449 | 438 | # check valid: |
|
450 | 439 | if msg_id in dep or dep.difference(self.all_ids): |
|
451 | 440 | self.queue_map[msg_id] = job |
|
452 | 441 | return self.fail_unreachable(msg_id, error.InvalidDependency) |
|
453 | 442 | # check if unreachable: |
|
454 | 443 | if dep.unreachable(self.all_completed, self.all_failed): |
|
455 | 444 | self.queue_map[msg_id] = job |
|
456 | 445 | return self.fail_unreachable(msg_id) |
|
457 | 446 | |
|
458 | 447 | if after.check(self.all_completed, self.all_failed): |
|
459 | 448 | # time deps already met, try to run |
|
460 | 449 | if not self.maybe_run(job): |
|
461 | 450 | # can't run yet |
|
462 | 451 | if msg_id not in self.all_failed: |
|
463 | 452 | # could have failed as unreachable |
|
464 | 453 | self.save_unmet(job) |
|
465 | 454 | else: |
|
466 | 455 | self.save_unmet(job) |
|
467 | 456 | |
|
468 | 457 | def job_timeout(self, job, timeout_id): |
|
469 | 458 | """callback for a job's timeout. |
|
470 | 459 | |
|
471 | 460 | The job may or may not have been run at this point. |
|
472 | 461 | """ |
|
473 | 462 | if job.timeout_id != timeout_id: |
|
474 | 463 | # not the most recent call |
|
475 | 464 | return |
|
476 | 465 | now = time.time() |
|
477 | 466 | if job.timeout >= (now + 1): |
|
478 | 467 | self.log.warn("task %s timeout fired prematurely: %s > %s", |
|
479 | 468 | job.msg_id, job.timeout, now |
|
480 | 469 | ) |
|
481 | 470 | if job.msg_id in self.queue_map: |
|
482 | 471 | # still waiting, but ran out of time |
|
483 | 472 | self.log.info("task %r timed out", job.msg_id) |
|
484 | 473 | self.fail_unreachable(job.msg_id, error.TaskTimeout) |
|
485 | 474 | |
|
486 | 475 | def fail_unreachable(self, msg_id, why=error.ImpossibleDependency): |
|
487 | 476 | """a task has become unreachable, send a reply with an ImpossibleDependency |
|
488 | 477 | error.""" |
|
489 | 478 | if msg_id not in self.queue_map: |
|
490 | 479 | self.log.error("task %r already failed!", msg_id) |
|
491 | 480 | return |
|
492 | 481 | job = self.queue_map.pop(msg_id) |
|
493 | 482 | # lazy-delete from the queue |
|
494 | 483 | job.removed = True |
|
495 | 484 | for mid in job.dependents: |
|
496 | 485 | if mid in self.graph: |
|
497 | 486 | self.graph[mid].remove(msg_id) |
|
498 | 487 | |
|
499 | 488 | try: |
|
500 | 489 | raise why() |
|
501 | 490 | except: |
|
502 | 491 | content = error.wrap_exception() |
|
503 | 492 | self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename']) |
|
504 | 493 | |
|
505 | 494 | self.all_done.add(msg_id) |
|
506 | 495 | self.all_failed.add(msg_id) |
|
507 | 496 | |
|
508 | 497 | msg = self.session.send(self.client_stream, 'apply_reply', content, |
|
509 | 498 | parent=job.header, ident=job.idents) |
|
510 | 499 | self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents) |
|
511 | 500 | |
|
512 | 501 | self.update_graph(msg_id, success=False) |
|
513 | 502 | |
|
514 | 503 | def available_engines(self): |
|
515 | 504 | """return a list of available engine indices based on HWM""" |
|
516 | 505 | if not self.hwm: |
|
517 | 506 | return list(range(len(self.targets))) |
|
518 | 507 | available = [] |
|
519 | 508 | for idx in range(len(self.targets)): |
|
520 | 509 | if self.loads[idx] < self.hwm: |
|
521 | 510 | available.append(idx) |
|
522 | 511 | return available |
|
523 | 512 | |
|
524 | 513 | def maybe_run(self, job): |
|
525 | 514 | """check location dependencies, and run if they are met.""" |
|
526 | 515 | msg_id = job.msg_id |
|
527 | 516 | self.log.debug("Attempting to assign task %s", msg_id) |
|
528 | 517 | available = self.available_engines() |
|
529 | 518 | if not available: |
|
530 | 519 | # no engines, definitely can't run |
|
531 | 520 | return False |
|
532 | 521 | |
|
533 | 522 | if job.follow or job.targets or job.blacklist or self.hwm: |
|
534 | 523 | # we need a can_run filter |
|
535 | 524 | def can_run(idx): |
|
536 | 525 | # check hwm |
|
537 | 526 | if self.hwm and self.loads[idx] == self.hwm: |
|
538 | 527 | return False |
|
539 | 528 | target = self.targets[idx] |
|
540 | 529 | # check blacklist |
|
541 | 530 | if target in job.blacklist: |
|
542 | 531 | return False |
|
543 | 532 | # check targets |
|
544 | 533 | if job.targets and target not in job.targets: |
|
545 | 534 | return False |
|
546 | 535 | # check follow |
|
547 | 536 | return job.follow.check(self.completed[target], self.failed[target]) |
|
548 | 537 | |
|
549 | 538 | indices = list(filter(can_run, available)) |
|
550 | 539 | |
|
551 | 540 | if not indices: |
|
552 | 541 | # couldn't run |
|
553 | 542 | if job.follow.all: |
|
554 | 543 | # check follow for impossibility |
|
555 | 544 | dests = set() |
|
556 | 545 | relevant = set() |
|
557 | 546 | if job.follow.success: |
|
558 | 547 | relevant = self.all_completed |
|
559 | 548 | if job.follow.failure: |
|
560 | 549 | relevant = relevant.union(self.all_failed) |
|
561 | 550 | for m in job.follow.intersection(relevant): |
|
562 | 551 | dests.add(self.destinations[m]) |
|
563 | 552 | if len(dests) > 1: |
|
564 | 553 | self.queue_map[msg_id] = job |
|
565 | 554 | self.fail_unreachable(msg_id) |
|
566 | 555 | return False |
|
567 | 556 | if job.targets: |
|
568 | 557 | # check blacklist+targets for impossibility |
|
569 | 558 | job.targets.difference_update(job.blacklist) |
|
570 | 559 | if not job.targets or not job.targets.intersection(self.targets): |
|
571 | 560 | self.queue_map[msg_id] = job |
|
572 | 561 | self.fail_unreachable(msg_id) |
|
573 | 562 | return False |
|
574 | 563 | return False |
|
575 | 564 | else: |
|
576 | 565 | indices = None |
|
577 | 566 | |
|
578 | 567 | self.submit_task(job, indices) |
|
579 | 568 | return True |
|
580 | 569 | |
|
581 | 570 | def save_unmet(self, job): |
|
582 | 571 | """Save a message for later submission when its dependencies are met.""" |
|
583 | 572 | msg_id = job.msg_id |
|
584 | 573 | self.log.debug("Adding task %s to the queue", msg_id) |
|
585 | 574 | self.queue_map[msg_id] = job |
|
586 | 575 | self.queue.append(job) |
|
587 | 576 | # track the ids in follow or after, but not those already finished |
|
588 | 577 | for dep_id in job.after.union(job.follow).difference(self.all_done): |
|
589 | 578 | if dep_id not in self.graph: |
|
590 | 579 | self.graph[dep_id] = set() |
|
591 | 580 | self.graph[dep_id].add(msg_id) |
|
592 | 581 | |
|
593 | 582 | # schedule timeout callback |
|
594 | 583 | if job.timeout: |
|
595 | 584 | timeout_id = job.timeout_id = job.timeout_id + 1 |
|
596 | 585 | self.loop.add_timeout(time.time() + job.timeout, |
|
597 | 586 | lambda : self.job_timeout(job, timeout_id) |
|
598 | 587 | ) |
|
599 | 588 | |
|
600 | 589 | |
|
601 | 590 | def submit_task(self, job, indices=None): |
|
602 | 591 | """Submit a task to any of a subset of our targets.""" |
|
603 | 592 | if indices: |
|
604 | 593 | loads = [self.loads[i] for i in indices] |
|
605 | 594 | else: |
|
606 | 595 | loads = self.loads |
|
607 | 596 | idx = self.scheme(loads) |
|
608 | 597 | if indices: |
|
609 | 598 | idx = indices[idx] |
|
610 | 599 | target = self.targets[idx] |
|
611 | 600 | # print (target, map(str, msg[:3])) |
|
612 | 601 | # send job to the engine |
|
613 | 602 | self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False) |
|
614 | 603 | self.engine_stream.send_multipart(job.raw_msg, copy=False) |
|
615 | 604 | # update load |
|
616 | 605 | self.add_job(idx) |
|
617 | 606 | self.pending[target][job.msg_id] = job |
|
618 | 607 | # notify Hub |
|
619 | 608 | content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii')) |
|
620 | 609 | self.session.send(self.mon_stream, 'task_destination', content=content, |
|
621 | 610 | ident=[b'tracktask',self.ident]) |
|
622 | 611 | |
|
623 | 612 | |
|
624 | 613 | #----------------------------------------------------------------------- |
|
625 | 614 | # Result Handling |
|
626 | 615 | #----------------------------------------------------------------------- |
|
627 | 616 | |
|
628 | 617 | |
|
629 | 618 | @util.log_errors |
|
630 | 619 | def dispatch_result(self, raw_msg): |
|
631 | 620 | """dispatch method for result replies""" |
|
632 | 621 | try: |
|
633 | 622 | idents,msg = self.session.feed_identities(raw_msg, copy=False) |
|
634 | 623 | msg = self.session.unserialize(msg, content=False, copy=False) |
|
635 | 624 | engine = idents[0] |
|
636 | 625 | try: |
|
637 | 626 | idx = self.targets.index(engine) |
|
638 | 627 | except ValueError: |
|
639 | 628 | pass # skip load-update for dead engines |
|
640 | 629 | else: |
|
641 | 630 | self.finish_job(idx) |
|
642 | 631 | except Exception: |
|
643 | 632 | self.log.error("task::Invalid result: %r", raw_msg, exc_info=True) |
|
644 | 633 | return |
|
645 | 634 | |
|
646 | 635 | md = msg['metadata'] |
|
647 | 636 | parent = msg['parent_header'] |
|
648 | 637 | if md.get('dependencies_met', True): |
|
649 | 638 | success = (md['status'] == 'ok') |
|
650 | 639 | msg_id = parent['msg_id'] |
|
651 | 640 | retries = self.retries[msg_id] |
|
652 | 641 | if not success and retries > 0: |
|
653 | 642 | # failed |
|
654 | 643 | self.retries[msg_id] = retries - 1 |
|
655 | 644 | self.handle_unmet_dependency(idents, parent) |
|
656 | 645 | else: |
|
657 | 646 | del self.retries[msg_id] |
|
658 | 647 | # relay to client and update graph |
|
659 | 648 | self.handle_result(idents, parent, raw_msg, success) |
|
660 | 649 | # send to Hub monitor |
|
661 | 650 | self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False) |
|
662 | 651 | else: |
|
663 | 652 | self.handle_unmet_dependency(idents, parent) |
|
664 | 653 | |
|
665 | 654 | def handle_result(self, idents, parent, raw_msg, success=True): |
|
666 | 655 | """handle a real task result, either success or failure""" |
|
667 | 656 | # first, relay result to client |
|
668 | 657 | engine = idents[0] |
|
669 | 658 | client = idents[1] |
|
670 | 659 | # swap_ids for ROUTER-ROUTER mirror |
|
671 | 660 | raw_msg[:2] = [client,engine] |
|
672 | 661 | # print (map(str, raw_msg[:4])) |
|
673 | 662 | self.client_stream.send_multipart(raw_msg, copy=False) |
|
674 | 663 | # now, update our data structures |
|
675 | 664 | msg_id = parent['msg_id'] |
|
676 | 665 | self.pending[engine].pop(msg_id) |
|
677 | 666 | if success: |
|
678 | 667 | self.completed[engine].add(msg_id) |
|
679 | 668 | self.all_completed.add(msg_id) |
|
680 | 669 | else: |
|
681 | 670 | self.failed[engine].add(msg_id) |
|
682 | 671 | self.all_failed.add(msg_id) |
|
683 | 672 | self.all_done.add(msg_id) |
|
684 | 673 | self.destinations[msg_id] = engine |
|
685 | 674 | |
|
686 | 675 | self.update_graph(msg_id, success) |
|
687 | 676 | |
|
688 | 677 | def handle_unmet_dependency(self, idents, parent): |
|
689 | 678 | """handle an unmet dependency""" |
|
690 | 679 | engine = idents[0] |
|
691 | 680 | msg_id = parent['msg_id'] |
|
692 | 681 | |
|
693 | 682 | job = self.pending[engine].pop(msg_id) |
|
694 | 683 | job.blacklist.add(engine) |
|
695 | 684 | |
|
696 | 685 | if job.blacklist == job.targets: |
|
697 | 686 | self.queue_map[msg_id] = job |
|
698 | 687 | self.fail_unreachable(msg_id) |
|
699 | 688 | elif not self.maybe_run(job): |
|
700 | 689 | # resubmit failed |
|
701 | 690 | if msg_id not in self.all_failed: |
|
702 | 691 | # put it back in our dependency tree |
|
703 | 692 | self.save_unmet(job) |
|
704 | 693 | |
|
705 | 694 | if self.hwm: |
|
706 | 695 | try: |
|
707 | 696 | idx = self.targets.index(engine) |
|
708 | 697 | except ValueError: |
|
709 | 698 | pass # skip load-update for dead engines |
|
710 | 699 | else: |
|
711 | 700 | if self.loads[idx] == self.hwm-1: |
|
712 | 701 | self.update_graph(None) |
|
713 | 702 | |
|
714 | 703 | def update_graph(self, dep_id=None, success=True): |
|
715 | 704 | """dep_id just finished. Update our dependency |
|
716 | 705 | graph and submit any jobs that just became runnable. |
|
717 | 706 | |
|
718 | 707 | Called with dep_id=None to update entire graph for hwm, but without finishing a task. |
|
719 | 708 | """ |
|
720 | 709 | # print ("\n\n***********") |
|
721 | 710 | # pprint (dep_id) |
|
722 | 711 | # pprint (self.graph) |
|
723 | 712 | # pprint (self.queue_map) |
|
724 | 713 | # pprint (self.all_completed) |
|
725 | 714 | # pprint (self.all_failed) |
|
726 | 715 | # print ("\n\n***********\n\n") |
|
727 | 716 | # update any jobs that depended on the dependency |
|
728 | 717 | msg_ids = self.graph.pop(dep_id, []) |
|
729 | 718 | |
|
730 | 719 | # recheck *all* jobs if |
|
731 | 720 | # a) we have HWM and an engine just become no longer full |
|
732 | 721 | # or b) dep_id was given as None |
|
733 | 722 | |
|
734 | 723 | if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]): |
|
735 | 724 | jobs = self.queue |
|
736 | 725 | using_queue = True |
|
737 | 726 | else: |
|
738 | 727 | using_queue = False |
|
739 | 728 | jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids )) |
|
740 | 729 | |
|
741 | 730 | to_restore = [] |
|
742 | 731 | while jobs: |
|
743 | 732 | job = jobs.popleft() |
|
744 | 733 | if job.removed: |
|
745 | 734 | continue |
|
746 | 735 | msg_id = job.msg_id |
|
747 | 736 | |
|
748 | 737 | put_it_back = True |
|
749 | 738 | |
|
750 | 739 | if job.after.unreachable(self.all_completed, self.all_failed)\ |
|
751 | 740 | or job.follow.unreachable(self.all_completed, self.all_failed): |
|
752 | 741 | self.fail_unreachable(msg_id) |
|
753 | 742 | put_it_back = False |
|
754 | 743 | |
|
755 | 744 | elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run |
|
756 | 745 | if self.maybe_run(job): |
|
757 | 746 | put_it_back = False |
|
758 | 747 | self.queue_map.pop(msg_id) |
|
759 | 748 | for mid in job.dependents: |
|
760 | 749 | if mid in self.graph: |
|
761 | 750 | self.graph[mid].remove(msg_id) |
|
762 | 751 | |
|
763 | 752 | # abort the loop if we just filled up all of our engines. |
|
764 | 753 | # avoids an O(N) operation in situation of full queue, |
|
765 | 754 | # where graph update is triggered as soon as an engine becomes |
|
766 | 755 | # non-full, and all tasks after the first are checked, |
|
767 | 756 | # even though they can't run. |
|
768 | 757 | if not self.available_engines(): |
|
769 | 758 | break |
|
770 | 759 | |
|
771 | 760 | if using_queue and put_it_back: |
|
772 | 761 | # popped a job from the queue but it neither ran nor failed, |
|
773 | 762 | # so we need to put it back when we are done |
|
774 | 763 | # make sure to_restore preserves the same ordering |
|
775 | 764 | to_restore.append(job) |
|
776 | 765 | |
|
777 | 766 | # put back any tasks we popped but didn't run |
|
778 | 767 | if using_queue: |
|
779 | 768 | self.queue.extendleft(to_restore) |
|
780 | 769 | |
|
781 | 770 | #---------------------------------------------------------------------- |
|
782 | 771 | # methods to be overridden by subclasses |
|
783 | 772 | #---------------------------------------------------------------------- |
|
784 | 773 | |
|
785 | 774 | def add_job(self, idx): |
|
786 | 775 | """Called after self.targets[idx] just got the job with header. |
|
787 | 776 | Override with subclasses. The default ordering is simple LRU. |
|
788 | 777 | The default loads are the number of outstanding jobs.""" |
|
789 | 778 | self.loads[idx] += 1 |
|
790 | 779 | for lis in (self.targets, self.loads): |
|
791 | 780 | lis.append(lis.pop(idx)) |
|
792 | 781 | |
|
793 | 782 | |
|
794 | 783 | def finish_job(self, idx): |
|
795 | 784 | """Called after self.targets[idx] just finished a job. |
|
796 | 785 | Override with subclasses.""" |
|
797 | 786 | self.loads[idx] -= 1 |
|
798 | 787 | |
|
799 | 788 | |
|
800 | 789 | |
|
801 | 790 | def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None, |
|
802 | 791 | logname='root', log_url=None, loglevel=logging.DEBUG, |
|
803 | 792 | identity=b'task', in_thread=False): |
|
804 | 793 | |
|
805 | 794 | ZMQStream = zmqstream.ZMQStream |
|
806 | 795 | |
|
807 | 796 | if config: |
|
808 | 797 | # unwrap dict back into Config |
|
809 | 798 | config = Config(config) |
|
810 | 799 | |
|
811 | 800 | if in_thread: |
|
812 | 801 | # use instance() to get the same Context/Loop as our parent |
|
813 | 802 | ctx = zmq.Context.instance() |
|
814 | 803 | loop = ioloop.IOLoop.instance() |
|
815 | 804 | else: |
|
816 | 805 | # in a process, don't use instance() |
|
817 | 806 | # for safety with multiprocessing |
|
818 | 807 | ctx = zmq.Context() |
|
819 | 808 | loop = ioloop.IOLoop() |
|
820 | 809 | ins = ZMQStream(ctx.socket(zmq.ROUTER),loop) |
|
821 | 810 | util.set_hwm(ins, 0) |
|
822 | 811 | ins.setsockopt(zmq.IDENTITY, identity + b'_in') |
|
823 | 812 | ins.bind(in_addr) |
|
824 | 813 | |
|
825 | 814 | outs = ZMQStream(ctx.socket(zmq.ROUTER),loop) |
|
826 | 815 | util.set_hwm(outs, 0) |
|
827 | 816 | outs.setsockopt(zmq.IDENTITY, identity + b'_out') |
|
828 | 817 | outs.bind(out_addr) |
|
829 | 818 | mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop) |
|
830 | 819 | util.set_hwm(mons, 0) |
|
831 | 820 | mons.connect(mon_addr) |
|
832 | 821 | nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop) |
|
833 | 822 | nots.setsockopt(zmq.SUBSCRIBE, b'') |
|
834 | 823 | nots.connect(not_addr) |
|
835 | 824 | |
|
836 | 825 | querys = ZMQStream(ctx.socket(zmq.DEALER),loop) |
|
837 | 826 | querys.connect(reg_addr) |
|
838 | 827 | |
|
839 | 828 | # setup logging. |
|
840 | 829 | if in_thread: |
|
841 | 830 | log = Application.instance().log |
|
842 | 831 | else: |
|
843 | 832 | if log_url: |
|
844 | 833 | log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel) |
|
845 | 834 | else: |
|
846 | 835 | log = local_logger(logname, loglevel) |
|
847 | 836 | |
|
848 | 837 | scheduler = TaskScheduler(client_stream=ins, engine_stream=outs, |
|
849 | 838 | mon_stream=mons, notifier_stream=nots, |
|
850 | 839 | query_stream=querys, |
|
851 | 840 | loop=loop, log=log, |
|
852 | 841 | config=config) |
|
853 | 842 | scheduler.start() |
|
854 | 843 | if not in_thread: |
|
855 | 844 | try: |
|
856 | 845 | loop.start() |
|
857 | 846 | except KeyboardInterrupt: |
|
858 | 847 | scheduler.log.critical("Interrupted, exiting...") |
|
859 | 848 |
@@ -1,388 +1,389 b'' | |||
|
1 | 1 | """Some generic utilities for dealing with classes, urls, and serialization.""" |
|
2 | 2 | |
|
3 | 3 | # Copyright (c) IPython Development Team. |
|
4 | 4 | # Distributed under the terms of the Modified BSD License. |
|
5 | 5 | |
|
6 | 6 | import logging |
|
7 | 7 | import os |
|
8 | 8 | import re |
|
9 | 9 | import stat |
|
10 | 10 | import socket |
|
11 | 11 | import sys |
|
12 | 12 | import warnings |
|
13 | 13 | from signal import signal, SIGINT, SIGABRT, SIGTERM |
|
14 | 14 | try: |
|
15 | 15 | from signal import SIGKILL |
|
16 | 16 | except ImportError: |
|
17 | 17 | SIGKILL=None |
|
18 | 18 | from types import FunctionType |
|
19 | 19 | |
|
20 | 20 | try: |
|
21 | 21 | import cPickle |
|
22 | 22 | pickle = cPickle |
|
23 | 23 | except: |
|
24 | 24 | cPickle = None |
|
25 | 25 | import pickle |
|
26 | 26 | |
|
27 | 27 | import zmq |
|
28 | 28 | from zmq.log import handlers |
|
29 | 29 | |
|
30 | from IPython.utils.log import get_logger | |
|
30 | 31 | from IPython.external.decorator import decorator |
|
31 | 32 | |
|
32 | 33 | from IPython.config.application import Application |
|
33 | 34 | from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips |
|
34 | 35 | from IPython.utils.py3compat import string_types, iteritems, itervalues |
|
35 | 36 | from IPython.kernel.zmq.log import EnginePUBHandler |
|
36 | 37 | |
|
37 | 38 | |
|
38 | 39 | #----------------------------------------------------------------------------- |
|
39 | 40 | # Classes |
|
40 | 41 | #----------------------------------------------------------------------------- |
|
41 | 42 | |
|
42 | 43 | class Namespace(dict): |
|
43 | 44 | """Subclass of dict for attribute access to keys.""" |
|
44 | 45 | |
|
45 | 46 | def __getattr__(self, key): |
|
46 | 47 | """getattr aliased to getitem""" |
|
47 | 48 | if key in self: |
|
48 | 49 | return self[key] |
|
49 | 50 | else: |
|
50 | 51 | raise NameError(key) |
|
51 | 52 | |
|
52 | 53 | def __setattr__(self, key, value): |
|
53 | 54 | """setattr aliased to setitem, with strict""" |
|
54 | 55 | if hasattr(dict, key): |
|
55 | 56 | raise KeyError("Cannot override dict keys %r"%key) |
|
56 | 57 | self[key] = value |
|
57 | 58 | |
|
58 | 59 | |
|
59 | 60 | class ReverseDict(dict): |
|
60 | 61 | """simple double-keyed subset of dict methods.""" |
|
61 | 62 | |
|
62 | 63 | def __init__(self, *args, **kwargs): |
|
63 | 64 | dict.__init__(self, *args, **kwargs) |
|
64 | 65 | self._reverse = dict() |
|
65 | 66 | for key, value in iteritems(self): |
|
66 | 67 | self._reverse[value] = key |
|
67 | 68 | |
|
68 | 69 | def __getitem__(self, key): |
|
69 | 70 | try: |
|
70 | 71 | return dict.__getitem__(self, key) |
|
71 | 72 | except KeyError: |
|
72 | 73 | return self._reverse[key] |
|
73 | 74 | |
|
74 | 75 | def __setitem__(self, key, value): |
|
75 | 76 | if key in self._reverse: |
|
76 | 77 | raise KeyError("Can't have key %r on both sides!"%key) |
|
77 | 78 | dict.__setitem__(self, key, value) |
|
78 | 79 | self._reverse[value] = key |
|
79 | 80 | |
|
80 | 81 | def pop(self, key): |
|
81 | 82 | value = dict.pop(self, key) |
|
82 | 83 | self._reverse.pop(value) |
|
83 | 84 | return value |
|
84 | 85 | |
|
85 | 86 | def get(self, key, default=None): |
|
86 | 87 | try: |
|
87 | 88 | return self[key] |
|
88 | 89 | except KeyError: |
|
89 | 90 | return default |
|
90 | 91 | |
|
91 | 92 | #----------------------------------------------------------------------------- |
|
92 | 93 | # Functions |
|
93 | 94 | #----------------------------------------------------------------------------- |
|
94 | 95 | |
|
95 | 96 | @decorator |
|
96 | 97 | def log_errors(f, self, *args, **kwargs): |
|
97 | 98 | """decorator to log unhandled exceptions raised in a method. |
|
98 | 99 | |
|
99 | 100 | For use wrapping on_recv callbacks, so that exceptions |
|
100 | 101 | do not cause the stream to be closed. |
|
101 | 102 | """ |
|
102 | 103 | try: |
|
103 | 104 | return f(self, *args, **kwargs) |
|
104 | 105 | except Exception: |
|
105 | 106 | self.log.error("Uncaught exception in %r" % f, exc_info=True) |
|
106 | 107 | |
|
107 | 108 | |
|
108 | 109 | def is_url(url): |
|
109 | 110 | """boolean check for whether a string is a zmq url""" |
|
110 | 111 | if '://' not in url: |
|
111 | 112 | return False |
|
112 | 113 | proto, addr = url.split('://', 1) |
|
113 | 114 | if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']: |
|
114 | 115 | return False |
|
115 | 116 | return True |
|
116 | 117 | |
|
117 | 118 | def validate_url(url): |
|
118 | 119 | """validate a url for zeromq""" |
|
119 | 120 | if not isinstance(url, string_types): |
|
120 | 121 | raise TypeError("url must be a string, not %r"%type(url)) |
|
121 | 122 | url = url.lower() |
|
122 | 123 | |
|
123 | 124 | proto_addr = url.split('://') |
|
124 | 125 | assert len(proto_addr) == 2, 'Invalid url: %r'%url |
|
125 | 126 | proto, addr = proto_addr |
|
126 | 127 | assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto |
|
127 | 128 | |
|
128 | 129 | # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391 |
|
129 | 130 | # author: Remi Sabourin |
|
130 | 131 | pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$') |
|
131 | 132 | |
|
132 | 133 | if proto == 'tcp': |
|
133 | 134 | lis = addr.split(':') |
|
134 | 135 | assert len(lis) == 2, 'Invalid url: %r'%url |
|
135 | 136 | addr,s_port = lis |
|
136 | 137 | try: |
|
137 | 138 | port = int(s_port) |
|
138 | 139 | except ValueError: |
|
139 | 140 | raise AssertionError("Invalid port %r in url: %r"%(port, url)) |
|
140 | 141 | |
|
141 | 142 | assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url |
|
142 | 143 | |
|
143 | 144 | else: |
|
144 | 145 | # only validate tcp urls currently |
|
145 | 146 | pass |
|
146 | 147 | |
|
147 | 148 | return True |
|
148 | 149 | |
|
149 | 150 | |
|
150 | 151 | def validate_url_container(container): |
|
151 | 152 | """validate a potentially nested collection of urls.""" |
|
152 | 153 | if isinstance(container, string_types): |
|
153 | 154 | url = container |
|
154 | 155 | return validate_url(url) |
|
155 | 156 | elif isinstance(container, dict): |
|
156 | 157 | container = itervalues(container) |
|
157 | 158 | |
|
158 | 159 | for element in container: |
|
159 | 160 | validate_url_container(element) |
|
160 | 161 | |
|
161 | 162 | |
|
162 | 163 | def split_url(url): |
|
163 | 164 | """split a zmq url (tcp://ip:port) into ('tcp','ip','port').""" |
|
164 | 165 | proto_addr = url.split('://') |
|
165 | 166 | assert len(proto_addr) == 2, 'Invalid url: %r'%url |
|
166 | 167 | proto, addr = proto_addr |
|
167 | 168 | lis = addr.split(':') |
|
168 | 169 | assert len(lis) == 2, 'Invalid url: %r'%url |
|
169 | 170 | addr,s_port = lis |
|
170 | 171 | return proto,addr,s_port |
|
171 | 172 | |
|
172 | 173 | |
|
173 | 174 | def disambiguate_ip_address(ip, location=None): |
|
174 | 175 | """turn multi-ip interfaces '0.0.0.0' and '*' into a connectable address |
|
175 | 176 | |
|
176 | 177 | Explicit IP addresses are returned unmodified. |
|
177 | 178 | |
|
178 | 179 | Parameters |
|
179 | 180 | ---------- |
|
180 | 181 | |
|
181 | 182 | ip : IP address |
|
182 | 183 | An IP address, or the special values 0.0.0.0, or * |
|
183 | 184 | location: IP address, optional |
|
184 | 185 | A public IP of the target machine. |
|
185 | 186 | If location is an IP of the current machine, |
|
186 | 187 | localhost will be returned, |
|
187 | 188 | otherwise location will be returned. |
|
188 | 189 | """ |
|
189 | 190 | if ip in {'0.0.0.0', '*'}: |
|
190 | 191 | if not location: |
|
191 | 192 | # unspecified location, localhost is the only choice |
|
192 | 193 | ip = localhost() |
|
193 | 194 | elif is_public_ip(location): |
|
194 | 195 | # location is a public IP on this machine, use localhost |
|
195 | 196 | ip = localhost() |
|
196 | 197 | elif not public_ips(): |
|
197 | 198 | # this machine's public IPs cannot be determined, |
|
198 | 199 | # assume `location` is not this machine |
|
199 | 200 | warnings.warn("IPython could not determine public IPs", RuntimeWarning) |
|
200 | 201 | ip = location |
|
201 | 202 | else: |
|
202 | 203 | # location is not this machine, do not use loopback |
|
203 | 204 | ip = location |
|
204 | 205 | return ip |
|
205 | 206 | |
|
206 | 207 | |
|
207 | 208 | def disambiguate_url(url, location=None): |
|
208 | 209 | """turn multi-ip interfaces '0.0.0.0' and '*' into connectable |
|
209 | 210 | ones, based on the location (default interpretation is localhost). |
|
210 | 211 | |
|
211 | 212 | This is for zeromq urls, such as ``tcp://*:10101``. |
|
212 | 213 | """ |
|
213 | 214 | try: |
|
214 | 215 | proto,ip,port = split_url(url) |
|
215 | 216 | except AssertionError: |
|
216 | 217 | # probably not tcp url; could be ipc, etc. |
|
217 | 218 | return url |
|
218 | 219 | |
|
219 | 220 | ip = disambiguate_ip_address(ip,location) |
|
220 | 221 | |
|
221 | 222 | return "%s://%s:%s"%(proto,ip,port) |
|
222 | 223 | |
|
223 | 224 | |
|
224 | 225 | #-------------------------------------------------------------------------- |
|
225 | 226 | # helpers for implementing old MEC API via view.apply |
|
226 | 227 | #-------------------------------------------------------------------------- |
|
227 | 228 | |
|
228 | 229 | def interactive(f): |
|
229 | 230 | """decorator for making functions appear as interactively defined. |
|
230 | 231 | This results in the function being linked to the user_ns as globals() |
|
231 | 232 | instead of the module globals(). |
|
232 | 233 | """ |
|
233 | 234 | |
|
234 | 235 | # build new FunctionType, so it can have the right globals |
|
235 | 236 | # interactive functions never have closures, that's kind of the point |
|
236 | 237 | if isinstance(f, FunctionType): |
|
237 | 238 | mainmod = __import__('__main__') |
|
238 | 239 | f = FunctionType(f.__code__, mainmod.__dict__, |
|
239 | 240 | f.__name__, f.__defaults__, |
|
240 | 241 | ) |
|
241 | 242 | # associate with __main__ for uncanning |
|
242 | 243 | f.__module__ = '__main__' |
|
243 | 244 | return f |
|
244 | 245 | |
|
245 | 246 | @interactive |
|
246 | 247 | def _push(**ns): |
|
247 | 248 | """helper method for implementing `client.push` via `client.apply`""" |
|
248 | 249 | user_ns = globals() |
|
249 | 250 | tmp = '_IP_PUSH_TMP_' |
|
250 | 251 | while tmp in user_ns: |
|
251 | 252 | tmp = tmp + '_' |
|
252 | 253 | try: |
|
253 | 254 | for name, value in ns.items(): |
|
254 | 255 | user_ns[tmp] = value |
|
255 | 256 | exec("%s = %s" % (name, tmp), user_ns) |
|
256 | 257 | finally: |
|
257 | 258 | user_ns.pop(tmp, None) |
|
258 | 259 | |
|
259 | 260 | @interactive |
|
260 | 261 | def _pull(keys): |
|
261 | 262 | """helper method for implementing `client.pull` via `client.apply`""" |
|
262 | 263 | if isinstance(keys, (list,tuple, set)): |
|
263 | 264 | return [eval(key, globals()) for key in keys] |
|
264 | 265 | else: |
|
265 | 266 | return eval(keys, globals()) |
|
266 | 267 | |
|
267 | 268 | @interactive |
|
268 | 269 | def _execute(code): |
|
269 | 270 | """helper method for implementing `client.execute` via `client.apply`""" |
|
270 | 271 | exec(code, globals()) |
|
271 | 272 | |
|
272 | 273 | #-------------------------------------------------------------------------- |
|
273 | 274 | # extra process management utilities |
|
274 | 275 | #-------------------------------------------------------------------------- |
|
275 | 276 | |
|
276 | 277 | _random_ports = set() |
|
277 | 278 | |
|
278 | 279 | def select_random_ports(n): |
|
279 | 280 | """Selects and return n random ports that are available.""" |
|
280 | 281 | ports = [] |
|
281 | 282 | for i in range(n): |
|
282 | 283 | sock = socket.socket() |
|
283 | 284 | sock.bind(('', 0)) |
|
284 | 285 | while sock.getsockname()[1] in _random_ports: |
|
285 | 286 | sock.close() |
|
286 | 287 | sock = socket.socket() |
|
287 | 288 | sock.bind(('', 0)) |
|
288 | 289 | ports.append(sock) |
|
289 | 290 | for i, sock in enumerate(ports): |
|
290 | 291 | port = sock.getsockname()[1] |
|
291 | 292 | sock.close() |
|
292 | 293 | ports[i] = port |
|
293 | 294 | _random_ports.add(port) |
|
294 | 295 | return ports |
|
295 | 296 | |
|
296 | 297 | def signal_children(children): |
|
297 | 298 | """Relay interupt/term signals to children, for more solid process cleanup.""" |
|
298 | 299 | def terminate_children(sig, frame): |
|
299 | log = Application.instance().log | |
|
300 | log = get_logger() | |
|
300 | 301 | log.critical("Got signal %i, terminating children..."%sig) |
|
301 | 302 | for child in children: |
|
302 | 303 | child.terminate() |
|
303 | 304 | |
|
304 | 305 | sys.exit(sig != SIGINT) |
|
305 | 306 | # sys.exit(sig) |
|
306 | 307 | for sig in (SIGINT, SIGABRT, SIGTERM): |
|
307 | 308 | signal(sig, terminate_children) |
|
308 | 309 | |
|
309 | 310 | def generate_exec_key(keyfile): |
|
310 | 311 | import uuid |
|
311 | 312 | newkey = str(uuid.uuid4()) |
|
312 | 313 | with open(keyfile, 'w') as f: |
|
313 | 314 | # f.write('ipython-key ') |
|
314 | 315 | f.write(newkey+'\n') |
|
315 | 316 | # set user-only RW permissions (0600) |
|
316 | 317 | # this will have no effect on Windows |
|
317 | 318 | os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR) |
|
318 | 319 | |
|
319 | 320 | |
|
320 | 321 | def integer_loglevel(loglevel): |
|
321 | 322 | try: |
|
322 | 323 | loglevel = int(loglevel) |
|
323 | 324 | except ValueError: |
|
324 | 325 | if isinstance(loglevel, str): |
|
325 | 326 | loglevel = getattr(logging, loglevel) |
|
326 | 327 | return loglevel |
|
327 | 328 | |
|
328 | 329 | def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG): |
|
329 | 330 | logger = logging.getLogger(logname) |
|
330 | 331 | if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]): |
|
331 | 332 | # don't add a second PUBHandler |
|
332 | 333 | return |
|
333 | 334 | loglevel = integer_loglevel(loglevel) |
|
334 | 335 | lsock = context.socket(zmq.PUB) |
|
335 | 336 | lsock.connect(iface) |
|
336 | 337 | handler = handlers.PUBHandler(lsock) |
|
337 | 338 | handler.setLevel(loglevel) |
|
338 | 339 | handler.root_topic = root |
|
339 | 340 | logger.addHandler(handler) |
|
340 | 341 | logger.setLevel(loglevel) |
|
341 | 342 | |
|
342 | 343 | def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG): |
|
343 | 344 | logger = logging.getLogger() |
|
344 | 345 | if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]): |
|
345 | 346 | # don't add a second PUBHandler |
|
346 | 347 | return |
|
347 | 348 | loglevel = integer_loglevel(loglevel) |
|
348 | 349 | lsock = context.socket(zmq.PUB) |
|
349 | 350 | lsock.connect(iface) |
|
350 | 351 | handler = EnginePUBHandler(engine, lsock) |
|
351 | 352 | handler.setLevel(loglevel) |
|
352 | 353 | logger.addHandler(handler) |
|
353 | 354 | logger.setLevel(loglevel) |
|
354 | 355 | return logger |
|
355 | 356 | |
|
356 | 357 | def local_logger(logname, loglevel=logging.DEBUG): |
|
357 | 358 | loglevel = integer_loglevel(loglevel) |
|
358 | 359 | logger = logging.getLogger(logname) |
|
359 | 360 | if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]): |
|
360 | 361 | # don't add a second StreamHandler |
|
361 | 362 | return |
|
362 | 363 | handler = logging.StreamHandler() |
|
363 | 364 | handler.setLevel(loglevel) |
|
364 | 365 | formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s", |
|
365 | 366 | datefmt="%Y-%m-%d %H:%M:%S") |
|
366 | 367 | handler.setFormatter(formatter) |
|
367 | 368 | |
|
368 | 369 | logger.addHandler(handler) |
|
369 | 370 | logger.setLevel(loglevel) |
|
370 | 371 | return logger |
|
371 | 372 | |
|
372 | 373 | def set_hwm(sock, hwm=0): |
|
373 | 374 | """set zmq High Water Mark on a socket |
|
374 | 375 | |
|
375 | 376 | in a way that always works for various pyzmq / libzmq versions. |
|
376 | 377 | """ |
|
377 | 378 | import zmq |
|
378 | 379 | |
|
379 | 380 | for key in ('HWM', 'SNDHWM', 'RCVHWM'): |
|
380 | 381 | opt = getattr(zmq, key, None) |
|
381 | 382 | if opt is None: |
|
382 | 383 | continue |
|
383 | 384 | try: |
|
384 | 385 | sock.setsockopt(opt, hwm) |
|
385 | 386 | except zmq.ZMQError: |
|
386 | 387 | pass |
|
387 | 388 | |
|
388 | 389 |
@@ -1,433 +1,420 b'' | |||
|
1 | 1 | # encoding: utf-8 |
|
2 | 2 | """Pickle related utilities. Perhaps this should be called 'can'.""" |
|
3 | 3 | |
|
4 | 4 | # Copyright (c) IPython Development Team. |
|
5 | 5 | # Distributed under the terms of the Modified BSD License. |
|
6 | 6 | |
|
7 | 7 | import copy |
|
8 | 8 | import logging |
|
9 | 9 | import sys |
|
10 | 10 | from types import FunctionType |
|
11 | 11 | |
|
12 | 12 | try: |
|
13 | 13 | import cPickle as pickle |
|
14 | 14 | except ImportError: |
|
15 | 15 | import pickle |
|
16 | 16 | |
|
17 | 17 | from . import codeutil # This registers a hook when it's imported |
|
18 | 18 | from . import py3compat |
|
19 | 19 | from .importstring import import_item |
|
20 | 20 | from .py3compat import string_types, iteritems |
|
21 | 21 | |
|
22 | 22 | from IPython.config import Application |
|
23 | from IPython.utils.log import get_logger | |
|
23 | 24 | |
|
24 | 25 | if py3compat.PY3: |
|
25 | 26 | buffer = memoryview |
|
26 | 27 | class_type = type |
|
27 | 28 | else: |
|
28 | 29 | from types import ClassType |
|
29 | 30 | class_type = (type, ClassType) |
|
30 | 31 | |
|
31 | 32 | def _get_cell_type(a=None): |
|
32 | 33 | """the type of a closure cell doesn't seem to be importable, |
|
33 | 34 | so just create one |
|
34 | 35 | """ |
|
35 | 36 | def inner(): |
|
36 | 37 | return a |
|
37 | 38 | return type(py3compat.get_closure(inner)[0]) |
|
38 | 39 | |
|
39 | 40 | cell_type = _get_cell_type() |
|
40 | 41 | |
|
41 | 42 | #------------------------------------------------------------------------------- |
|
42 | 43 | # Functions |
|
43 | 44 | #------------------------------------------------------------------------------- |
|
44 | 45 | |
|
45 | 46 | |
|
46 | 47 | def use_dill(): |
|
47 | 48 | """use dill to expand serialization support |
|
48 | 49 | |
|
49 | 50 | adds support for object methods and closures to serialization. |
|
50 | 51 | """ |
|
51 | 52 | # import dill causes most of the magic |
|
52 | 53 | import dill |
|
53 | 54 | |
|
54 | 55 | # dill doesn't work with cPickle, |
|
55 | 56 | # tell the two relevant modules to use plain pickle |
|
56 | 57 | |
|
57 | 58 | global pickle |
|
58 | 59 | pickle = dill |
|
59 | 60 | |
|
60 | 61 | try: |
|
61 | 62 | from IPython.kernel.zmq import serialize |
|
62 | 63 | except ImportError: |
|
63 | 64 | pass |
|
64 | 65 | else: |
|
65 | 66 | serialize.pickle = dill |
|
66 | 67 | |
|
67 | 68 | # disable special function handling, let dill take care of it |
|
68 | 69 | can_map.pop(FunctionType, None) |
|
69 | 70 | |
|
70 | 71 | def use_cloudpickle(): |
|
71 | 72 | """use cloudpickle to expand serialization support |
|
72 | 73 | |
|
73 | 74 | adds support for object methods and closures to serialization. |
|
74 | 75 | """ |
|
75 | 76 | from cloud.serialization import cloudpickle |
|
76 | 77 | |
|
77 | 78 | global pickle |
|
78 | 79 | pickle = cloudpickle |
|
79 | 80 | |
|
80 | 81 | try: |
|
81 | 82 | from IPython.kernel.zmq import serialize |
|
82 | 83 | except ImportError: |
|
83 | 84 | pass |
|
84 | 85 | else: |
|
85 | 86 | serialize.pickle = cloudpickle |
|
86 | 87 | |
|
87 | 88 | # disable special function handling, let cloudpickle take care of it |
|
88 | 89 | can_map.pop(FunctionType, None) |
|
89 | 90 | |
|
90 | 91 | |
|
91 | 92 | #------------------------------------------------------------------------------- |
|
92 | 93 | # Classes |
|
93 | 94 | #------------------------------------------------------------------------------- |
|
94 | 95 | |
|
95 | 96 | |
|
96 | 97 | class CannedObject(object): |
|
97 | 98 | def __init__(self, obj, keys=[], hook=None): |
|
98 | 99 | """can an object for safe pickling |
|
99 | 100 | |
|
100 | 101 | Parameters |
|
101 | 102 | ========== |
|
102 | 103 | |
|
103 | 104 | obj: |
|
104 | 105 | The object to be canned |
|
105 | 106 | keys: list (optional) |
|
106 | 107 | list of attribute names that will be explicitly canned / uncanned |
|
107 | 108 | hook: callable (optional) |
|
108 | 109 | An optional extra callable, |
|
109 | 110 | which can do additional processing of the uncanned object. |
|
110 | 111 | |
|
111 | 112 | large data may be offloaded into the buffers list, |
|
112 | 113 | used for zero-copy transfers. |
|
113 | 114 | """ |
|
114 | 115 | self.keys = keys |
|
115 | 116 | self.obj = copy.copy(obj) |
|
116 | 117 | self.hook = can(hook) |
|
117 | 118 | for key in keys: |
|
118 | 119 | setattr(self.obj, key, can(getattr(obj, key))) |
|
119 | 120 | |
|
120 | 121 | self.buffers = [] |
|
121 | 122 | |
|
122 | 123 | def get_object(self, g=None): |
|
123 | 124 | if g is None: |
|
124 | 125 | g = {} |
|
125 | 126 | obj = self.obj |
|
126 | 127 | for key in self.keys: |
|
127 | 128 | setattr(obj, key, uncan(getattr(obj, key), g)) |
|
128 | 129 | |
|
129 | 130 | if self.hook: |
|
130 | 131 | self.hook = uncan(self.hook, g) |
|
131 | 132 | self.hook(obj, g) |
|
132 | 133 | return self.obj |
|
133 | 134 | |
|
134 | 135 | |
|
135 | 136 | class Reference(CannedObject): |
|
136 | 137 | """object for wrapping a remote reference by name.""" |
|
137 | 138 | def __init__(self, name): |
|
138 | 139 | if not isinstance(name, string_types): |
|
139 | 140 | raise TypeError("illegal name: %r"%name) |
|
140 | 141 | self.name = name |
|
141 | 142 | self.buffers = [] |
|
142 | 143 | |
|
143 | 144 | def __repr__(self): |
|
144 | 145 | return "<Reference: %r>"%self.name |
|
145 | 146 | |
|
146 | 147 | def get_object(self, g=None): |
|
147 | 148 | if g is None: |
|
148 | 149 | g = {} |
|
149 | 150 | |
|
150 | 151 | return eval(self.name, g) |
|
151 | 152 | |
|
152 | 153 | |
|
153 | 154 | class CannedCell(CannedObject): |
|
154 | 155 | """Can a closure cell""" |
|
155 | 156 | def __init__(self, cell): |
|
156 | 157 | self.cell_contents = can(cell.cell_contents) |
|
157 | 158 | |
|
158 | 159 | def get_object(self, g=None): |
|
159 | 160 | cell_contents = uncan(self.cell_contents, g) |
|
160 | 161 | def inner(): |
|
161 | 162 | return cell_contents |
|
162 | 163 | return py3compat.get_closure(inner)[0] |
|
163 | 164 | |
|
164 | 165 | |
|
165 | 166 | class CannedFunction(CannedObject): |
|
166 | 167 | |
|
167 | 168 | def __init__(self, f): |
|
168 | 169 | self._check_type(f) |
|
169 | 170 | self.code = f.__code__ |
|
170 | 171 | if f.__defaults__: |
|
171 | 172 | self.defaults = [ can(fd) for fd in f.__defaults__ ] |
|
172 | 173 | else: |
|
173 | 174 | self.defaults = None |
|
174 | 175 | |
|
175 | 176 | closure = py3compat.get_closure(f) |
|
176 | 177 | if closure: |
|
177 | 178 | self.closure = tuple( can(cell) for cell in closure ) |
|
178 | 179 | else: |
|
179 | 180 | self.closure = None |
|
180 | 181 | |
|
181 | 182 | self.module = f.__module__ or '__main__' |
|
182 | 183 | self.__name__ = f.__name__ |
|
183 | 184 | self.buffers = [] |
|
184 | 185 | |
|
185 | 186 | def _check_type(self, obj): |
|
186 | 187 | assert isinstance(obj, FunctionType), "Not a function type" |
|
187 | 188 | |
|
188 | 189 | def get_object(self, g=None): |
|
189 | 190 | # try to load function back into its module: |
|
190 | 191 | if not self.module.startswith('__'): |
|
191 | 192 | __import__(self.module) |
|
192 | 193 | g = sys.modules[self.module].__dict__ |
|
193 | 194 | |
|
194 | 195 | if g is None: |
|
195 | 196 | g = {} |
|
196 | 197 | if self.defaults: |
|
197 | 198 | defaults = tuple(uncan(cfd, g) for cfd in self.defaults) |
|
198 | 199 | else: |
|
199 | 200 | defaults = None |
|
200 | 201 | if self.closure: |
|
201 | 202 | closure = tuple(uncan(cell, g) for cell in self.closure) |
|
202 | 203 | else: |
|
203 | 204 | closure = None |
|
204 | 205 | newFunc = FunctionType(self.code, g, self.__name__, defaults, closure) |
|
205 | 206 | return newFunc |
|
206 | 207 | |
|
207 | 208 | class CannedClass(CannedObject): |
|
208 | 209 | |
|
209 | 210 | def __init__(self, cls): |
|
210 | 211 | self._check_type(cls) |
|
211 | 212 | self.name = cls.__name__ |
|
212 | 213 | self.old_style = not isinstance(cls, type) |
|
213 | 214 | self._canned_dict = {} |
|
214 | 215 | for k,v in cls.__dict__.items(): |
|
215 | 216 | if k not in ('__weakref__', '__dict__'): |
|
216 | 217 | self._canned_dict[k] = can(v) |
|
217 | 218 | if self.old_style: |
|
218 | 219 | mro = [] |
|
219 | 220 | else: |
|
220 | 221 | mro = cls.mro() |
|
221 | 222 | |
|
222 | 223 | self.parents = [ can(c) for c in mro[1:] ] |
|
223 | 224 | self.buffers = [] |
|
224 | 225 | |
|
225 | 226 | def _check_type(self, obj): |
|
226 | 227 | assert isinstance(obj, class_type), "Not a class type" |
|
227 | 228 | |
|
228 | 229 | def get_object(self, g=None): |
|
229 | 230 | parents = tuple(uncan(p, g) for p in self.parents) |
|
230 | 231 | return type(self.name, parents, uncan_dict(self._canned_dict, g=g)) |
|
231 | 232 | |
|
232 | 233 | class CannedArray(CannedObject): |
|
233 | 234 | def __init__(self, obj): |
|
234 | 235 | from numpy import ascontiguousarray |
|
235 | 236 | self.shape = obj.shape |
|
236 | 237 | self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str |
|
237 | 238 | self.pickled = False |
|
238 | 239 | if sum(obj.shape) == 0: |
|
239 | 240 | self.pickled = True |
|
240 | 241 | elif obj.dtype == 'O': |
|
241 | 242 | # can't handle object dtype with buffer approach |
|
242 | 243 | self.pickled = True |
|
243 | 244 | elif obj.dtype.fields and any(dt == 'O' for dt,sz in obj.dtype.fields.values()): |
|
244 | 245 | self.pickled = True |
|
245 | 246 | if self.pickled: |
|
246 | 247 | # just pickle it |
|
247 | 248 | self.buffers = [pickle.dumps(obj, -1)] |
|
248 | 249 | else: |
|
249 | 250 | # ensure contiguous |
|
250 | 251 | obj = ascontiguousarray(obj, dtype=None) |
|
251 | 252 | self.buffers = [buffer(obj)] |
|
252 | 253 | |
|
253 | 254 | def get_object(self, g=None): |
|
254 | 255 | from numpy import frombuffer |
|
255 | 256 | data = self.buffers[0] |
|
256 | 257 | if self.pickled: |
|
257 | 258 | # no shape, we just pickled it |
|
258 | 259 | return pickle.loads(data) |
|
259 | 260 | else: |
|
260 | 261 | return frombuffer(data, dtype=self.dtype).reshape(self.shape) |
|
261 | 262 | |
|
262 | 263 | |
|
263 | 264 | class CannedBytes(CannedObject): |
|
264 | 265 | wrap = bytes |
|
265 | 266 | def __init__(self, obj): |
|
266 | 267 | self.buffers = [obj] |
|
267 | 268 | |
|
268 | 269 | def get_object(self, g=None): |
|
269 | 270 | data = self.buffers[0] |
|
270 | 271 | return self.wrap(data) |
|
271 | 272 | |
|
272 | 273 | def CannedBuffer(CannedBytes): |
|
273 | 274 | wrap = buffer |
|
274 | 275 | |
|
275 | 276 | #------------------------------------------------------------------------------- |
|
276 | 277 | # Functions |
|
277 | 278 | #------------------------------------------------------------------------------- |
|
278 | 279 | |
|
279 | def _logger(): | |
|
280 | """get the logger for the current Application | |
|
281 | ||
|
282 | the root logger will be used if no Application is running | |
|
283 | """ | |
|
284 | if Application.initialized(): | |
|
285 | logger = Application.instance().log | |
|
286 | else: | |
|
287 | logger = logging.getLogger() | |
|
288 | if not logger.handlers: | |
|
289 | logging.basicConfig() | |
|
290 | ||
|
291 | return logger | |
|
292 | ||
|
293 | 280 | def _import_mapping(mapping, original=None): |
|
294 | 281 | """import any string-keys in a type mapping |
|
295 | 282 | |
|
296 | 283 | """ |
|
297 | log = _logger() | |
|
284 | log = get_logger() | |
|
298 | 285 | log.debug("Importing canning map") |
|
299 | 286 | for key,value in list(mapping.items()): |
|
300 | 287 | if isinstance(key, string_types): |
|
301 | 288 | try: |
|
302 | 289 | cls = import_item(key) |
|
303 | 290 | except Exception: |
|
304 | 291 | if original and key not in original: |
|
305 | 292 | # only message on user-added classes |
|
306 | 293 | log.error("canning class not importable: %r", key, exc_info=True) |
|
307 | 294 | mapping.pop(key) |
|
308 | 295 | else: |
|
309 | 296 | mapping[cls] = mapping.pop(key) |
|
310 | 297 | |
|
311 | 298 | def istype(obj, check): |
|
312 | 299 | """like isinstance(obj, check), but strict |
|
313 | 300 | |
|
314 | 301 | This won't catch subclasses. |
|
315 | 302 | """ |
|
316 | 303 | if isinstance(check, tuple): |
|
317 | 304 | for cls in check: |
|
318 | 305 | if type(obj) is cls: |
|
319 | 306 | return True |
|
320 | 307 | return False |
|
321 | 308 | else: |
|
322 | 309 | return type(obj) is check |
|
323 | 310 | |
|
324 | 311 | def can(obj): |
|
325 | 312 | """prepare an object for pickling""" |
|
326 | 313 | |
|
327 | 314 | import_needed = False |
|
328 | 315 | |
|
329 | 316 | for cls,canner in iteritems(can_map): |
|
330 | 317 | if isinstance(cls, string_types): |
|
331 | 318 | import_needed = True |
|
332 | 319 | break |
|
333 | 320 | elif istype(obj, cls): |
|
334 | 321 | return canner(obj) |
|
335 | 322 | |
|
336 | 323 | if import_needed: |
|
337 | 324 | # perform can_map imports, then try again |
|
338 | 325 | # this will usually only happen once |
|
339 | 326 | _import_mapping(can_map, _original_can_map) |
|
340 | 327 | return can(obj) |
|
341 | 328 | |
|
342 | 329 | return obj |
|
343 | 330 | |
|
344 | 331 | def can_class(obj): |
|
345 | 332 | if isinstance(obj, class_type) and obj.__module__ == '__main__': |
|
346 | 333 | return CannedClass(obj) |
|
347 | 334 | else: |
|
348 | 335 | return obj |
|
349 | 336 | |
|
350 | 337 | def can_dict(obj): |
|
351 | 338 | """can the *values* of a dict""" |
|
352 | 339 | if istype(obj, dict): |
|
353 | 340 | newobj = {} |
|
354 | 341 | for k, v in iteritems(obj): |
|
355 | 342 | newobj[k] = can(v) |
|
356 | 343 | return newobj |
|
357 | 344 | else: |
|
358 | 345 | return obj |
|
359 | 346 | |
|
360 | 347 | sequence_types = (list, tuple, set) |
|
361 | 348 | |
|
362 | 349 | def can_sequence(obj): |
|
363 | 350 | """can the elements of a sequence""" |
|
364 | 351 | if istype(obj, sequence_types): |
|
365 | 352 | t = type(obj) |
|
366 | 353 | return t([can(i) for i in obj]) |
|
367 | 354 | else: |
|
368 | 355 | return obj |
|
369 | 356 | |
|
370 | 357 | def uncan(obj, g=None): |
|
371 | 358 | """invert canning""" |
|
372 | 359 | |
|
373 | 360 | import_needed = False |
|
374 | 361 | for cls,uncanner in iteritems(uncan_map): |
|
375 | 362 | if isinstance(cls, string_types): |
|
376 | 363 | import_needed = True |
|
377 | 364 | break |
|
378 | 365 | elif isinstance(obj, cls): |
|
379 | 366 | return uncanner(obj, g) |
|
380 | 367 | |
|
381 | 368 | if import_needed: |
|
382 | 369 | # perform uncan_map imports, then try again |
|
383 | 370 | # this will usually only happen once |
|
384 | 371 | _import_mapping(uncan_map, _original_uncan_map) |
|
385 | 372 | return uncan(obj, g) |
|
386 | 373 | |
|
387 | 374 | return obj |
|
388 | 375 | |
|
389 | 376 | def uncan_dict(obj, g=None): |
|
390 | 377 | if istype(obj, dict): |
|
391 | 378 | newobj = {} |
|
392 | 379 | for k, v in iteritems(obj): |
|
393 | 380 | newobj[k] = uncan(v,g) |
|
394 | 381 | return newobj |
|
395 | 382 | else: |
|
396 | 383 | return obj |
|
397 | 384 | |
|
398 | 385 | def uncan_sequence(obj, g=None): |
|
399 | 386 | if istype(obj, sequence_types): |
|
400 | 387 | t = type(obj) |
|
401 | 388 | return t([uncan(i,g) for i in obj]) |
|
402 | 389 | else: |
|
403 | 390 | return obj |
|
404 | 391 | |
|
405 | 392 | def _uncan_dependent_hook(dep, g=None): |
|
406 | 393 | dep.check_dependency() |
|
407 | 394 | |
|
408 | 395 | def can_dependent(obj): |
|
409 | 396 | return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook) |
|
410 | 397 | |
|
411 | 398 | #------------------------------------------------------------------------------- |
|
412 | 399 | # API dictionaries |
|
413 | 400 | #------------------------------------------------------------------------------- |
|
414 | 401 | |
|
415 | 402 | # These dicts can be extended for custom serialization of new objects |
|
416 | 403 | |
|
417 | 404 | can_map = { |
|
418 | 405 | 'IPython.parallel.dependent' : can_dependent, |
|
419 | 406 | 'numpy.ndarray' : CannedArray, |
|
420 | 407 | FunctionType : CannedFunction, |
|
421 | 408 | bytes : CannedBytes, |
|
422 | 409 | buffer : CannedBuffer, |
|
423 | 410 | cell_type : CannedCell, |
|
424 | 411 | class_type : can_class, |
|
425 | 412 | } |
|
426 | 413 | |
|
427 | 414 | uncan_map = { |
|
428 | 415 | CannedObject : lambda obj, g: obj.get_object(g), |
|
429 | 416 | } |
|
430 | 417 | |
|
431 | 418 | # for use in _import_mapping: |
|
432 | 419 | _original_can_map = can_map.copy() |
|
433 | 420 | _original_uncan_map = uncan_map.copy() |
General Comments 0
You need to be logged in to leave comments.
Login now