Show More
@@ -0,0 +1,25 b'' | |||||
|
1 | """Grab the global logger instance.""" | |||
|
2 | ||||
|
3 | # Copyright (c) IPython Development Team. | |||
|
4 | # Distributed under the terms of the Modified BSD License. | |||
|
5 | ||||
|
6 | import logging | |||
|
7 | ||||
|
8 | _logger = None | |||
|
9 | ||||
|
10 | def get_logger(): | |||
|
11 | """Grab the global logger instance. | |||
|
12 | ||||
|
13 | If a global IPython Application is instantiated, grab its logger. | |||
|
14 | Otherwise, grab the root logger. | |||
|
15 | """ | |||
|
16 | global _logger | |||
|
17 | ||||
|
18 | if _logger is None: | |||
|
19 | from IPython.config import Application | |||
|
20 | if Application.initialized(): | |||
|
21 | _logger = Application.instance().log | |||
|
22 | else: | |||
|
23 | logging.basicConfig() | |||
|
24 | _logger = logging.getLogger() | |||
|
25 | return _logger |
@@ -1,390 +1,366 b'' | |||||
1 | # encoding: utf-8 |
|
1 | # encoding: utf-8 | |
2 | """ |
|
2 | """A base class for objects that are configurable.""" | |
3 | A base class for objects that are configurable. |
|
|||
4 |
|
3 | |||
5 | Inheritance diagram: |
|
4 | # Copyright (c) IPython Development Team. | |
|
5 | # Distributed under the terms of the Modified BSD License. | |||
6 |
|
6 | |||
7 | .. inheritance-diagram:: IPython.config.configurable |
|
|||
8 | :parts: 3 |
|
|||
9 |
|
||||
10 | Authors: |
|
|||
11 |
|
||||
12 | * Brian Granger |
|
|||
13 | * Fernando Perez |
|
|||
14 | * Min RK |
|
|||
15 | """ |
|
|||
16 | from __future__ import print_function |
|
7 | from __future__ import print_function | |
17 |
|
8 | |||
18 | #----------------------------------------------------------------------------- |
|
|||
19 | # Copyright (C) 2008-2011 The IPython Development Team |
|
|||
20 | # |
|
|||
21 | # Distributed under the terms of the BSD License. The full license is in |
|
|||
22 | # the file COPYING, distributed as part of this software. |
|
|||
23 | #----------------------------------------------------------------------------- |
|
|||
24 |
|
||||
25 | #----------------------------------------------------------------------------- |
|
|||
26 | # Imports |
|
|||
27 | #----------------------------------------------------------------------------- |
|
|||
28 |
|
||||
29 | import logging |
|
9 | import logging | |
30 | from copy import deepcopy |
|
10 | from copy import deepcopy | |
31 |
|
11 | |||
32 | from .loader import Config, LazyConfigValue |
|
12 | from .loader import Config, LazyConfigValue | |
33 | from IPython.utils.traitlets import HasTraits, Instance |
|
13 | from IPython.utils.traitlets import HasTraits, Instance | |
34 | from IPython.utils.text import indent, wrap_paragraphs |
|
14 | from IPython.utils.text import indent, wrap_paragraphs | |
35 | from IPython.utils.py3compat import iteritems |
|
15 | from IPython.utils.py3compat import iteritems | |
36 |
|
16 | |||
37 |
|
17 | |||
38 | #----------------------------------------------------------------------------- |
|
18 | #----------------------------------------------------------------------------- | |
39 | # Helper classes for Configurables |
|
19 | # Helper classes for Configurables | |
40 | #----------------------------------------------------------------------------- |
|
20 | #----------------------------------------------------------------------------- | |
41 |
|
21 | |||
42 |
|
22 | |||
43 | class ConfigurableError(Exception): |
|
23 | class ConfigurableError(Exception): | |
44 | pass |
|
24 | pass | |
45 |
|
25 | |||
46 |
|
26 | |||
47 | class MultipleInstanceError(ConfigurableError): |
|
27 | class MultipleInstanceError(ConfigurableError): | |
48 | pass |
|
28 | pass | |
49 |
|
29 | |||
50 | #----------------------------------------------------------------------------- |
|
30 | #----------------------------------------------------------------------------- | |
51 | # Configurable implementation |
|
31 | # Configurable implementation | |
52 | #----------------------------------------------------------------------------- |
|
32 | #----------------------------------------------------------------------------- | |
53 |
|
33 | |||
54 | class Configurable(HasTraits): |
|
34 | class Configurable(HasTraits): | |
55 |
|
35 | |||
56 | config = Instance(Config, (), {}) |
|
36 | config = Instance(Config, (), {}) | |
57 | parent = Instance('IPython.config.configurable.Configurable') |
|
37 | parent = Instance('IPython.config.configurable.Configurable') | |
58 |
|
38 | |||
59 | def __init__(self, **kwargs): |
|
39 | def __init__(self, **kwargs): | |
60 | """Create a configurable given a config config. |
|
40 | """Create a configurable given a config config. | |
61 |
|
41 | |||
62 | Parameters |
|
42 | Parameters | |
63 | ---------- |
|
43 | ---------- | |
64 | config : Config |
|
44 | config : Config | |
65 | If this is empty, default values are used. If config is a |
|
45 | If this is empty, default values are used. If config is a | |
66 | :class:`Config` instance, it will be used to configure the |
|
46 | :class:`Config` instance, it will be used to configure the | |
67 | instance. |
|
47 | instance. | |
68 | parent : Configurable instance, optional |
|
48 | parent : Configurable instance, optional | |
69 | The parent Configurable instance of this object. |
|
49 | The parent Configurable instance of this object. | |
70 |
|
50 | |||
71 | Notes |
|
51 | Notes | |
72 | ----- |
|
52 | ----- | |
73 | Subclasses of Configurable must call the :meth:`__init__` method of |
|
53 | Subclasses of Configurable must call the :meth:`__init__` method of | |
74 | :class:`Configurable` *before* doing anything else and using |
|
54 | :class:`Configurable` *before* doing anything else and using | |
75 | :func:`super`:: |
|
55 | :func:`super`:: | |
76 |
|
56 | |||
77 | class MyConfigurable(Configurable): |
|
57 | class MyConfigurable(Configurable): | |
78 | def __init__(self, config=None): |
|
58 | def __init__(self, config=None): | |
79 | super(MyConfigurable, self).__init__(config=config) |
|
59 | super(MyConfigurable, self).__init__(config=config) | |
80 | # Then any other code you need to finish initialization. |
|
60 | # Then any other code you need to finish initialization. | |
81 |
|
61 | |||
82 | This ensures that instances will be configured properly. |
|
62 | This ensures that instances will be configured properly. | |
83 | """ |
|
63 | """ | |
84 | parent = kwargs.pop('parent', None) |
|
64 | parent = kwargs.pop('parent', None) | |
85 | if parent is not None: |
|
65 | if parent is not None: | |
86 | # config is implied from parent |
|
66 | # config is implied from parent | |
87 | if kwargs.get('config', None) is None: |
|
67 | if kwargs.get('config', None) is None: | |
88 | kwargs['config'] = parent.config |
|
68 | kwargs['config'] = parent.config | |
89 | self.parent = parent |
|
69 | self.parent = parent | |
90 |
|
70 | |||
91 | config = kwargs.pop('config', None) |
|
71 | config = kwargs.pop('config', None) | |
92 | if config is not None: |
|
72 | if config is not None: | |
93 | # We used to deepcopy, but for now we are trying to just save |
|
73 | # We used to deepcopy, but for now we are trying to just save | |
94 | # by reference. This *could* have side effects as all components |
|
74 | # by reference. This *could* have side effects as all components | |
95 | # will share config. In fact, I did find such a side effect in |
|
75 | # will share config. In fact, I did find such a side effect in | |
96 | # _config_changed below. If a config attribute value was a mutable type |
|
76 | # _config_changed below. If a config attribute value was a mutable type | |
97 | # all instances of a component were getting the same copy, effectively |
|
77 | # all instances of a component were getting the same copy, effectively | |
98 | # making that a class attribute. |
|
78 | # making that a class attribute. | |
99 | # self.config = deepcopy(config) |
|
79 | # self.config = deepcopy(config) | |
100 | self.config = config |
|
80 | self.config = config | |
101 | # This should go second so individual keyword arguments override |
|
81 | # This should go second so individual keyword arguments override | |
102 | # the values in config. |
|
82 | # the values in config. | |
103 | super(Configurable, self).__init__(**kwargs) |
|
83 | super(Configurable, self).__init__(**kwargs) | |
104 |
|
84 | |||
105 | #------------------------------------------------------------------------- |
|
85 | #------------------------------------------------------------------------- | |
106 | # Static trait notifiations |
|
86 | # Static trait notifiations | |
107 | #------------------------------------------------------------------------- |
|
87 | #------------------------------------------------------------------------- | |
108 |
|
88 | |||
109 | @classmethod |
|
89 | @classmethod | |
110 | def section_names(cls): |
|
90 | def section_names(cls): | |
111 | """return section names as a list""" |
|
91 | """return section names as a list""" | |
112 | return [c.__name__ for c in reversed(cls.__mro__) if |
|
92 | return [c.__name__ for c in reversed(cls.__mro__) if | |
113 | issubclass(c, Configurable) and issubclass(cls, c) |
|
93 | issubclass(c, Configurable) and issubclass(cls, c) | |
114 | ] |
|
94 | ] | |
115 |
|
95 | |||
116 | def _find_my_config(self, cfg): |
|
96 | def _find_my_config(self, cfg): | |
117 | """extract my config from a global Config object |
|
97 | """extract my config from a global Config object | |
118 |
|
98 | |||
119 | will construct a Config object of only the config values that apply to me |
|
99 | will construct a Config object of only the config values that apply to me | |
120 | based on my mro(), as well as those of my parent(s) if they exist. |
|
100 | based on my mro(), as well as those of my parent(s) if they exist. | |
121 |
|
101 | |||
122 | If I am Bar and my parent is Foo, and their parent is Tim, |
|
102 | If I am Bar and my parent is Foo, and their parent is Tim, | |
123 | this will return merge following config sections, in this order:: |
|
103 | this will return merge following config sections, in this order:: | |
124 |
|
104 | |||
125 | [Bar, Foo.bar, Tim.Foo.Bar] |
|
105 | [Bar, Foo.bar, Tim.Foo.Bar] | |
126 |
|
106 | |||
127 | With the last item being the highest priority. |
|
107 | With the last item being the highest priority. | |
128 | """ |
|
108 | """ | |
129 | cfgs = [cfg] |
|
109 | cfgs = [cfg] | |
130 | if self.parent: |
|
110 | if self.parent: | |
131 | cfgs.append(self.parent._find_my_config(cfg)) |
|
111 | cfgs.append(self.parent._find_my_config(cfg)) | |
132 | my_config = Config() |
|
112 | my_config = Config() | |
133 | for c in cfgs: |
|
113 | for c in cfgs: | |
134 | for sname in self.section_names(): |
|
114 | for sname in self.section_names(): | |
135 | # Don't do a blind getattr as that would cause the config to |
|
115 | # Don't do a blind getattr as that would cause the config to | |
136 | # dynamically create the section with name Class.__name__. |
|
116 | # dynamically create the section with name Class.__name__. | |
137 | if c._has_section(sname): |
|
117 | if c._has_section(sname): | |
138 | my_config.merge(c[sname]) |
|
118 | my_config.merge(c[sname]) | |
139 | return my_config |
|
119 | return my_config | |
140 |
|
120 | |||
141 | def _load_config(self, cfg, section_names=None, traits=None): |
|
121 | def _load_config(self, cfg, section_names=None, traits=None): | |
142 | """load traits from a Config object""" |
|
122 | """load traits from a Config object""" | |
143 |
|
123 | |||
144 | if traits is None: |
|
124 | if traits is None: | |
145 | traits = self.traits(config=True) |
|
125 | traits = self.traits(config=True) | |
146 | if section_names is None: |
|
126 | if section_names is None: | |
147 | section_names = self.section_names() |
|
127 | section_names = self.section_names() | |
148 |
|
128 | |||
149 | my_config = self._find_my_config(cfg) |
|
129 | my_config = self._find_my_config(cfg) | |
150 | for name, config_value in iteritems(my_config): |
|
130 | for name, config_value in iteritems(my_config): | |
151 | if name in traits: |
|
131 | if name in traits: | |
152 | if isinstance(config_value, LazyConfigValue): |
|
132 | if isinstance(config_value, LazyConfigValue): | |
153 | # ConfigValue is a wrapper for using append / update on containers |
|
133 | # ConfigValue is a wrapper for using append / update on containers | |
154 | # without having to copy the |
|
134 | # without having to copy the | |
155 | initial = getattr(self, name) |
|
135 | initial = getattr(self, name) | |
156 | config_value = config_value.get_value(initial) |
|
136 | config_value = config_value.get_value(initial) | |
157 | # We have to do a deepcopy here if we don't deepcopy the entire |
|
137 | # We have to do a deepcopy here if we don't deepcopy the entire | |
158 | # config object. If we don't, a mutable config_value will be |
|
138 | # config object. If we don't, a mutable config_value will be | |
159 | # shared by all instances, effectively making it a class attribute. |
|
139 | # shared by all instances, effectively making it a class attribute. | |
160 | setattr(self, name, deepcopy(config_value)) |
|
140 | setattr(self, name, deepcopy(config_value)) | |
161 |
|
141 | |||
162 | def _config_changed(self, name, old, new): |
|
142 | def _config_changed(self, name, old, new): | |
163 | """Update all the class traits having ``config=True`` as metadata. |
|
143 | """Update all the class traits having ``config=True`` as metadata. | |
164 |
|
144 | |||
165 | For any class trait with a ``config`` metadata attribute that is |
|
145 | For any class trait with a ``config`` metadata attribute that is | |
166 | ``True``, we update the trait with the value of the corresponding |
|
146 | ``True``, we update the trait with the value of the corresponding | |
167 | config entry. |
|
147 | config entry. | |
168 | """ |
|
148 | """ | |
169 | # Get all traits with a config metadata entry that is True |
|
149 | # Get all traits with a config metadata entry that is True | |
170 | traits = self.traits(config=True) |
|
150 | traits = self.traits(config=True) | |
171 |
|
151 | |||
172 | # We auto-load config section for this class as well as any parent |
|
152 | # We auto-load config section for this class as well as any parent | |
173 | # classes that are Configurable subclasses. This starts with Configurable |
|
153 | # classes that are Configurable subclasses. This starts with Configurable | |
174 | # and works down the mro loading the config for each section. |
|
154 | # and works down the mro loading the config for each section. | |
175 | section_names = self.section_names() |
|
155 | section_names = self.section_names() | |
176 | self._load_config(new, traits=traits, section_names=section_names) |
|
156 | self._load_config(new, traits=traits, section_names=section_names) | |
177 |
|
157 | |||
178 | def update_config(self, config): |
|
158 | def update_config(self, config): | |
179 | """Fire the traits events when the config is updated.""" |
|
159 | """Fire the traits events when the config is updated.""" | |
180 | # Save a copy of the current config. |
|
160 | # Save a copy of the current config. | |
181 | newconfig = deepcopy(self.config) |
|
161 | newconfig = deepcopy(self.config) | |
182 | # Merge the new config into the current one. |
|
162 | # Merge the new config into the current one. | |
183 | newconfig.merge(config) |
|
163 | newconfig.merge(config) | |
184 | # Save the combined config as self.config, which triggers the traits |
|
164 | # Save the combined config as self.config, which triggers the traits | |
185 | # events. |
|
165 | # events. | |
186 | self.config = newconfig |
|
166 | self.config = newconfig | |
187 |
|
167 | |||
188 | @classmethod |
|
168 | @classmethod | |
189 | def class_get_help(cls, inst=None): |
|
169 | def class_get_help(cls, inst=None): | |
190 | """Get the help string for this class in ReST format. |
|
170 | """Get the help string for this class in ReST format. | |
191 |
|
171 | |||
192 | If `inst` is given, it's current trait values will be used in place of |
|
172 | If `inst` is given, it's current trait values will be used in place of | |
193 | class defaults. |
|
173 | class defaults. | |
194 | """ |
|
174 | """ | |
195 | assert inst is None or isinstance(inst, cls) |
|
175 | assert inst is None or isinstance(inst, cls) | |
196 | final_help = [] |
|
176 | final_help = [] | |
197 | final_help.append(u'%s options' % cls.__name__) |
|
177 | final_help.append(u'%s options' % cls.__name__) | |
198 | final_help.append(len(final_help[0])*u'-') |
|
178 | final_help.append(len(final_help[0])*u'-') | |
199 | for k, v in sorted(cls.class_traits(config=True).items()): |
|
179 | for k, v in sorted(cls.class_traits(config=True).items()): | |
200 | help = cls.class_get_trait_help(v, inst) |
|
180 | help = cls.class_get_trait_help(v, inst) | |
201 | final_help.append(help) |
|
181 | final_help.append(help) | |
202 | return '\n'.join(final_help) |
|
182 | return '\n'.join(final_help) | |
203 |
|
183 | |||
204 | @classmethod |
|
184 | @classmethod | |
205 | def class_get_trait_help(cls, trait, inst=None): |
|
185 | def class_get_trait_help(cls, trait, inst=None): | |
206 | """Get the help string for a single trait. |
|
186 | """Get the help string for a single trait. | |
207 |
|
187 | |||
208 | If `inst` is given, it's current trait values will be used in place of |
|
188 | If `inst` is given, it's current trait values will be used in place of | |
209 | the class default. |
|
189 | the class default. | |
210 | """ |
|
190 | """ | |
211 | assert inst is None or isinstance(inst, cls) |
|
191 | assert inst is None or isinstance(inst, cls) | |
212 | lines = [] |
|
192 | lines = [] | |
213 | header = "--%s.%s=<%s>" % (cls.__name__, trait.name, trait.__class__.__name__) |
|
193 | header = "--%s.%s=<%s>" % (cls.__name__, trait.name, trait.__class__.__name__) | |
214 | lines.append(header) |
|
194 | lines.append(header) | |
215 | if inst is not None: |
|
195 | if inst is not None: | |
216 | lines.append(indent('Current: %r' % getattr(inst, trait.name), 4)) |
|
196 | lines.append(indent('Current: %r' % getattr(inst, trait.name), 4)) | |
217 | else: |
|
197 | else: | |
218 | try: |
|
198 | try: | |
219 | dvr = repr(trait.get_default_value()) |
|
199 | dvr = repr(trait.get_default_value()) | |
220 | except Exception: |
|
200 | except Exception: | |
221 | dvr = None # ignore defaults we can't construct |
|
201 | dvr = None # ignore defaults we can't construct | |
222 | if dvr is not None: |
|
202 | if dvr is not None: | |
223 | if len(dvr) > 64: |
|
203 | if len(dvr) > 64: | |
224 | dvr = dvr[:61]+'...' |
|
204 | dvr = dvr[:61]+'...' | |
225 | lines.append(indent('Default: %s' % dvr, 4)) |
|
205 | lines.append(indent('Default: %s' % dvr, 4)) | |
226 | if 'Enum' in trait.__class__.__name__: |
|
206 | if 'Enum' in trait.__class__.__name__: | |
227 | # include Enum choices |
|
207 | # include Enum choices | |
228 | lines.append(indent('Choices: %r' % (trait.values,))) |
|
208 | lines.append(indent('Choices: %r' % (trait.values,))) | |
229 |
|
209 | |||
230 | help = trait.get_metadata('help') |
|
210 | help = trait.get_metadata('help') | |
231 | if help is not None: |
|
211 | if help is not None: | |
232 | help = '\n'.join(wrap_paragraphs(help, 76)) |
|
212 | help = '\n'.join(wrap_paragraphs(help, 76)) | |
233 | lines.append(indent(help, 4)) |
|
213 | lines.append(indent(help, 4)) | |
234 | return '\n'.join(lines) |
|
214 | return '\n'.join(lines) | |
235 |
|
215 | |||
236 | @classmethod |
|
216 | @classmethod | |
237 | def class_print_help(cls, inst=None): |
|
217 | def class_print_help(cls, inst=None): | |
238 | """Get the help string for a single trait and print it.""" |
|
218 | """Get the help string for a single trait and print it.""" | |
239 | print(cls.class_get_help(inst)) |
|
219 | print(cls.class_get_help(inst)) | |
240 |
|
220 | |||
241 | @classmethod |
|
221 | @classmethod | |
242 | def class_config_section(cls): |
|
222 | def class_config_section(cls): | |
243 | """Get the config class config section""" |
|
223 | """Get the config class config section""" | |
244 | def c(s): |
|
224 | def c(s): | |
245 | """return a commented, wrapped block.""" |
|
225 | """return a commented, wrapped block.""" | |
246 | s = '\n\n'.join(wrap_paragraphs(s, 78)) |
|
226 | s = '\n\n'.join(wrap_paragraphs(s, 78)) | |
247 |
|
227 | |||
248 | return '# ' + s.replace('\n', '\n# ') |
|
228 | return '# ' + s.replace('\n', '\n# ') | |
249 |
|
229 | |||
250 | # section header |
|
230 | # section header | |
251 | breaker = '#' + '-'*78 |
|
231 | breaker = '#' + '-'*78 | |
252 | s = "# %s configuration" % cls.__name__ |
|
232 | s = "# %s configuration" % cls.__name__ | |
253 | lines = [breaker, s, breaker, ''] |
|
233 | lines = [breaker, s, breaker, ''] | |
254 | # get the description trait |
|
234 | # get the description trait | |
255 | desc = cls.class_traits().get('description') |
|
235 | desc = cls.class_traits().get('description') | |
256 | if desc: |
|
236 | if desc: | |
257 | desc = desc.default_value |
|
237 | desc = desc.default_value | |
258 | else: |
|
238 | else: | |
259 | # no description trait, use __doc__ |
|
239 | # no description trait, use __doc__ | |
260 | desc = getattr(cls, '__doc__', '') |
|
240 | desc = getattr(cls, '__doc__', '') | |
261 | if desc: |
|
241 | if desc: | |
262 | lines.append(c(desc)) |
|
242 | lines.append(c(desc)) | |
263 | lines.append('') |
|
243 | lines.append('') | |
264 |
|
244 | |||
265 | parents = [] |
|
245 | parents = [] | |
266 | for parent in cls.mro(): |
|
246 | for parent in cls.mro(): | |
267 | # only include parents that are not base classes |
|
247 | # only include parents that are not base classes | |
268 | # and are not the class itself |
|
248 | # and are not the class itself | |
269 | # and have some configurable traits to inherit |
|
249 | # and have some configurable traits to inherit | |
270 | if parent is not cls and issubclass(parent, Configurable) and \ |
|
250 | if parent is not cls and issubclass(parent, Configurable) and \ | |
271 | parent.class_traits(config=True): |
|
251 | parent.class_traits(config=True): | |
272 | parents.append(parent) |
|
252 | parents.append(parent) | |
273 |
|
253 | |||
274 | if parents: |
|
254 | if parents: | |
275 | pstr = ', '.join([ p.__name__ for p in parents ]) |
|
255 | pstr = ', '.join([ p.__name__ for p in parents ]) | |
276 | lines.append(c('%s will inherit config from: %s'%(cls.__name__, pstr))) |
|
256 | lines.append(c('%s will inherit config from: %s'%(cls.__name__, pstr))) | |
277 | lines.append('') |
|
257 | lines.append('') | |
278 |
|
258 | |||
279 | for name, trait in iteritems(cls.class_traits(config=True)): |
|
259 | for name, trait in iteritems(cls.class_traits(config=True)): | |
280 | help = trait.get_metadata('help') or '' |
|
260 | help = trait.get_metadata('help') or '' | |
281 | lines.append(c(help)) |
|
261 | lines.append(c(help)) | |
282 | lines.append('# c.%s.%s = %r'%(cls.__name__, name, trait.get_default_value())) |
|
262 | lines.append('# c.%s.%s = %r'%(cls.__name__, name, trait.get_default_value())) | |
283 | lines.append('') |
|
263 | lines.append('') | |
284 | return '\n'.join(lines) |
|
264 | return '\n'.join(lines) | |
285 |
|
265 | |||
286 |
|
266 | |||
287 |
|
267 | |||
288 | class SingletonConfigurable(Configurable): |
|
268 | class SingletonConfigurable(Configurable): | |
289 | """A configurable that only allows one instance. |
|
269 | """A configurable that only allows one instance. | |
290 |
|
270 | |||
291 | This class is for classes that should only have one instance of itself |
|
271 | This class is for classes that should only have one instance of itself | |
292 | or *any* subclass. To create and retrieve such a class use the |
|
272 | or *any* subclass. To create and retrieve such a class use the | |
293 | :meth:`SingletonConfigurable.instance` method. |
|
273 | :meth:`SingletonConfigurable.instance` method. | |
294 | """ |
|
274 | """ | |
295 |
|
275 | |||
296 | _instance = None |
|
276 | _instance = None | |
297 |
|
277 | |||
298 | @classmethod |
|
278 | @classmethod | |
299 | def _walk_mro(cls): |
|
279 | def _walk_mro(cls): | |
300 | """Walk the cls.mro() for parent classes that are also singletons |
|
280 | """Walk the cls.mro() for parent classes that are also singletons | |
301 |
|
281 | |||
302 | For use in instance() |
|
282 | For use in instance() | |
303 | """ |
|
283 | """ | |
304 |
|
284 | |||
305 | for subclass in cls.mro(): |
|
285 | for subclass in cls.mro(): | |
306 | if issubclass(cls, subclass) and \ |
|
286 | if issubclass(cls, subclass) and \ | |
307 | issubclass(subclass, SingletonConfigurable) and \ |
|
287 | issubclass(subclass, SingletonConfigurable) and \ | |
308 | subclass != SingletonConfigurable: |
|
288 | subclass != SingletonConfigurable: | |
309 | yield subclass |
|
289 | yield subclass | |
310 |
|
290 | |||
311 | @classmethod |
|
291 | @classmethod | |
312 | def clear_instance(cls): |
|
292 | def clear_instance(cls): | |
313 | """unset _instance for this class and singleton parents. |
|
293 | """unset _instance for this class and singleton parents. | |
314 | """ |
|
294 | """ | |
315 | if not cls.initialized(): |
|
295 | if not cls.initialized(): | |
316 | return |
|
296 | return | |
317 | for subclass in cls._walk_mro(): |
|
297 | for subclass in cls._walk_mro(): | |
318 | if isinstance(subclass._instance, cls): |
|
298 | if isinstance(subclass._instance, cls): | |
319 | # only clear instances that are instances |
|
299 | # only clear instances that are instances | |
320 | # of the calling class |
|
300 | # of the calling class | |
321 | subclass._instance = None |
|
301 | subclass._instance = None | |
322 |
|
302 | |||
323 | @classmethod |
|
303 | @classmethod | |
324 | def instance(cls, *args, **kwargs): |
|
304 | def instance(cls, *args, **kwargs): | |
325 | """Returns a global instance of this class. |
|
305 | """Returns a global instance of this class. | |
326 |
|
306 | |||
327 | This method create a new instance if none have previously been created |
|
307 | This method create a new instance if none have previously been created | |
328 | and returns a previously created instance is one already exists. |
|
308 | and returns a previously created instance is one already exists. | |
329 |
|
309 | |||
330 | The arguments and keyword arguments passed to this method are passed |
|
310 | The arguments and keyword arguments passed to this method are passed | |
331 | on to the :meth:`__init__` method of the class upon instantiation. |
|
311 | on to the :meth:`__init__` method of the class upon instantiation. | |
332 |
|
312 | |||
333 | Examples |
|
313 | Examples | |
334 | -------- |
|
314 | -------- | |
335 |
|
315 | |||
336 | Create a singleton class using instance, and retrieve it:: |
|
316 | Create a singleton class using instance, and retrieve it:: | |
337 |
|
317 | |||
338 | >>> from IPython.config.configurable import SingletonConfigurable |
|
318 | >>> from IPython.config.configurable import SingletonConfigurable | |
339 | >>> class Foo(SingletonConfigurable): pass |
|
319 | >>> class Foo(SingletonConfigurable): pass | |
340 | >>> foo = Foo.instance() |
|
320 | >>> foo = Foo.instance() | |
341 | >>> foo == Foo.instance() |
|
321 | >>> foo == Foo.instance() | |
342 | True |
|
322 | True | |
343 |
|
323 | |||
344 | Create a subclass that is retrived using the base class instance:: |
|
324 | Create a subclass that is retrived using the base class instance:: | |
345 |
|
325 | |||
346 | >>> class Bar(SingletonConfigurable): pass |
|
326 | >>> class Bar(SingletonConfigurable): pass | |
347 | >>> class Bam(Bar): pass |
|
327 | >>> class Bam(Bar): pass | |
348 | >>> bam = Bam.instance() |
|
328 | >>> bam = Bam.instance() | |
349 | >>> bam == Bar.instance() |
|
329 | >>> bam == Bar.instance() | |
350 | True |
|
330 | True | |
351 | """ |
|
331 | """ | |
352 | # Create and save the instance |
|
332 | # Create and save the instance | |
353 | if cls._instance is None: |
|
333 | if cls._instance is None: | |
354 | inst = cls(*args, **kwargs) |
|
334 | inst = cls(*args, **kwargs) | |
355 | # Now make sure that the instance will also be returned by |
|
335 | # Now make sure that the instance will also be returned by | |
356 | # parent classes' _instance attribute. |
|
336 | # parent classes' _instance attribute. | |
357 | for subclass in cls._walk_mro(): |
|
337 | for subclass in cls._walk_mro(): | |
358 | subclass._instance = inst |
|
338 | subclass._instance = inst | |
359 |
|
339 | |||
360 | if isinstance(cls._instance, cls): |
|
340 | if isinstance(cls._instance, cls): | |
361 | return cls._instance |
|
341 | return cls._instance | |
362 | else: |
|
342 | else: | |
363 | raise MultipleInstanceError( |
|
343 | raise MultipleInstanceError( | |
364 | 'Multiple incompatible subclass instances of ' |
|
344 | 'Multiple incompatible subclass instances of ' | |
365 | '%s are being created.' % cls.__name__ |
|
345 | '%s are being created.' % cls.__name__ | |
366 | ) |
|
346 | ) | |
367 |
|
347 | |||
368 | @classmethod |
|
348 | @classmethod | |
369 | def initialized(cls): |
|
349 | def initialized(cls): | |
370 | """Has an instance been created?""" |
|
350 | """Has an instance been created?""" | |
371 | return hasattr(cls, "_instance") and cls._instance is not None |
|
351 | return hasattr(cls, "_instance") and cls._instance is not None | |
372 |
|
352 | |||
373 |
|
353 | |||
374 | class LoggingConfigurable(Configurable): |
|
354 | class LoggingConfigurable(Configurable): | |
375 | """A parent class for Configurables that log. |
|
355 | """A parent class for Configurables that log. | |
376 |
|
356 | |||
377 | Subclasses have a log trait, and the default behavior |
|
357 | Subclasses have a log trait, and the default behavior | |
378 | is to get the logger from the currently running Application |
|
358 | is to get the logger from the currently running Application. | |
379 | via Application.instance().log. |
|
|||
380 | """ |
|
359 | """ | |
381 |
|
360 | |||
382 | log = Instance('logging.Logger') |
|
361 | log = Instance('logging.Logger') | |
383 | def _log_default(self): |
|
362 | def _log_default(self): | |
384 |
from IPython. |
|
363 | from IPython.utils import log | |
385 | if Application.initialized(): |
|
364 | return log.get_logger() | |
386 | return Application.instance().log |
|
|||
387 | else: |
|
|||
388 | return logging.getLogger() |
|
|||
389 |
|
365 | |||
390 |
|
366 |
@@ -1,846 +1,824 b'' | |||||
1 | """A simple configuration system. |
|
1 | # encoding: utf-8 | |
|
2 | """A simple configuration system.""" | |||
2 |
|
3 | |||
3 | Inheritance diagram: |
|
4 | # Copyright (c) IPython Development Team. | |
4 |
|
5 | # Distributed under the terms of the Modified BSD License. | ||
5 | .. inheritance-diagram:: IPython.config.loader |
|
|||
6 | :parts: 3 |
|
|||
7 |
|
||||
8 | Authors |
|
|||
9 | ------- |
|
|||
10 | * Brian Granger |
|
|||
11 | * Fernando Perez |
|
|||
12 | * Min RK |
|
|||
13 | """ |
|
|||
14 |
|
||||
15 | #----------------------------------------------------------------------------- |
|
|||
16 | # Copyright (C) 2008-2011 The IPython Development Team |
|
|||
17 | # |
|
|||
18 | # Distributed under the terms of the BSD License. The full license is in |
|
|||
19 | # the file COPYING, distributed as part of this software. |
|
|||
20 | #----------------------------------------------------------------------------- |
|
|||
21 |
|
||||
22 | #----------------------------------------------------------------------------- |
|
|||
23 | # Imports |
|
|||
24 | #----------------------------------------------------------------------------- |
|
|||
25 |
|
6 | |||
26 | import argparse |
|
7 | import argparse | |
27 | import copy |
|
8 | import copy | |
28 | import logging |
|
9 | import logging | |
29 | import os |
|
10 | import os | |
30 | import re |
|
11 | import re | |
31 | import sys |
|
12 | import sys | |
32 | import json |
|
13 | import json | |
33 |
|
14 | |||
34 | from IPython.utils.path import filefind, get_ipython_dir |
|
15 | from IPython.utils.path import filefind, get_ipython_dir | |
35 | from IPython.utils import py3compat |
|
16 | from IPython.utils import py3compat | |
36 | from IPython.utils.encoding import DEFAULT_ENCODING |
|
17 | from IPython.utils.encoding import DEFAULT_ENCODING | |
37 | from IPython.utils.py3compat import unicode_type, iteritems |
|
18 | from IPython.utils.py3compat import unicode_type, iteritems | |
38 | from IPython.utils.traitlets import HasTraits, List, Any |
|
19 | from IPython.utils.traitlets import HasTraits, List, Any | |
39 |
|
20 | |||
40 | #----------------------------------------------------------------------------- |
|
21 | #----------------------------------------------------------------------------- | |
41 | # Exceptions |
|
22 | # Exceptions | |
42 | #----------------------------------------------------------------------------- |
|
23 | #----------------------------------------------------------------------------- | |
43 |
|
24 | |||
44 |
|
25 | |||
45 | class ConfigError(Exception): |
|
26 | class ConfigError(Exception): | |
46 | pass |
|
27 | pass | |
47 |
|
28 | |||
48 | class ConfigLoaderError(ConfigError): |
|
29 | class ConfigLoaderError(ConfigError): | |
49 | pass |
|
30 | pass | |
50 |
|
31 | |||
51 | class ConfigFileNotFound(ConfigError): |
|
32 | class ConfigFileNotFound(ConfigError): | |
52 | pass |
|
33 | pass | |
53 |
|
34 | |||
54 | class ArgumentError(ConfigLoaderError): |
|
35 | class ArgumentError(ConfigLoaderError): | |
55 | pass |
|
36 | pass | |
56 |
|
37 | |||
57 | #----------------------------------------------------------------------------- |
|
38 | #----------------------------------------------------------------------------- | |
58 | # Argparse fix |
|
39 | # Argparse fix | |
59 | #----------------------------------------------------------------------------- |
|
40 | #----------------------------------------------------------------------------- | |
60 |
|
41 | |||
61 | # Unfortunately argparse by default prints help messages to stderr instead of |
|
42 | # Unfortunately argparse by default prints help messages to stderr instead of | |
62 | # stdout. This makes it annoying to capture long help screens at the command |
|
43 | # stdout. This makes it annoying to capture long help screens at the command | |
63 | # line, since one must know how to pipe stderr, which many users don't know how |
|
44 | # line, since one must know how to pipe stderr, which many users don't know how | |
64 | # to do. So we override the print_help method with one that defaults to |
|
45 | # to do. So we override the print_help method with one that defaults to | |
65 | # stdout and use our class instead. |
|
46 | # stdout and use our class instead. | |
66 |
|
47 | |||
67 | class ArgumentParser(argparse.ArgumentParser): |
|
48 | class ArgumentParser(argparse.ArgumentParser): | |
68 | """Simple argparse subclass that prints help to stdout by default.""" |
|
49 | """Simple argparse subclass that prints help to stdout by default.""" | |
69 |
|
50 | |||
70 | def print_help(self, file=None): |
|
51 | def print_help(self, file=None): | |
71 | if file is None: |
|
52 | if file is None: | |
72 | file = sys.stdout |
|
53 | file = sys.stdout | |
73 | return super(ArgumentParser, self).print_help(file) |
|
54 | return super(ArgumentParser, self).print_help(file) | |
74 |
|
55 | |||
75 | print_help.__doc__ = argparse.ArgumentParser.print_help.__doc__ |
|
56 | print_help.__doc__ = argparse.ArgumentParser.print_help.__doc__ | |
76 |
|
57 | |||
77 | #----------------------------------------------------------------------------- |
|
58 | #----------------------------------------------------------------------------- | |
78 | # Config class for holding config information |
|
59 | # Config class for holding config information | |
79 | #----------------------------------------------------------------------------- |
|
60 | #----------------------------------------------------------------------------- | |
80 |
|
61 | |||
81 | class LazyConfigValue(HasTraits): |
|
62 | class LazyConfigValue(HasTraits): | |
82 | """Proxy object for exposing methods on configurable containers |
|
63 | """Proxy object for exposing methods on configurable containers | |
83 |
|
64 | |||
84 | Exposes: |
|
65 | Exposes: | |
85 |
|
66 | |||
86 | - append, extend, insert on lists |
|
67 | - append, extend, insert on lists | |
87 | - update on dicts |
|
68 | - update on dicts | |
88 | - update, add on sets |
|
69 | - update, add on sets | |
89 | """ |
|
70 | """ | |
90 |
|
71 | |||
91 | _value = None |
|
72 | _value = None | |
92 |
|
73 | |||
93 | # list methods |
|
74 | # list methods | |
94 | _extend = List() |
|
75 | _extend = List() | |
95 | _prepend = List() |
|
76 | _prepend = List() | |
96 |
|
77 | |||
97 | def append(self, obj): |
|
78 | def append(self, obj): | |
98 | self._extend.append(obj) |
|
79 | self._extend.append(obj) | |
99 |
|
80 | |||
100 | def extend(self, other): |
|
81 | def extend(self, other): | |
101 | self._extend.extend(other) |
|
82 | self._extend.extend(other) | |
102 |
|
83 | |||
103 | def prepend(self, other): |
|
84 | def prepend(self, other): | |
104 | """like list.extend, but for the front""" |
|
85 | """like list.extend, but for the front""" | |
105 | self._prepend[:0] = other |
|
86 | self._prepend[:0] = other | |
106 |
|
87 | |||
107 | _inserts = List() |
|
88 | _inserts = List() | |
108 | def insert(self, index, other): |
|
89 | def insert(self, index, other): | |
109 | if not isinstance(index, int): |
|
90 | if not isinstance(index, int): | |
110 | raise TypeError("An integer is required") |
|
91 | raise TypeError("An integer is required") | |
111 | self._inserts.append((index, other)) |
|
92 | self._inserts.append((index, other)) | |
112 |
|
93 | |||
113 | # dict methods |
|
94 | # dict methods | |
114 | # update is used for both dict and set |
|
95 | # update is used for both dict and set | |
115 | _update = Any() |
|
96 | _update = Any() | |
116 | def update(self, other): |
|
97 | def update(self, other): | |
117 | if self._update is None: |
|
98 | if self._update is None: | |
118 | if isinstance(other, dict): |
|
99 | if isinstance(other, dict): | |
119 | self._update = {} |
|
100 | self._update = {} | |
120 | else: |
|
101 | else: | |
121 | self._update = set() |
|
102 | self._update = set() | |
122 | self._update.update(other) |
|
103 | self._update.update(other) | |
123 |
|
104 | |||
124 | # set methods |
|
105 | # set methods | |
125 | def add(self, obj): |
|
106 | def add(self, obj): | |
126 | self.update({obj}) |
|
107 | self.update({obj}) | |
127 |
|
108 | |||
128 | def get_value(self, initial): |
|
109 | def get_value(self, initial): | |
129 | """construct the value from the initial one |
|
110 | """construct the value from the initial one | |
130 |
|
111 | |||
131 | after applying any insert / extend / update changes |
|
112 | after applying any insert / extend / update changes | |
132 | """ |
|
113 | """ | |
133 | if self._value is not None: |
|
114 | if self._value is not None: | |
134 | return self._value |
|
115 | return self._value | |
135 | value = copy.deepcopy(initial) |
|
116 | value = copy.deepcopy(initial) | |
136 | if isinstance(value, list): |
|
117 | if isinstance(value, list): | |
137 | for idx, obj in self._inserts: |
|
118 | for idx, obj in self._inserts: | |
138 | value.insert(idx, obj) |
|
119 | value.insert(idx, obj) | |
139 | value[:0] = self._prepend |
|
120 | value[:0] = self._prepend | |
140 | value.extend(self._extend) |
|
121 | value.extend(self._extend) | |
141 |
|
122 | |||
142 | elif isinstance(value, dict): |
|
123 | elif isinstance(value, dict): | |
143 | if self._update: |
|
124 | if self._update: | |
144 | value.update(self._update) |
|
125 | value.update(self._update) | |
145 | elif isinstance(value, set): |
|
126 | elif isinstance(value, set): | |
146 | if self._update: |
|
127 | if self._update: | |
147 | value.update(self._update) |
|
128 | value.update(self._update) | |
148 | self._value = value |
|
129 | self._value = value | |
149 | return value |
|
130 | return value | |
150 |
|
131 | |||
151 | def to_dict(self): |
|
132 | def to_dict(self): | |
152 | """return JSONable dict form of my data |
|
133 | """return JSONable dict form of my data | |
153 |
|
134 | |||
154 | Currently update as dict or set, extend, prepend as lists, and inserts as list of tuples. |
|
135 | Currently update as dict or set, extend, prepend as lists, and inserts as list of tuples. | |
155 | """ |
|
136 | """ | |
156 | d = {} |
|
137 | d = {} | |
157 | if self._update: |
|
138 | if self._update: | |
158 | d['update'] = self._update |
|
139 | d['update'] = self._update | |
159 | if self._extend: |
|
140 | if self._extend: | |
160 | d['extend'] = self._extend |
|
141 | d['extend'] = self._extend | |
161 | if self._prepend: |
|
142 | if self._prepend: | |
162 | d['prepend'] = self._prepend |
|
143 | d['prepend'] = self._prepend | |
163 | elif self._inserts: |
|
144 | elif self._inserts: | |
164 | d['inserts'] = self._inserts |
|
145 | d['inserts'] = self._inserts | |
165 | return d |
|
146 | return d | |
166 |
|
147 | |||
167 |
|
148 | |||
168 | def _is_section_key(key): |
|
149 | def _is_section_key(key): | |
169 | """Is a Config key a section name (does it start with a capital)?""" |
|
150 | """Is a Config key a section name (does it start with a capital)?""" | |
170 | if key and key[0].upper()==key[0] and not key.startswith('_'): |
|
151 | if key and key[0].upper()==key[0] and not key.startswith('_'): | |
171 | return True |
|
152 | return True | |
172 | else: |
|
153 | else: | |
173 | return False |
|
154 | return False | |
174 |
|
155 | |||
175 |
|
156 | |||
176 | class Config(dict): |
|
157 | class Config(dict): | |
177 | """An attribute based dict that can do smart merges.""" |
|
158 | """An attribute based dict that can do smart merges.""" | |
178 |
|
159 | |||
179 | def __init__(self, *args, **kwds): |
|
160 | def __init__(self, *args, **kwds): | |
180 | dict.__init__(self, *args, **kwds) |
|
161 | dict.__init__(self, *args, **kwds) | |
181 | self._ensure_subconfig() |
|
162 | self._ensure_subconfig() | |
182 |
|
163 | |||
183 | def _ensure_subconfig(self): |
|
164 | def _ensure_subconfig(self): | |
184 | """ensure that sub-dicts that should be Config objects are |
|
165 | """ensure that sub-dicts that should be Config objects are | |
185 |
|
166 | |||
186 | casts dicts that are under section keys to Config objects, |
|
167 | casts dicts that are under section keys to Config objects, | |
187 | which is necessary for constructing Config objects from dict literals. |
|
168 | which is necessary for constructing Config objects from dict literals. | |
188 | """ |
|
169 | """ | |
189 | for key in self: |
|
170 | for key in self: | |
190 | obj = self[key] |
|
171 | obj = self[key] | |
191 | if _is_section_key(key) \ |
|
172 | if _is_section_key(key) \ | |
192 | and isinstance(obj, dict) \ |
|
173 | and isinstance(obj, dict) \ | |
193 | and not isinstance(obj, Config): |
|
174 | and not isinstance(obj, Config): | |
194 | setattr(self, key, Config(obj)) |
|
175 | setattr(self, key, Config(obj)) | |
195 |
|
176 | |||
196 | def _merge(self, other): |
|
177 | def _merge(self, other): | |
197 | """deprecated alias, use Config.merge()""" |
|
178 | """deprecated alias, use Config.merge()""" | |
198 | self.merge(other) |
|
179 | self.merge(other) | |
199 |
|
180 | |||
200 | def merge(self, other): |
|
181 | def merge(self, other): | |
201 | """merge another config object into this one""" |
|
182 | """merge another config object into this one""" | |
202 | to_update = {} |
|
183 | to_update = {} | |
203 | for k, v in iteritems(other): |
|
184 | for k, v in iteritems(other): | |
204 | if k not in self: |
|
185 | if k not in self: | |
205 | to_update[k] = copy.deepcopy(v) |
|
186 | to_update[k] = copy.deepcopy(v) | |
206 | else: # I have this key |
|
187 | else: # I have this key | |
207 | if isinstance(v, Config) and isinstance(self[k], Config): |
|
188 | if isinstance(v, Config) and isinstance(self[k], Config): | |
208 | # Recursively merge common sub Configs |
|
189 | # Recursively merge common sub Configs | |
209 | self[k].merge(v) |
|
190 | self[k].merge(v) | |
210 | else: |
|
191 | else: | |
211 | # Plain updates for non-Configs |
|
192 | # Plain updates for non-Configs | |
212 | to_update[k] = copy.deepcopy(v) |
|
193 | to_update[k] = copy.deepcopy(v) | |
213 |
|
194 | |||
214 | self.update(to_update) |
|
195 | self.update(to_update) | |
215 |
|
196 | |||
216 | def __contains__(self, key): |
|
197 | def __contains__(self, key): | |
217 | # allow nested contains of the form `"Section.key" in config` |
|
198 | # allow nested contains of the form `"Section.key" in config` | |
218 | if '.' in key: |
|
199 | if '.' in key: | |
219 | first, remainder = key.split('.', 1) |
|
200 | first, remainder = key.split('.', 1) | |
220 | if first not in self: |
|
201 | if first not in self: | |
221 | return False |
|
202 | return False | |
222 | return remainder in self[first] |
|
203 | return remainder in self[first] | |
223 |
|
204 | |||
224 | return super(Config, self).__contains__(key) |
|
205 | return super(Config, self).__contains__(key) | |
225 |
|
206 | |||
226 | # .has_key is deprecated for dictionaries. |
|
207 | # .has_key is deprecated for dictionaries. | |
227 | has_key = __contains__ |
|
208 | has_key = __contains__ | |
228 |
|
209 | |||
229 | def _has_section(self, key): |
|
210 | def _has_section(self, key): | |
230 | return _is_section_key(key) and key in self |
|
211 | return _is_section_key(key) and key in self | |
231 |
|
212 | |||
232 | def copy(self): |
|
213 | def copy(self): | |
233 | return type(self)(dict.copy(self)) |
|
214 | return type(self)(dict.copy(self)) | |
234 |
|
215 | |||
235 | def __copy__(self): |
|
216 | def __copy__(self): | |
236 | return self.copy() |
|
217 | return self.copy() | |
237 |
|
218 | |||
238 | def __deepcopy__(self, memo): |
|
219 | def __deepcopy__(self, memo): | |
239 | import copy |
|
220 | import copy | |
240 | return type(self)(copy.deepcopy(list(self.items()))) |
|
221 | return type(self)(copy.deepcopy(list(self.items()))) | |
241 |
|
222 | |||
242 | def __getitem__(self, key): |
|
223 | def __getitem__(self, key): | |
243 | try: |
|
224 | try: | |
244 | return dict.__getitem__(self, key) |
|
225 | return dict.__getitem__(self, key) | |
245 | except KeyError: |
|
226 | except KeyError: | |
246 | if _is_section_key(key): |
|
227 | if _is_section_key(key): | |
247 | c = Config() |
|
228 | c = Config() | |
248 | dict.__setitem__(self, key, c) |
|
229 | dict.__setitem__(self, key, c) | |
249 | return c |
|
230 | return c | |
250 | elif not key.startswith('_'): |
|
231 | elif not key.startswith('_'): | |
251 | # undefined, create lazy value, used for container methods |
|
232 | # undefined, create lazy value, used for container methods | |
252 | v = LazyConfigValue() |
|
233 | v = LazyConfigValue() | |
253 | dict.__setitem__(self, key, v) |
|
234 | dict.__setitem__(self, key, v) | |
254 | return v |
|
235 | return v | |
255 | else: |
|
236 | else: | |
256 | raise KeyError |
|
237 | raise KeyError | |
257 |
|
238 | |||
258 | def __setitem__(self, key, value): |
|
239 | def __setitem__(self, key, value): | |
259 | if _is_section_key(key): |
|
240 | if _is_section_key(key): | |
260 | if not isinstance(value, Config): |
|
241 | if not isinstance(value, Config): | |
261 | raise ValueError('values whose keys begin with an uppercase ' |
|
242 | raise ValueError('values whose keys begin with an uppercase ' | |
262 | 'char must be Config instances: %r, %r' % (key, value)) |
|
243 | 'char must be Config instances: %r, %r' % (key, value)) | |
263 | dict.__setitem__(self, key, value) |
|
244 | dict.__setitem__(self, key, value) | |
264 |
|
245 | |||
265 | def __getattr__(self, key): |
|
246 | def __getattr__(self, key): | |
266 | if key.startswith('__'): |
|
247 | if key.startswith('__'): | |
267 | return dict.__getattr__(self, key) |
|
248 | return dict.__getattr__(self, key) | |
268 | try: |
|
249 | try: | |
269 | return self.__getitem__(key) |
|
250 | return self.__getitem__(key) | |
270 | except KeyError as e: |
|
251 | except KeyError as e: | |
271 | raise AttributeError(e) |
|
252 | raise AttributeError(e) | |
272 |
|
253 | |||
273 | def __setattr__(self, key, value): |
|
254 | def __setattr__(self, key, value): | |
274 | if key.startswith('__'): |
|
255 | if key.startswith('__'): | |
275 | return dict.__setattr__(self, key, value) |
|
256 | return dict.__setattr__(self, key, value) | |
276 | try: |
|
257 | try: | |
277 | self.__setitem__(key, value) |
|
258 | self.__setitem__(key, value) | |
278 | except KeyError as e: |
|
259 | except KeyError as e: | |
279 | raise AttributeError(e) |
|
260 | raise AttributeError(e) | |
280 |
|
261 | |||
281 | def __delattr__(self, key): |
|
262 | def __delattr__(self, key): | |
282 | if key.startswith('__'): |
|
263 | if key.startswith('__'): | |
283 | return dict.__delattr__(self, key) |
|
264 | return dict.__delattr__(self, key) | |
284 | try: |
|
265 | try: | |
285 | dict.__delitem__(self, key) |
|
266 | dict.__delitem__(self, key) | |
286 | except KeyError as e: |
|
267 | except KeyError as e: | |
287 | raise AttributeError(e) |
|
268 | raise AttributeError(e) | |
288 |
|
269 | |||
289 |
|
270 | |||
290 | #----------------------------------------------------------------------------- |
|
271 | #----------------------------------------------------------------------------- | |
291 | # Config loading classes |
|
272 | # Config loading classes | |
292 | #----------------------------------------------------------------------------- |
|
273 | #----------------------------------------------------------------------------- | |
293 |
|
274 | |||
294 |
|
275 | |||
295 | class ConfigLoader(object): |
|
276 | class ConfigLoader(object): | |
296 | """A object for loading configurations from just about anywhere. |
|
277 | """A object for loading configurations from just about anywhere. | |
297 |
|
278 | |||
298 | The resulting configuration is packaged as a :class:`Config`. |
|
279 | The resulting configuration is packaged as a :class:`Config`. | |
299 |
|
280 | |||
300 | Notes |
|
281 | Notes | |
301 | ----- |
|
282 | ----- | |
302 | A :class:`ConfigLoader` does one thing: load a config from a source |
|
283 | A :class:`ConfigLoader` does one thing: load a config from a source | |
303 | (file, command line arguments) and returns the data as a :class:`Config` object. |
|
284 | (file, command line arguments) and returns the data as a :class:`Config` object. | |
304 | There are lots of things that :class:`ConfigLoader` does not do. It does |
|
285 | There are lots of things that :class:`ConfigLoader` does not do. It does | |
305 | not implement complex logic for finding config files. It does not handle |
|
286 | not implement complex logic for finding config files. It does not handle | |
306 | default values or merge multiple configs. These things need to be |
|
287 | default values or merge multiple configs. These things need to be | |
307 | handled elsewhere. |
|
288 | handled elsewhere. | |
308 | """ |
|
289 | """ | |
309 |
|
290 | |||
310 | def _log_default(self): |
|
291 | def _log_default(self): | |
311 |
from IPython. |
|
292 | from IPython.utils.log import get_logger | |
312 | if Application.initialized(): |
|
293 | return get_logger() | |
313 | return Application.instance().log |
|
|||
314 | else: |
|
|||
315 | return logging.getLogger() |
|
|||
316 |
|
294 | |||
317 | def __init__(self, log=None): |
|
295 | def __init__(self, log=None): | |
318 | """A base class for config loaders. |
|
296 | """A base class for config loaders. | |
319 |
|
297 | |||
320 | log : instance of :class:`logging.Logger` to use. |
|
298 | log : instance of :class:`logging.Logger` to use. | |
321 | By default loger of :meth:`IPython.config.application.Application.instance()` |
|
299 | By default loger of :meth:`IPython.config.application.Application.instance()` | |
322 | will be used |
|
300 | will be used | |
323 |
|
301 | |||
324 | Examples |
|
302 | Examples | |
325 | -------- |
|
303 | -------- | |
326 |
|
304 | |||
327 | >>> cl = ConfigLoader() |
|
305 | >>> cl = ConfigLoader() | |
328 | >>> config = cl.load_config() |
|
306 | >>> config = cl.load_config() | |
329 | >>> config |
|
307 | >>> config | |
330 | {} |
|
308 | {} | |
331 | """ |
|
309 | """ | |
332 | self.clear() |
|
310 | self.clear() | |
333 | if log is None: |
|
311 | if log is None: | |
334 | self.log = self._log_default() |
|
312 | self.log = self._log_default() | |
335 | self.log.debug('Using default logger') |
|
313 | self.log.debug('Using default logger') | |
336 | else: |
|
314 | else: | |
337 | self.log = log |
|
315 | self.log = log | |
338 |
|
316 | |||
339 | def clear(self): |
|
317 | def clear(self): | |
340 | self.config = Config() |
|
318 | self.config = Config() | |
341 |
|
319 | |||
342 | def load_config(self): |
|
320 | def load_config(self): | |
343 | """Load a config from somewhere, return a :class:`Config` instance. |
|
321 | """Load a config from somewhere, return a :class:`Config` instance. | |
344 |
|
322 | |||
345 | Usually, this will cause self.config to be set and then returned. |
|
323 | Usually, this will cause self.config to be set and then returned. | |
346 | However, in most cases, :meth:`ConfigLoader.clear` should be called |
|
324 | However, in most cases, :meth:`ConfigLoader.clear` should be called | |
347 | to erase any previous state. |
|
325 | to erase any previous state. | |
348 | """ |
|
326 | """ | |
349 | self.clear() |
|
327 | self.clear() | |
350 | return self.config |
|
328 | return self.config | |
351 |
|
329 | |||
352 |
|
330 | |||
353 | class FileConfigLoader(ConfigLoader): |
|
331 | class FileConfigLoader(ConfigLoader): | |
354 | """A base class for file based configurations. |
|
332 | """A base class for file based configurations. | |
355 |
|
333 | |||
356 | As we add more file based config loaders, the common logic should go |
|
334 | As we add more file based config loaders, the common logic should go | |
357 | here. |
|
335 | here. | |
358 | """ |
|
336 | """ | |
359 |
|
337 | |||
360 | def __init__(self, filename, path=None, **kw): |
|
338 | def __init__(self, filename, path=None, **kw): | |
361 | """Build a config loader for a filename and path. |
|
339 | """Build a config loader for a filename and path. | |
362 |
|
340 | |||
363 | Parameters |
|
341 | Parameters | |
364 | ---------- |
|
342 | ---------- | |
365 | filename : str |
|
343 | filename : str | |
366 | The file name of the config file. |
|
344 | The file name of the config file. | |
367 | path : str, list, tuple |
|
345 | path : str, list, tuple | |
368 | The path to search for the config file on, or a sequence of |
|
346 | The path to search for the config file on, or a sequence of | |
369 | paths to try in order. |
|
347 | paths to try in order. | |
370 | """ |
|
348 | """ | |
371 | super(FileConfigLoader, self).__init__(**kw) |
|
349 | super(FileConfigLoader, self).__init__(**kw) | |
372 | self.filename = filename |
|
350 | self.filename = filename | |
373 | self.path = path |
|
351 | self.path = path | |
374 | self.full_filename = '' |
|
352 | self.full_filename = '' | |
375 |
|
353 | |||
376 | def _find_file(self): |
|
354 | def _find_file(self): | |
377 | """Try to find the file by searching the paths.""" |
|
355 | """Try to find the file by searching the paths.""" | |
378 | self.full_filename = filefind(self.filename, self.path) |
|
356 | self.full_filename = filefind(self.filename, self.path) | |
379 |
|
357 | |||
380 | class JSONFileConfigLoader(FileConfigLoader): |
|
358 | class JSONFileConfigLoader(FileConfigLoader): | |
381 | """A Json file loader for config""" |
|
359 | """A Json file loader for config""" | |
382 |
|
360 | |||
383 | def load_config(self): |
|
361 | def load_config(self): | |
384 | """Load the config from a file and return it as a Config object.""" |
|
362 | """Load the config from a file and return it as a Config object.""" | |
385 | self.clear() |
|
363 | self.clear() | |
386 | try: |
|
364 | try: | |
387 | self._find_file() |
|
365 | self._find_file() | |
388 | except IOError as e: |
|
366 | except IOError as e: | |
389 | raise ConfigFileNotFound(str(e)) |
|
367 | raise ConfigFileNotFound(str(e)) | |
390 | dct = self._read_file_as_dict() |
|
368 | dct = self._read_file_as_dict() | |
391 | self.config = self._convert_to_config(dct) |
|
369 | self.config = self._convert_to_config(dct) | |
392 | return self.config |
|
370 | return self.config | |
393 |
|
371 | |||
394 | def _read_file_as_dict(self): |
|
372 | def _read_file_as_dict(self): | |
395 | with open(self.full_filename) as f: |
|
373 | with open(self.full_filename) as f: | |
396 | return json.load(f) |
|
374 | return json.load(f) | |
397 |
|
375 | |||
398 | def _convert_to_config(self, dictionary): |
|
376 | def _convert_to_config(self, dictionary): | |
399 | if 'version' in dictionary: |
|
377 | if 'version' in dictionary: | |
400 | version = dictionary.pop('version') |
|
378 | version = dictionary.pop('version') | |
401 | else: |
|
379 | else: | |
402 | version = 1 |
|
380 | version = 1 | |
403 | self.log.warn("Unrecognized JSON config file version, assuming version {}".format(version)) |
|
381 | self.log.warn("Unrecognized JSON config file version, assuming version {}".format(version)) | |
404 |
|
382 | |||
405 | if version == 1: |
|
383 | if version == 1: | |
406 | return Config(dictionary) |
|
384 | return Config(dictionary) | |
407 | else: |
|
385 | else: | |
408 | raise ValueError('Unknown version of JSON config file: {version}'.format(version=version)) |
|
386 | raise ValueError('Unknown version of JSON config file: {version}'.format(version=version)) | |
409 |
|
387 | |||
410 |
|
388 | |||
411 | class PyFileConfigLoader(FileConfigLoader): |
|
389 | class PyFileConfigLoader(FileConfigLoader): | |
412 | """A config loader for pure python files. |
|
390 | """A config loader for pure python files. | |
413 |
|
391 | |||
414 | This is responsible for locating a Python config file by filename and |
|
392 | This is responsible for locating a Python config file by filename and | |
415 | path, then executing it to construct a Config object. |
|
393 | path, then executing it to construct a Config object. | |
416 | """ |
|
394 | """ | |
417 |
|
395 | |||
418 | def load_config(self): |
|
396 | def load_config(self): | |
419 | """Load the config from a file and return it as a Config object.""" |
|
397 | """Load the config from a file and return it as a Config object.""" | |
420 | self.clear() |
|
398 | self.clear() | |
421 | try: |
|
399 | try: | |
422 | self._find_file() |
|
400 | self._find_file() | |
423 | except IOError as e: |
|
401 | except IOError as e: | |
424 | raise ConfigFileNotFound(str(e)) |
|
402 | raise ConfigFileNotFound(str(e)) | |
425 | self._read_file_as_dict() |
|
403 | self._read_file_as_dict() | |
426 | return self.config |
|
404 | return self.config | |
427 |
|
405 | |||
428 |
|
406 | |||
429 | def _read_file_as_dict(self): |
|
407 | def _read_file_as_dict(self): | |
430 | """Load the config file into self.config, with recursive loading.""" |
|
408 | """Load the config file into self.config, with recursive loading.""" | |
431 | # This closure is made available in the namespace that is used |
|
409 | # This closure is made available in the namespace that is used | |
432 | # to exec the config file. It allows users to call |
|
410 | # to exec the config file. It allows users to call | |
433 | # load_subconfig('myconfig.py') to load config files recursively. |
|
411 | # load_subconfig('myconfig.py') to load config files recursively. | |
434 | # It needs to be a closure because it has references to self.path |
|
412 | # It needs to be a closure because it has references to self.path | |
435 | # and self.config. The sub-config is loaded with the same path |
|
413 | # and self.config. The sub-config is loaded with the same path | |
436 | # as the parent, but it uses an empty config which is then merged |
|
414 | # as the parent, but it uses an empty config which is then merged | |
437 | # with the parents. |
|
415 | # with the parents. | |
438 |
|
416 | |||
439 | # If a profile is specified, the config file will be loaded |
|
417 | # If a profile is specified, the config file will be loaded | |
440 | # from that profile |
|
418 | # from that profile | |
441 |
|
419 | |||
442 | def load_subconfig(fname, profile=None): |
|
420 | def load_subconfig(fname, profile=None): | |
443 | # import here to prevent circular imports |
|
421 | # import here to prevent circular imports | |
444 | from IPython.core.profiledir import ProfileDir, ProfileDirError |
|
422 | from IPython.core.profiledir import ProfileDir, ProfileDirError | |
445 | if profile is not None: |
|
423 | if profile is not None: | |
446 | try: |
|
424 | try: | |
447 | profile_dir = ProfileDir.find_profile_dir_by_name( |
|
425 | profile_dir = ProfileDir.find_profile_dir_by_name( | |
448 | get_ipython_dir(), |
|
426 | get_ipython_dir(), | |
449 | profile, |
|
427 | profile, | |
450 | ) |
|
428 | ) | |
451 | except ProfileDirError: |
|
429 | except ProfileDirError: | |
452 | return |
|
430 | return | |
453 | path = profile_dir.location |
|
431 | path = profile_dir.location | |
454 | else: |
|
432 | else: | |
455 | path = self.path |
|
433 | path = self.path | |
456 | loader = PyFileConfigLoader(fname, path) |
|
434 | loader = PyFileConfigLoader(fname, path) | |
457 | try: |
|
435 | try: | |
458 | sub_config = loader.load_config() |
|
436 | sub_config = loader.load_config() | |
459 | except ConfigFileNotFound: |
|
437 | except ConfigFileNotFound: | |
460 | # Pass silently if the sub config is not there. This happens |
|
438 | # Pass silently if the sub config is not there. This happens | |
461 | # when a user s using a profile, but not the default config. |
|
439 | # when a user s using a profile, but not the default config. | |
462 | pass |
|
440 | pass | |
463 | else: |
|
441 | else: | |
464 | self.config.merge(sub_config) |
|
442 | self.config.merge(sub_config) | |
465 |
|
443 | |||
466 | # Again, this needs to be a closure and should be used in config |
|
444 | # Again, this needs to be a closure and should be used in config | |
467 | # files to get the config being loaded. |
|
445 | # files to get the config being loaded. | |
468 | def get_config(): |
|
446 | def get_config(): | |
469 | return self.config |
|
447 | return self.config | |
470 |
|
448 | |||
471 | namespace = dict( |
|
449 | namespace = dict( | |
472 | load_subconfig=load_subconfig, |
|
450 | load_subconfig=load_subconfig, | |
473 | get_config=get_config, |
|
451 | get_config=get_config, | |
474 | __file__=self.full_filename, |
|
452 | __file__=self.full_filename, | |
475 | ) |
|
453 | ) | |
476 | fs_encoding = sys.getfilesystemencoding() or 'ascii' |
|
454 | fs_encoding = sys.getfilesystemencoding() or 'ascii' | |
477 | conf_filename = self.full_filename.encode(fs_encoding) |
|
455 | conf_filename = self.full_filename.encode(fs_encoding) | |
478 | py3compat.execfile(conf_filename, namespace) |
|
456 | py3compat.execfile(conf_filename, namespace) | |
479 |
|
457 | |||
480 |
|
458 | |||
481 | class CommandLineConfigLoader(ConfigLoader): |
|
459 | class CommandLineConfigLoader(ConfigLoader): | |
482 | """A config loader for command line arguments. |
|
460 | """A config loader for command line arguments. | |
483 |
|
461 | |||
484 | As we add more command line based loaders, the common logic should go |
|
462 | As we add more command line based loaders, the common logic should go | |
485 | here. |
|
463 | here. | |
486 | """ |
|
464 | """ | |
487 |
|
465 | |||
488 | def _exec_config_str(self, lhs, rhs): |
|
466 | def _exec_config_str(self, lhs, rhs): | |
489 | """execute self.config.<lhs> = <rhs> |
|
467 | """execute self.config.<lhs> = <rhs> | |
490 |
|
468 | |||
491 | * expands ~ with expanduser |
|
469 | * expands ~ with expanduser | |
492 | * tries to assign with raw eval, otherwise assigns with just the string, |
|
470 | * tries to assign with raw eval, otherwise assigns with just the string, | |
493 | allowing `--C.a=foobar` and `--C.a="foobar"` to be equivalent. *Not* |
|
471 | allowing `--C.a=foobar` and `--C.a="foobar"` to be equivalent. *Not* | |
494 | equivalent are `--C.a=4` and `--C.a='4'`. |
|
472 | equivalent are `--C.a=4` and `--C.a='4'`. | |
495 | """ |
|
473 | """ | |
496 | rhs = os.path.expanduser(rhs) |
|
474 | rhs = os.path.expanduser(rhs) | |
497 | try: |
|
475 | try: | |
498 | # Try to see if regular Python syntax will work. This |
|
476 | # Try to see if regular Python syntax will work. This | |
499 | # won't handle strings as the quote marks are removed |
|
477 | # won't handle strings as the quote marks are removed | |
500 | # by the system shell. |
|
478 | # by the system shell. | |
501 | value = eval(rhs) |
|
479 | value = eval(rhs) | |
502 | except (NameError, SyntaxError): |
|
480 | except (NameError, SyntaxError): | |
503 | # This case happens if the rhs is a string. |
|
481 | # This case happens if the rhs is a string. | |
504 | value = rhs |
|
482 | value = rhs | |
505 |
|
483 | |||
506 | exec(u'self.config.%s = value' % lhs) |
|
484 | exec(u'self.config.%s = value' % lhs) | |
507 |
|
485 | |||
508 | def _load_flag(self, cfg): |
|
486 | def _load_flag(self, cfg): | |
509 | """update self.config from a flag, which can be a dict or Config""" |
|
487 | """update self.config from a flag, which can be a dict or Config""" | |
510 | if isinstance(cfg, (dict, Config)): |
|
488 | if isinstance(cfg, (dict, Config)): | |
511 | # don't clobber whole config sections, update |
|
489 | # don't clobber whole config sections, update | |
512 | # each section from config: |
|
490 | # each section from config: | |
513 | for sec,c in iteritems(cfg): |
|
491 | for sec,c in iteritems(cfg): | |
514 | self.config[sec].update(c) |
|
492 | self.config[sec].update(c) | |
515 | else: |
|
493 | else: | |
516 | raise TypeError("Invalid flag: %r" % cfg) |
|
494 | raise TypeError("Invalid flag: %r" % cfg) | |
517 |
|
495 | |||
518 | # raw --identifier=value pattern |
|
496 | # raw --identifier=value pattern | |
519 | # but *also* accept '-' as wordsep, for aliases |
|
497 | # but *also* accept '-' as wordsep, for aliases | |
520 | # accepts: --foo=a |
|
498 | # accepts: --foo=a | |
521 | # --Class.trait=value |
|
499 | # --Class.trait=value | |
522 | # --alias-name=value |
|
500 | # --alias-name=value | |
523 | # rejects: -foo=value |
|
501 | # rejects: -foo=value | |
524 | # --foo |
|
502 | # --foo | |
525 | # --Class.trait |
|
503 | # --Class.trait | |
526 | kv_pattern = re.compile(r'\-\-[A-Za-z][\w\-]*(\.[\w\-]+)*\=.*') |
|
504 | kv_pattern = re.compile(r'\-\-[A-Za-z][\w\-]*(\.[\w\-]+)*\=.*') | |
527 |
|
505 | |||
528 | # just flags, no assignments, with two *or one* leading '-' |
|
506 | # just flags, no assignments, with two *or one* leading '-' | |
529 | # accepts: --foo |
|
507 | # accepts: --foo | |
530 | # -foo-bar-again |
|
508 | # -foo-bar-again | |
531 | # rejects: --anything=anything |
|
509 | # rejects: --anything=anything | |
532 | # --two.word |
|
510 | # --two.word | |
533 |
|
511 | |||
534 | flag_pattern = re.compile(r'\-\-?\w+[\-\w]*$') |
|
512 | flag_pattern = re.compile(r'\-\-?\w+[\-\w]*$') | |
535 |
|
513 | |||
536 | class KeyValueConfigLoader(CommandLineConfigLoader): |
|
514 | class KeyValueConfigLoader(CommandLineConfigLoader): | |
537 | """A config loader that loads key value pairs from the command line. |
|
515 | """A config loader that loads key value pairs from the command line. | |
538 |
|
516 | |||
539 | This allows command line options to be gives in the following form:: |
|
517 | This allows command line options to be gives in the following form:: | |
540 |
|
518 | |||
541 | ipython --profile="foo" --InteractiveShell.autocall=False |
|
519 | ipython --profile="foo" --InteractiveShell.autocall=False | |
542 | """ |
|
520 | """ | |
543 |
|
521 | |||
544 | def __init__(self, argv=None, aliases=None, flags=None, **kw): |
|
522 | def __init__(self, argv=None, aliases=None, flags=None, **kw): | |
545 | """Create a key value pair config loader. |
|
523 | """Create a key value pair config loader. | |
546 |
|
524 | |||
547 | Parameters |
|
525 | Parameters | |
548 | ---------- |
|
526 | ---------- | |
549 | argv : list |
|
527 | argv : list | |
550 | A list that has the form of sys.argv[1:] which has unicode |
|
528 | A list that has the form of sys.argv[1:] which has unicode | |
551 | elements of the form u"key=value". If this is None (default), |
|
529 | elements of the form u"key=value". If this is None (default), | |
552 | then sys.argv[1:] will be used. |
|
530 | then sys.argv[1:] will be used. | |
553 | aliases : dict |
|
531 | aliases : dict | |
554 | A dict of aliases for configurable traits. |
|
532 | A dict of aliases for configurable traits. | |
555 | Keys are the short aliases, Values are the resolved trait. |
|
533 | Keys are the short aliases, Values are the resolved trait. | |
556 | Of the form: `{'alias' : 'Configurable.trait'}` |
|
534 | Of the form: `{'alias' : 'Configurable.trait'}` | |
557 | flags : dict |
|
535 | flags : dict | |
558 | A dict of flags, keyed by str name. Vaues can be Config objects, |
|
536 | A dict of flags, keyed by str name. Vaues can be Config objects, | |
559 | dicts, or "key=value" strings. If Config or dict, when the flag |
|
537 | dicts, or "key=value" strings. If Config or dict, when the flag | |
560 | is triggered, The flag is loaded as `self.config.update(m)`. |
|
538 | is triggered, The flag is loaded as `self.config.update(m)`. | |
561 |
|
539 | |||
562 | Returns |
|
540 | Returns | |
563 | ------- |
|
541 | ------- | |
564 | config : Config |
|
542 | config : Config | |
565 | The resulting Config object. |
|
543 | The resulting Config object. | |
566 |
|
544 | |||
567 | Examples |
|
545 | Examples | |
568 | -------- |
|
546 | -------- | |
569 |
|
547 | |||
570 | >>> from IPython.config.loader import KeyValueConfigLoader |
|
548 | >>> from IPython.config.loader import KeyValueConfigLoader | |
571 | >>> cl = KeyValueConfigLoader() |
|
549 | >>> cl = KeyValueConfigLoader() | |
572 | >>> d = cl.load_config(["--A.name='brian'","--B.number=0"]) |
|
550 | >>> d = cl.load_config(["--A.name='brian'","--B.number=0"]) | |
573 | >>> sorted(d.items()) |
|
551 | >>> sorted(d.items()) | |
574 | [('A', {'name': 'brian'}), ('B', {'number': 0})] |
|
552 | [('A', {'name': 'brian'}), ('B', {'number': 0})] | |
575 | """ |
|
553 | """ | |
576 | super(KeyValueConfigLoader, self).__init__(**kw) |
|
554 | super(KeyValueConfigLoader, self).__init__(**kw) | |
577 | if argv is None: |
|
555 | if argv is None: | |
578 | argv = sys.argv[1:] |
|
556 | argv = sys.argv[1:] | |
579 | self.argv = argv |
|
557 | self.argv = argv | |
580 | self.aliases = aliases or {} |
|
558 | self.aliases = aliases or {} | |
581 | self.flags = flags or {} |
|
559 | self.flags = flags or {} | |
582 |
|
560 | |||
583 |
|
561 | |||
584 | def clear(self): |
|
562 | def clear(self): | |
585 | super(KeyValueConfigLoader, self).clear() |
|
563 | super(KeyValueConfigLoader, self).clear() | |
586 | self.extra_args = [] |
|
564 | self.extra_args = [] | |
587 |
|
565 | |||
588 |
|
566 | |||
589 | def _decode_argv(self, argv, enc=None): |
|
567 | def _decode_argv(self, argv, enc=None): | |
590 | """decode argv if bytes, using stin.encoding, falling back on default enc""" |
|
568 | """decode argv if bytes, using stin.encoding, falling back on default enc""" | |
591 | uargv = [] |
|
569 | uargv = [] | |
592 | if enc is None: |
|
570 | if enc is None: | |
593 | enc = DEFAULT_ENCODING |
|
571 | enc = DEFAULT_ENCODING | |
594 | for arg in argv: |
|
572 | for arg in argv: | |
595 | if not isinstance(arg, unicode_type): |
|
573 | if not isinstance(arg, unicode_type): | |
596 | # only decode if not already decoded |
|
574 | # only decode if not already decoded | |
597 | arg = arg.decode(enc) |
|
575 | arg = arg.decode(enc) | |
598 | uargv.append(arg) |
|
576 | uargv.append(arg) | |
599 | return uargv |
|
577 | return uargv | |
600 |
|
578 | |||
601 |
|
579 | |||
602 | def load_config(self, argv=None, aliases=None, flags=None): |
|
580 | def load_config(self, argv=None, aliases=None, flags=None): | |
603 | """Parse the configuration and generate the Config object. |
|
581 | """Parse the configuration and generate the Config object. | |
604 |
|
582 | |||
605 | After loading, any arguments that are not key-value or |
|
583 | After loading, any arguments that are not key-value or | |
606 | flags will be stored in self.extra_args - a list of |
|
584 | flags will be stored in self.extra_args - a list of | |
607 | unparsed command-line arguments. This is used for |
|
585 | unparsed command-line arguments. This is used for | |
608 | arguments such as input files or subcommands. |
|
586 | arguments such as input files or subcommands. | |
609 |
|
587 | |||
610 | Parameters |
|
588 | Parameters | |
611 | ---------- |
|
589 | ---------- | |
612 | argv : list, optional |
|
590 | argv : list, optional | |
613 | A list that has the form of sys.argv[1:] which has unicode |
|
591 | A list that has the form of sys.argv[1:] which has unicode | |
614 | elements of the form u"key=value". If this is None (default), |
|
592 | elements of the form u"key=value". If this is None (default), | |
615 | then self.argv will be used. |
|
593 | then self.argv will be used. | |
616 | aliases : dict |
|
594 | aliases : dict | |
617 | A dict of aliases for configurable traits. |
|
595 | A dict of aliases for configurable traits. | |
618 | Keys are the short aliases, Values are the resolved trait. |
|
596 | Keys are the short aliases, Values are the resolved trait. | |
619 | Of the form: `{'alias' : 'Configurable.trait'}` |
|
597 | Of the form: `{'alias' : 'Configurable.trait'}` | |
620 | flags : dict |
|
598 | flags : dict | |
621 | A dict of flags, keyed by str name. Values can be Config objects |
|
599 | A dict of flags, keyed by str name. Values can be Config objects | |
622 | or dicts. When the flag is triggered, The config is loaded as |
|
600 | or dicts. When the flag is triggered, The config is loaded as | |
623 | `self.config.update(cfg)`. |
|
601 | `self.config.update(cfg)`. | |
624 | """ |
|
602 | """ | |
625 | self.clear() |
|
603 | self.clear() | |
626 | if argv is None: |
|
604 | if argv is None: | |
627 | argv = self.argv |
|
605 | argv = self.argv | |
628 | if aliases is None: |
|
606 | if aliases is None: | |
629 | aliases = self.aliases |
|
607 | aliases = self.aliases | |
630 | if flags is None: |
|
608 | if flags is None: | |
631 | flags = self.flags |
|
609 | flags = self.flags | |
632 |
|
610 | |||
633 | # ensure argv is a list of unicode strings: |
|
611 | # ensure argv is a list of unicode strings: | |
634 | uargv = self._decode_argv(argv) |
|
612 | uargv = self._decode_argv(argv) | |
635 | for idx,raw in enumerate(uargv): |
|
613 | for idx,raw in enumerate(uargv): | |
636 | # strip leading '-' |
|
614 | # strip leading '-' | |
637 | item = raw.lstrip('-') |
|
615 | item = raw.lstrip('-') | |
638 |
|
616 | |||
639 | if raw == '--': |
|
617 | if raw == '--': | |
640 | # don't parse arguments after '--' |
|
618 | # don't parse arguments after '--' | |
641 | # this is useful for relaying arguments to scripts, e.g. |
|
619 | # this is useful for relaying arguments to scripts, e.g. | |
642 | # ipython -i foo.py --matplotlib=qt -- args after '--' go-to-foo.py |
|
620 | # ipython -i foo.py --matplotlib=qt -- args after '--' go-to-foo.py | |
643 | self.extra_args.extend(uargv[idx+1:]) |
|
621 | self.extra_args.extend(uargv[idx+1:]) | |
644 | break |
|
622 | break | |
645 |
|
623 | |||
646 | if kv_pattern.match(raw): |
|
624 | if kv_pattern.match(raw): | |
647 | lhs,rhs = item.split('=',1) |
|
625 | lhs,rhs = item.split('=',1) | |
648 | # Substitute longnames for aliases. |
|
626 | # Substitute longnames for aliases. | |
649 | if lhs in aliases: |
|
627 | if lhs in aliases: | |
650 | lhs = aliases[lhs] |
|
628 | lhs = aliases[lhs] | |
651 | if '.' not in lhs: |
|
629 | if '.' not in lhs: | |
652 | # probably a mistyped alias, but not technically illegal |
|
630 | # probably a mistyped alias, but not technically illegal | |
653 | self.log.warn("Unrecognized alias: '%s', it will probably have no effect.", raw) |
|
631 | self.log.warn("Unrecognized alias: '%s', it will probably have no effect.", raw) | |
654 | try: |
|
632 | try: | |
655 | self._exec_config_str(lhs, rhs) |
|
633 | self._exec_config_str(lhs, rhs) | |
656 | except Exception: |
|
634 | except Exception: | |
657 | raise ArgumentError("Invalid argument: '%s'" % raw) |
|
635 | raise ArgumentError("Invalid argument: '%s'" % raw) | |
658 |
|
636 | |||
659 | elif flag_pattern.match(raw): |
|
637 | elif flag_pattern.match(raw): | |
660 | if item in flags: |
|
638 | if item in flags: | |
661 | cfg,help = flags[item] |
|
639 | cfg,help = flags[item] | |
662 | self._load_flag(cfg) |
|
640 | self._load_flag(cfg) | |
663 | else: |
|
641 | else: | |
664 | raise ArgumentError("Unrecognized flag: '%s'"%raw) |
|
642 | raise ArgumentError("Unrecognized flag: '%s'"%raw) | |
665 | elif raw.startswith('-'): |
|
643 | elif raw.startswith('-'): | |
666 | kv = '--'+item |
|
644 | kv = '--'+item | |
667 | if kv_pattern.match(kv): |
|
645 | if kv_pattern.match(kv): | |
668 | raise ArgumentError("Invalid argument: '%s', did you mean '%s'?"%(raw, kv)) |
|
646 | raise ArgumentError("Invalid argument: '%s', did you mean '%s'?"%(raw, kv)) | |
669 | else: |
|
647 | else: | |
670 | raise ArgumentError("Invalid argument: '%s'"%raw) |
|
648 | raise ArgumentError("Invalid argument: '%s'"%raw) | |
671 | else: |
|
649 | else: | |
672 | # keep all args that aren't valid in a list, |
|
650 | # keep all args that aren't valid in a list, | |
673 | # in case our parent knows what to do with them. |
|
651 | # in case our parent knows what to do with them. | |
674 | self.extra_args.append(item) |
|
652 | self.extra_args.append(item) | |
675 | return self.config |
|
653 | return self.config | |
676 |
|
654 | |||
677 | class ArgParseConfigLoader(CommandLineConfigLoader): |
|
655 | class ArgParseConfigLoader(CommandLineConfigLoader): | |
678 | """A loader that uses the argparse module to load from the command line.""" |
|
656 | """A loader that uses the argparse module to load from the command line.""" | |
679 |
|
657 | |||
680 | def __init__(self, argv=None, aliases=None, flags=None, log=None, *parser_args, **parser_kw): |
|
658 | def __init__(self, argv=None, aliases=None, flags=None, log=None, *parser_args, **parser_kw): | |
681 | """Create a config loader for use with argparse. |
|
659 | """Create a config loader for use with argparse. | |
682 |
|
660 | |||
683 | Parameters |
|
661 | Parameters | |
684 | ---------- |
|
662 | ---------- | |
685 |
|
663 | |||
686 | argv : optional, list |
|
664 | argv : optional, list | |
687 | If given, used to read command-line arguments from, otherwise |
|
665 | If given, used to read command-line arguments from, otherwise | |
688 | sys.argv[1:] is used. |
|
666 | sys.argv[1:] is used. | |
689 |
|
667 | |||
690 | parser_args : tuple |
|
668 | parser_args : tuple | |
691 | A tuple of positional arguments that will be passed to the |
|
669 | A tuple of positional arguments that will be passed to the | |
692 | constructor of :class:`argparse.ArgumentParser`. |
|
670 | constructor of :class:`argparse.ArgumentParser`. | |
693 |
|
671 | |||
694 | parser_kw : dict |
|
672 | parser_kw : dict | |
695 | A tuple of keyword arguments that will be passed to the |
|
673 | A tuple of keyword arguments that will be passed to the | |
696 | constructor of :class:`argparse.ArgumentParser`. |
|
674 | constructor of :class:`argparse.ArgumentParser`. | |
697 |
|
675 | |||
698 | Returns |
|
676 | Returns | |
699 | ------- |
|
677 | ------- | |
700 | config : Config |
|
678 | config : Config | |
701 | The resulting Config object. |
|
679 | The resulting Config object. | |
702 | """ |
|
680 | """ | |
703 | super(CommandLineConfigLoader, self).__init__(log=log) |
|
681 | super(CommandLineConfigLoader, self).__init__(log=log) | |
704 | self.clear() |
|
682 | self.clear() | |
705 | if argv is None: |
|
683 | if argv is None: | |
706 | argv = sys.argv[1:] |
|
684 | argv = sys.argv[1:] | |
707 | self.argv = argv |
|
685 | self.argv = argv | |
708 | self.aliases = aliases or {} |
|
686 | self.aliases = aliases or {} | |
709 | self.flags = flags or {} |
|
687 | self.flags = flags or {} | |
710 |
|
688 | |||
711 | self.parser_args = parser_args |
|
689 | self.parser_args = parser_args | |
712 | self.version = parser_kw.pop("version", None) |
|
690 | self.version = parser_kw.pop("version", None) | |
713 | kwargs = dict(argument_default=argparse.SUPPRESS) |
|
691 | kwargs = dict(argument_default=argparse.SUPPRESS) | |
714 | kwargs.update(parser_kw) |
|
692 | kwargs.update(parser_kw) | |
715 | self.parser_kw = kwargs |
|
693 | self.parser_kw = kwargs | |
716 |
|
694 | |||
717 | def load_config(self, argv=None, aliases=None, flags=None): |
|
695 | def load_config(self, argv=None, aliases=None, flags=None): | |
718 | """Parse command line arguments and return as a Config object. |
|
696 | """Parse command line arguments and return as a Config object. | |
719 |
|
697 | |||
720 | Parameters |
|
698 | Parameters | |
721 | ---------- |
|
699 | ---------- | |
722 |
|
700 | |||
723 | args : optional, list |
|
701 | args : optional, list | |
724 | If given, a list with the structure of sys.argv[1:] to parse |
|
702 | If given, a list with the structure of sys.argv[1:] to parse | |
725 | arguments from. If not given, the instance's self.argv attribute |
|
703 | arguments from. If not given, the instance's self.argv attribute | |
726 | (given at construction time) is used.""" |
|
704 | (given at construction time) is used.""" | |
727 | self.clear() |
|
705 | self.clear() | |
728 | if argv is None: |
|
706 | if argv is None: | |
729 | argv = self.argv |
|
707 | argv = self.argv | |
730 | if aliases is None: |
|
708 | if aliases is None: | |
731 | aliases = self.aliases |
|
709 | aliases = self.aliases | |
732 | if flags is None: |
|
710 | if flags is None: | |
733 | flags = self.flags |
|
711 | flags = self.flags | |
734 | self._create_parser(aliases, flags) |
|
712 | self._create_parser(aliases, flags) | |
735 | self._parse_args(argv) |
|
713 | self._parse_args(argv) | |
736 | self._convert_to_config() |
|
714 | self._convert_to_config() | |
737 | return self.config |
|
715 | return self.config | |
738 |
|
716 | |||
739 | def get_extra_args(self): |
|
717 | def get_extra_args(self): | |
740 | if hasattr(self, 'extra_args'): |
|
718 | if hasattr(self, 'extra_args'): | |
741 | return self.extra_args |
|
719 | return self.extra_args | |
742 | else: |
|
720 | else: | |
743 | return [] |
|
721 | return [] | |
744 |
|
722 | |||
745 | def _create_parser(self, aliases=None, flags=None): |
|
723 | def _create_parser(self, aliases=None, flags=None): | |
746 | self.parser = ArgumentParser(*self.parser_args, **self.parser_kw) |
|
724 | self.parser = ArgumentParser(*self.parser_args, **self.parser_kw) | |
747 | self._add_arguments(aliases, flags) |
|
725 | self._add_arguments(aliases, flags) | |
748 |
|
726 | |||
749 | def _add_arguments(self, aliases=None, flags=None): |
|
727 | def _add_arguments(self, aliases=None, flags=None): | |
750 | raise NotImplementedError("subclasses must implement _add_arguments") |
|
728 | raise NotImplementedError("subclasses must implement _add_arguments") | |
751 |
|
729 | |||
752 | def _parse_args(self, args): |
|
730 | def _parse_args(self, args): | |
753 | """self.parser->self.parsed_data""" |
|
731 | """self.parser->self.parsed_data""" | |
754 | # decode sys.argv to support unicode command-line options |
|
732 | # decode sys.argv to support unicode command-line options | |
755 | enc = DEFAULT_ENCODING |
|
733 | enc = DEFAULT_ENCODING | |
756 | uargs = [py3compat.cast_unicode(a, enc) for a in args] |
|
734 | uargs = [py3compat.cast_unicode(a, enc) for a in args] | |
757 | self.parsed_data, self.extra_args = self.parser.parse_known_args(uargs) |
|
735 | self.parsed_data, self.extra_args = self.parser.parse_known_args(uargs) | |
758 |
|
736 | |||
759 | def _convert_to_config(self): |
|
737 | def _convert_to_config(self): | |
760 | """self.parsed_data->self.config""" |
|
738 | """self.parsed_data->self.config""" | |
761 | for k, v in iteritems(vars(self.parsed_data)): |
|
739 | for k, v in iteritems(vars(self.parsed_data)): | |
762 | exec("self.config.%s = v"%k, locals(), globals()) |
|
740 | exec("self.config.%s = v"%k, locals(), globals()) | |
763 |
|
741 | |||
764 | class KVArgParseConfigLoader(ArgParseConfigLoader): |
|
742 | class KVArgParseConfigLoader(ArgParseConfigLoader): | |
765 | """A config loader that loads aliases and flags with argparse, |
|
743 | """A config loader that loads aliases and flags with argparse, | |
766 | but will use KVLoader for the rest. This allows better parsing |
|
744 | but will use KVLoader for the rest. This allows better parsing | |
767 | of common args, such as `ipython -c 'print 5'`, but still gets |
|
745 | of common args, such as `ipython -c 'print 5'`, but still gets | |
768 | arbitrary config with `ipython --InteractiveShell.use_readline=False`""" |
|
746 | arbitrary config with `ipython --InteractiveShell.use_readline=False`""" | |
769 |
|
747 | |||
770 | def _add_arguments(self, aliases=None, flags=None): |
|
748 | def _add_arguments(self, aliases=None, flags=None): | |
771 | self.alias_flags = {} |
|
749 | self.alias_flags = {} | |
772 | # print aliases, flags |
|
750 | # print aliases, flags | |
773 | if aliases is None: |
|
751 | if aliases is None: | |
774 | aliases = self.aliases |
|
752 | aliases = self.aliases | |
775 | if flags is None: |
|
753 | if flags is None: | |
776 | flags = self.flags |
|
754 | flags = self.flags | |
777 | paa = self.parser.add_argument |
|
755 | paa = self.parser.add_argument | |
778 | for key,value in iteritems(aliases): |
|
756 | for key,value in iteritems(aliases): | |
779 | if key in flags: |
|
757 | if key in flags: | |
780 | # flags |
|
758 | # flags | |
781 | nargs = '?' |
|
759 | nargs = '?' | |
782 | else: |
|
760 | else: | |
783 | nargs = None |
|
761 | nargs = None | |
784 | if len(key) is 1: |
|
762 | if len(key) is 1: | |
785 | paa('-'+key, '--'+key, type=unicode_type, dest=value, nargs=nargs) |
|
763 | paa('-'+key, '--'+key, type=unicode_type, dest=value, nargs=nargs) | |
786 | else: |
|
764 | else: | |
787 | paa('--'+key, type=unicode_type, dest=value, nargs=nargs) |
|
765 | paa('--'+key, type=unicode_type, dest=value, nargs=nargs) | |
788 | for key, (value, help) in iteritems(flags): |
|
766 | for key, (value, help) in iteritems(flags): | |
789 | if key in self.aliases: |
|
767 | if key in self.aliases: | |
790 | # |
|
768 | # | |
791 | self.alias_flags[self.aliases[key]] = value |
|
769 | self.alias_flags[self.aliases[key]] = value | |
792 | continue |
|
770 | continue | |
793 | if len(key) is 1: |
|
771 | if len(key) is 1: | |
794 | paa('-'+key, '--'+key, action='append_const', dest='_flags', const=value) |
|
772 | paa('-'+key, '--'+key, action='append_const', dest='_flags', const=value) | |
795 | else: |
|
773 | else: | |
796 | paa('--'+key, action='append_const', dest='_flags', const=value) |
|
774 | paa('--'+key, action='append_const', dest='_flags', const=value) | |
797 |
|
775 | |||
798 | def _convert_to_config(self): |
|
776 | def _convert_to_config(self): | |
799 | """self.parsed_data->self.config, parse unrecognized extra args via KVLoader.""" |
|
777 | """self.parsed_data->self.config, parse unrecognized extra args via KVLoader.""" | |
800 | # remove subconfigs list from namespace before transforming the Namespace |
|
778 | # remove subconfigs list from namespace before transforming the Namespace | |
801 | if '_flags' in self.parsed_data: |
|
779 | if '_flags' in self.parsed_data: | |
802 | subcs = self.parsed_data._flags |
|
780 | subcs = self.parsed_data._flags | |
803 | del self.parsed_data._flags |
|
781 | del self.parsed_data._flags | |
804 | else: |
|
782 | else: | |
805 | subcs = [] |
|
783 | subcs = [] | |
806 |
|
784 | |||
807 | for k, v in iteritems(vars(self.parsed_data)): |
|
785 | for k, v in iteritems(vars(self.parsed_data)): | |
808 | if v is None: |
|
786 | if v is None: | |
809 | # it was a flag that shares the name of an alias |
|
787 | # it was a flag that shares the name of an alias | |
810 | subcs.append(self.alias_flags[k]) |
|
788 | subcs.append(self.alias_flags[k]) | |
811 | else: |
|
789 | else: | |
812 | # eval the KV assignment |
|
790 | # eval the KV assignment | |
813 | self._exec_config_str(k, v) |
|
791 | self._exec_config_str(k, v) | |
814 |
|
792 | |||
815 | for subc in subcs: |
|
793 | for subc in subcs: | |
816 | self._load_flag(subc) |
|
794 | self._load_flag(subc) | |
817 |
|
795 | |||
818 | if self.extra_args: |
|
796 | if self.extra_args: | |
819 | sub_parser = KeyValueConfigLoader(log=self.log) |
|
797 | sub_parser = KeyValueConfigLoader(log=self.log) | |
820 | sub_parser.load_config(self.extra_args) |
|
798 | sub_parser.load_config(self.extra_args) | |
821 | self.config.merge(sub_parser.config) |
|
799 | self.config.merge(sub_parser.config) | |
822 | self.extra_args = sub_parser.extra_args |
|
800 | self.extra_args = sub_parser.extra_args | |
823 |
|
801 | |||
824 |
|
802 | |||
825 | def load_pyconfig_files(config_files, path): |
|
803 | def load_pyconfig_files(config_files, path): | |
826 | """Load multiple Python config files, merging each of them in turn. |
|
804 | """Load multiple Python config files, merging each of them in turn. | |
827 |
|
805 | |||
828 | Parameters |
|
806 | Parameters | |
829 | ========== |
|
807 | ========== | |
830 | config_files : list of str |
|
808 | config_files : list of str | |
831 | List of config files names to load and merge into the config. |
|
809 | List of config files names to load and merge into the config. | |
832 | path : unicode |
|
810 | path : unicode | |
833 | The full path to the location of the config files. |
|
811 | The full path to the location of the config files. | |
834 | """ |
|
812 | """ | |
835 | config = Config() |
|
813 | config = Config() | |
836 | for cf in config_files: |
|
814 | for cf in config_files: | |
837 | loader = PyFileConfigLoader(cf, path=path) |
|
815 | loader = PyFileConfigLoader(cf, path=path) | |
838 | try: |
|
816 | try: | |
839 | next_config = loader.load_config() |
|
817 | next_config = loader.load_config() | |
840 | except ConfigFileNotFound: |
|
818 | except ConfigFileNotFound: | |
841 | pass |
|
819 | pass | |
842 | except: |
|
820 | except: | |
843 | raise |
|
821 | raise | |
844 | else: |
|
822 | else: | |
845 | config.merge(next_config) |
|
823 | config.merge(next_config) | |
846 | return config |
|
824 | return config |
@@ -1,390 +1,376 b'' | |||||
1 | """Base Tornado handlers for the notebook. |
|
1 | """Base Tornado handlers for the notebook.""" | |
2 |
|
||||
3 | Authors: |
|
|||
4 |
|
||||
5 | * Brian Granger |
|
|||
6 | """ |
|
|||
7 |
|
||||
8 | #----------------------------------------------------------------------------- |
|
|||
9 | # Copyright (C) 2011 The IPython Development Team |
|
|||
10 | # |
|
|||
11 | # Distributed under the terms of the BSD License. The full license is in |
|
|||
12 | # the file COPYING, distributed as part of this software. |
|
|||
13 | #----------------------------------------------------------------------------- |
|
|||
14 |
|
||||
15 | #----------------------------------------------------------------------------- |
|
|||
16 | # Imports |
|
|||
17 | #----------------------------------------------------------------------------- |
|
|||
18 |
|
2 | |||
|
3 | # Copyright (c) IPython Development Team. | |||
|
4 | # Distributed under the terms of the Modified BSD License. | |||
19 |
|
5 | |||
20 | import functools |
|
6 | import functools | |
21 | import json |
|
7 | import json | |
22 | import logging |
|
8 | import logging | |
23 | import os |
|
9 | import os | |
24 | import re |
|
10 | import re | |
25 | import sys |
|
11 | import sys | |
26 | import traceback |
|
12 | import traceback | |
27 | try: |
|
13 | try: | |
28 | # py3 |
|
14 | # py3 | |
29 | from http.client import responses |
|
15 | from http.client import responses | |
30 | except ImportError: |
|
16 | except ImportError: | |
31 | from httplib import responses |
|
17 | from httplib import responses | |
32 |
|
18 | |||
33 | from jinja2 import TemplateNotFound |
|
19 | from jinja2 import TemplateNotFound | |
34 | from tornado import web |
|
20 | from tornado import web | |
35 |
|
21 | |||
36 | try: |
|
22 | try: | |
37 | from tornado.log import app_log |
|
23 | from tornado.log import app_log | |
38 | except ImportError: |
|
24 | except ImportError: | |
39 | app_log = logging.getLogger() |
|
25 | app_log = logging.getLogger() | |
40 |
|
26 | |||
41 | from IPython.config import Application |
|
27 | from IPython.config import Application | |
42 | from IPython.utils.path import filefind |
|
28 | from IPython.utils.path import filefind | |
43 | from IPython.utils.py3compat import string_types |
|
29 | from IPython.utils.py3compat import string_types | |
44 | from IPython.html.utils import is_hidden |
|
30 | from IPython.html.utils import is_hidden | |
45 |
|
31 | |||
46 | #----------------------------------------------------------------------------- |
|
32 | #----------------------------------------------------------------------------- | |
47 | # Top-level handlers |
|
33 | # Top-level handlers | |
48 | #----------------------------------------------------------------------------- |
|
34 | #----------------------------------------------------------------------------- | |
49 | non_alphanum = re.compile(r'[^A-Za-z0-9]') |
|
35 | non_alphanum = re.compile(r'[^A-Za-z0-9]') | |
50 |
|
36 | |||
51 | class AuthenticatedHandler(web.RequestHandler): |
|
37 | class AuthenticatedHandler(web.RequestHandler): | |
52 | """A RequestHandler with an authenticated user.""" |
|
38 | """A RequestHandler with an authenticated user.""" | |
53 |
|
39 | |||
54 | def set_default_headers(self): |
|
40 | def set_default_headers(self): | |
55 | headers = self.settings.get('headers', {}) |
|
41 | headers = self.settings.get('headers', {}) | |
56 | for header_name,value in headers.items() : |
|
42 | for header_name,value in headers.items() : | |
57 | try: |
|
43 | try: | |
58 | self.set_header(header_name, value) |
|
44 | self.set_header(header_name, value) | |
59 | except Exception: |
|
45 | except Exception: | |
60 | # tornado raise Exception (not a subclass) |
|
46 | # tornado raise Exception (not a subclass) | |
61 | # if method is unsupported (websocket and Access-Control-Allow-Origin |
|
47 | # if method is unsupported (websocket and Access-Control-Allow-Origin | |
62 | # for example, so just ignore) |
|
48 | # for example, so just ignore) | |
63 | pass |
|
49 | pass | |
64 |
|
50 | |||
65 | def clear_login_cookie(self): |
|
51 | def clear_login_cookie(self): | |
66 | self.clear_cookie(self.cookie_name) |
|
52 | self.clear_cookie(self.cookie_name) | |
67 |
|
53 | |||
68 | def get_current_user(self): |
|
54 | def get_current_user(self): | |
69 | user_id = self.get_secure_cookie(self.cookie_name) |
|
55 | user_id = self.get_secure_cookie(self.cookie_name) | |
70 | # For now the user_id should not return empty, but it could eventually |
|
56 | # For now the user_id should not return empty, but it could eventually | |
71 | if user_id == '': |
|
57 | if user_id == '': | |
72 | user_id = 'anonymous' |
|
58 | user_id = 'anonymous' | |
73 | if user_id is None: |
|
59 | if user_id is None: | |
74 | # prevent extra Invalid cookie sig warnings: |
|
60 | # prevent extra Invalid cookie sig warnings: | |
75 | self.clear_login_cookie() |
|
61 | self.clear_login_cookie() | |
76 | if not self.login_available: |
|
62 | if not self.login_available: | |
77 | user_id = 'anonymous' |
|
63 | user_id = 'anonymous' | |
78 | return user_id |
|
64 | return user_id | |
79 |
|
65 | |||
80 | @property |
|
66 | @property | |
81 | def cookie_name(self): |
|
67 | def cookie_name(self): | |
82 | default_cookie_name = non_alphanum.sub('-', 'username-{}'.format( |
|
68 | default_cookie_name = non_alphanum.sub('-', 'username-{}'.format( | |
83 | self.request.host |
|
69 | self.request.host | |
84 | )) |
|
70 | )) | |
85 | return self.settings.get('cookie_name', default_cookie_name) |
|
71 | return self.settings.get('cookie_name', default_cookie_name) | |
86 |
|
72 | |||
87 | @property |
|
73 | @property | |
88 | def password(self): |
|
74 | def password(self): | |
89 | """our password""" |
|
75 | """our password""" | |
90 | return self.settings.get('password', '') |
|
76 | return self.settings.get('password', '') | |
91 |
|
77 | |||
92 | @property |
|
78 | @property | |
93 | def logged_in(self): |
|
79 | def logged_in(self): | |
94 | """Is a user currently logged in? |
|
80 | """Is a user currently logged in? | |
95 |
|
81 | |||
96 | """ |
|
82 | """ | |
97 | user = self.get_current_user() |
|
83 | user = self.get_current_user() | |
98 | return (user and not user == 'anonymous') |
|
84 | return (user and not user == 'anonymous') | |
99 |
|
85 | |||
100 | @property |
|
86 | @property | |
101 | def login_available(self): |
|
87 | def login_available(self): | |
102 | """May a user proceed to log in? |
|
88 | """May a user proceed to log in? | |
103 |
|
89 | |||
104 | This returns True if login capability is available, irrespective of |
|
90 | This returns True if login capability is available, irrespective of | |
105 | whether the user is already logged in or not. |
|
91 | whether the user is already logged in or not. | |
106 |
|
92 | |||
107 | """ |
|
93 | """ | |
108 | return bool(self.settings.get('password', '')) |
|
94 | return bool(self.settings.get('password', '')) | |
109 |
|
95 | |||
110 |
|
96 | |||
111 | class IPythonHandler(AuthenticatedHandler): |
|
97 | class IPythonHandler(AuthenticatedHandler): | |
112 | """IPython-specific extensions to authenticated handling |
|
98 | """IPython-specific extensions to authenticated handling | |
113 |
|
99 | |||
114 | Mostly property shortcuts to IPython-specific settings. |
|
100 | Mostly property shortcuts to IPython-specific settings. | |
115 | """ |
|
101 | """ | |
116 |
|
102 | |||
117 | @property |
|
103 | @property | |
118 | def config(self): |
|
104 | def config(self): | |
119 | return self.settings.get('config', None) |
|
105 | return self.settings.get('config', None) | |
120 |
|
106 | |||
121 | @property |
|
107 | @property | |
122 | def log(self): |
|
108 | def log(self): | |
123 | """use the IPython log by default, falling back on tornado's logger""" |
|
109 | """use the IPython log by default, falling back on tornado's logger""" | |
124 | if Application.initialized(): |
|
110 | if Application.initialized(): | |
125 | return Application.instance().log |
|
111 | return Application.instance().log | |
126 | else: |
|
112 | else: | |
127 | return app_log |
|
113 | return app_log | |
128 |
|
114 | |||
129 | #--------------------------------------------------------------- |
|
115 | #--------------------------------------------------------------- | |
130 | # URLs |
|
116 | # URLs | |
131 | #--------------------------------------------------------------- |
|
117 | #--------------------------------------------------------------- | |
132 |
|
118 | |||
133 | @property |
|
119 | @property | |
134 | def mathjax_url(self): |
|
120 | def mathjax_url(self): | |
135 | return self.settings.get('mathjax_url', '') |
|
121 | return self.settings.get('mathjax_url', '') | |
136 |
|
122 | |||
137 | @property |
|
123 | @property | |
138 | def base_url(self): |
|
124 | def base_url(self): | |
139 | return self.settings.get('base_url', '/') |
|
125 | return self.settings.get('base_url', '/') | |
140 |
|
126 | |||
141 | #--------------------------------------------------------------- |
|
127 | #--------------------------------------------------------------- | |
142 | # Manager objects |
|
128 | # Manager objects | |
143 | #--------------------------------------------------------------- |
|
129 | #--------------------------------------------------------------- | |
144 |
|
130 | |||
145 | @property |
|
131 | @property | |
146 | def kernel_manager(self): |
|
132 | def kernel_manager(self): | |
147 | return self.settings['kernel_manager'] |
|
133 | return self.settings['kernel_manager'] | |
148 |
|
134 | |||
149 | @property |
|
135 | @property | |
150 | def notebook_manager(self): |
|
136 | def notebook_manager(self): | |
151 | return self.settings['notebook_manager'] |
|
137 | return self.settings['notebook_manager'] | |
152 |
|
138 | |||
153 | @property |
|
139 | @property | |
154 | def cluster_manager(self): |
|
140 | def cluster_manager(self): | |
155 | return self.settings['cluster_manager'] |
|
141 | return self.settings['cluster_manager'] | |
156 |
|
142 | |||
157 | @property |
|
143 | @property | |
158 | def session_manager(self): |
|
144 | def session_manager(self): | |
159 | return self.settings['session_manager'] |
|
145 | return self.settings['session_manager'] | |
160 |
|
146 | |||
161 | @property |
|
147 | @property | |
162 | def kernel_spec_manager(self): |
|
148 | def kernel_spec_manager(self): | |
163 | return self.settings['kernel_spec_manager'] |
|
149 | return self.settings['kernel_spec_manager'] | |
164 |
|
150 | |||
165 | @property |
|
151 | @property | |
166 | def project_dir(self): |
|
152 | def project_dir(self): | |
167 | return self.notebook_manager.notebook_dir |
|
153 | return self.notebook_manager.notebook_dir | |
168 |
|
154 | |||
169 | #--------------------------------------------------------------- |
|
155 | #--------------------------------------------------------------- | |
170 | # template rendering |
|
156 | # template rendering | |
171 | #--------------------------------------------------------------- |
|
157 | #--------------------------------------------------------------- | |
172 |
|
158 | |||
173 | def get_template(self, name): |
|
159 | def get_template(self, name): | |
174 | """Return the jinja template object for a given name""" |
|
160 | """Return the jinja template object for a given name""" | |
175 | return self.settings['jinja2_env'].get_template(name) |
|
161 | return self.settings['jinja2_env'].get_template(name) | |
176 |
|
162 | |||
177 | def render_template(self, name, **ns): |
|
163 | def render_template(self, name, **ns): | |
178 | ns.update(self.template_namespace) |
|
164 | ns.update(self.template_namespace) | |
179 | template = self.get_template(name) |
|
165 | template = self.get_template(name) | |
180 | return template.render(**ns) |
|
166 | return template.render(**ns) | |
181 |
|
167 | |||
182 | @property |
|
168 | @property | |
183 | def template_namespace(self): |
|
169 | def template_namespace(self): | |
184 | return dict( |
|
170 | return dict( | |
185 | base_url=self.base_url, |
|
171 | base_url=self.base_url, | |
186 | logged_in=self.logged_in, |
|
172 | logged_in=self.logged_in, | |
187 | login_available=self.login_available, |
|
173 | login_available=self.login_available, | |
188 | static_url=self.static_url, |
|
174 | static_url=self.static_url, | |
189 | ) |
|
175 | ) | |
190 |
|
176 | |||
191 | def get_json_body(self): |
|
177 | def get_json_body(self): | |
192 | """Return the body of the request as JSON data.""" |
|
178 | """Return the body of the request as JSON data.""" | |
193 | if not self.request.body: |
|
179 | if not self.request.body: | |
194 | return None |
|
180 | return None | |
195 | # Do we need to call body.decode('utf-8') here? |
|
181 | # Do we need to call body.decode('utf-8') here? | |
196 | body = self.request.body.strip().decode(u'utf-8') |
|
182 | body = self.request.body.strip().decode(u'utf-8') | |
197 | try: |
|
183 | try: | |
198 | model = json.loads(body) |
|
184 | model = json.loads(body) | |
199 | except Exception: |
|
185 | except Exception: | |
200 | self.log.debug("Bad JSON: %r", body) |
|
186 | self.log.debug("Bad JSON: %r", body) | |
201 | self.log.error("Couldn't parse JSON", exc_info=True) |
|
187 | self.log.error("Couldn't parse JSON", exc_info=True) | |
202 | raise web.HTTPError(400, u'Invalid JSON in body of request') |
|
188 | raise web.HTTPError(400, u'Invalid JSON in body of request') | |
203 | return model |
|
189 | return model | |
204 |
|
190 | |||
205 | def get_error_html(self, status_code, **kwargs): |
|
191 | def get_error_html(self, status_code, **kwargs): | |
206 | """render custom error pages""" |
|
192 | """render custom error pages""" | |
207 | exception = kwargs.get('exception') |
|
193 | exception = kwargs.get('exception') | |
208 | message = '' |
|
194 | message = '' | |
209 | status_message = responses.get(status_code, 'Unknown HTTP Error') |
|
195 | status_message = responses.get(status_code, 'Unknown HTTP Error') | |
210 | if exception: |
|
196 | if exception: | |
211 | # get the custom message, if defined |
|
197 | # get the custom message, if defined | |
212 | try: |
|
198 | try: | |
213 | message = exception.log_message % exception.args |
|
199 | message = exception.log_message % exception.args | |
214 | except Exception: |
|
200 | except Exception: | |
215 | pass |
|
201 | pass | |
216 |
|
202 | |||
217 | # construct the custom reason, if defined |
|
203 | # construct the custom reason, if defined | |
218 | reason = getattr(exception, 'reason', '') |
|
204 | reason = getattr(exception, 'reason', '') | |
219 | if reason: |
|
205 | if reason: | |
220 | status_message = reason |
|
206 | status_message = reason | |
221 |
|
207 | |||
222 | # build template namespace |
|
208 | # build template namespace | |
223 | ns = dict( |
|
209 | ns = dict( | |
224 | status_code=status_code, |
|
210 | status_code=status_code, | |
225 | status_message=status_message, |
|
211 | status_message=status_message, | |
226 | message=message, |
|
212 | message=message, | |
227 | exception=exception, |
|
213 | exception=exception, | |
228 | ) |
|
214 | ) | |
229 |
|
215 | |||
230 | # render the template |
|
216 | # render the template | |
231 | try: |
|
217 | try: | |
232 | html = self.render_template('%s.html' % status_code, **ns) |
|
218 | html = self.render_template('%s.html' % status_code, **ns) | |
233 | except TemplateNotFound: |
|
219 | except TemplateNotFound: | |
234 | self.log.debug("No template for %d", status_code) |
|
220 | self.log.debug("No template for %d", status_code) | |
235 | html = self.render_template('error.html', **ns) |
|
221 | html = self.render_template('error.html', **ns) | |
236 | return html |
|
222 | return html | |
237 |
|
223 | |||
238 |
|
224 | |||
239 | class Template404(IPythonHandler): |
|
225 | class Template404(IPythonHandler): | |
240 | """Render our 404 template""" |
|
226 | """Render our 404 template""" | |
241 | def prepare(self): |
|
227 | def prepare(self): | |
242 | raise web.HTTPError(404) |
|
228 | raise web.HTTPError(404) | |
243 |
|
229 | |||
244 |
|
230 | |||
245 | class AuthenticatedFileHandler(IPythonHandler, web.StaticFileHandler): |
|
231 | class AuthenticatedFileHandler(IPythonHandler, web.StaticFileHandler): | |
246 | """static files should only be accessible when logged in""" |
|
232 | """static files should only be accessible when logged in""" | |
247 |
|
233 | |||
248 | @web.authenticated |
|
234 | @web.authenticated | |
249 | def get(self, path): |
|
235 | def get(self, path): | |
250 | if os.path.splitext(path)[1] == '.ipynb': |
|
236 | if os.path.splitext(path)[1] == '.ipynb': | |
251 | name = os.path.basename(path) |
|
237 | name = os.path.basename(path) | |
252 | self.set_header('Content-Type', 'application/json') |
|
238 | self.set_header('Content-Type', 'application/json') | |
253 | self.set_header('Content-Disposition','attachment; filename="%s"' % name) |
|
239 | self.set_header('Content-Disposition','attachment; filename="%s"' % name) | |
254 |
|
240 | |||
255 | return web.StaticFileHandler.get(self, path) |
|
241 | return web.StaticFileHandler.get(self, path) | |
256 |
|
242 | |||
257 | def compute_etag(self): |
|
243 | def compute_etag(self): | |
258 | return None |
|
244 | return None | |
259 |
|
245 | |||
260 | def validate_absolute_path(self, root, absolute_path): |
|
246 | def validate_absolute_path(self, root, absolute_path): | |
261 | """Validate and return the absolute path. |
|
247 | """Validate and return the absolute path. | |
262 |
|
248 | |||
263 | Requires tornado 3.1 |
|
249 | Requires tornado 3.1 | |
264 |
|
250 | |||
265 | Adding to tornado's own handling, forbids the serving of hidden files. |
|
251 | Adding to tornado's own handling, forbids the serving of hidden files. | |
266 | """ |
|
252 | """ | |
267 | abs_path = super(AuthenticatedFileHandler, self).validate_absolute_path(root, absolute_path) |
|
253 | abs_path = super(AuthenticatedFileHandler, self).validate_absolute_path(root, absolute_path) | |
268 | abs_root = os.path.abspath(root) |
|
254 | abs_root = os.path.abspath(root) | |
269 | if is_hidden(abs_path, abs_root): |
|
255 | if is_hidden(abs_path, abs_root): | |
270 | self.log.info("Refusing to serve hidden file, via 404 Error") |
|
256 | self.log.info("Refusing to serve hidden file, via 404 Error") | |
271 | raise web.HTTPError(404) |
|
257 | raise web.HTTPError(404) | |
272 | return abs_path |
|
258 | return abs_path | |
273 |
|
259 | |||
274 |
|
260 | |||
275 | def json_errors(method): |
|
261 | def json_errors(method): | |
276 | """Decorate methods with this to return GitHub style JSON errors. |
|
262 | """Decorate methods with this to return GitHub style JSON errors. | |
277 |
|
263 | |||
278 | This should be used on any JSON API on any handler method that can raise HTTPErrors. |
|
264 | This should be used on any JSON API on any handler method that can raise HTTPErrors. | |
279 |
|
265 | |||
280 | This will grab the latest HTTPError exception using sys.exc_info |
|
266 | This will grab the latest HTTPError exception using sys.exc_info | |
281 | and then: |
|
267 | and then: | |
282 |
|
268 | |||
283 | 1. Set the HTTP status code based on the HTTPError |
|
269 | 1. Set the HTTP status code based on the HTTPError | |
284 | 2. Create and return a JSON body with a message field describing |
|
270 | 2. Create and return a JSON body with a message field describing | |
285 | the error in a human readable form. |
|
271 | the error in a human readable form. | |
286 | """ |
|
272 | """ | |
287 | @functools.wraps(method) |
|
273 | @functools.wraps(method) | |
288 | def wrapper(self, *args, **kwargs): |
|
274 | def wrapper(self, *args, **kwargs): | |
289 | try: |
|
275 | try: | |
290 | result = method(self, *args, **kwargs) |
|
276 | result = method(self, *args, **kwargs) | |
291 | except web.HTTPError as e: |
|
277 | except web.HTTPError as e: | |
292 | status = e.status_code |
|
278 | status = e.status_code | |
293 | message = e.log_message |
|
279 | message = e.log_message | |
294 | self.log.warn(message) |
|
280 | self.log.warn(message) | |
295 | self.set_status(e.status_code) |
|
281 | self.set_status(e.status_code) | |
296 | self.finish(json.dumps(dict(message=message))) |
|
282 | self.finish(json.dumps(dict(message=message))) | |
297 | except Exception: |
|
283 | except Exception: | |
298 | self.log.error("Unhandled error in API request", exc_info=True) |
|
284 | self.log.error("Unhandled error in API request", exc_info=True) | |
299 | status = 500 |
|
285 | status = 500 | |
300 | message = "Unknown server error" |
|
286 | message = "Unknown server error" | |
301 | t, value, tb = sys.exc_info() |
|
287 | t, value, tb = sys.exc_info() | |
302 | self.set_status(status) |
|
288 | self.set_status(status) | |
303 | tb_text = ''.join(traceback.format_exception(t, value, tb)) |
|
289 | tb_text = ''.join(traceback.format_exception(t, value, tb)) | |
304 | reply = dict(message=message, traceback=tb_text) |
|
290 | reply = dict(message=message, traceback=tb_text) | |
305 | self.finish(json.dumps(reply)) |
|
291 | self.finish(json.dumps(reply)) | |
306 | else: |
|
292 | else: | |
307 | return result |
|
293 | return result | |
308 | return wrapper |
|
294 | return wrapper | |
309 |
|
295 | |||
310 |
|
296 | |||
311 |
|
297 | |||
312 | #----------------------------------------------------------------------------- |
|
298 | #----------------------------------------------------------------------------- | |
313 | # File handler |
|
299 | # File handler | |
314 | #----------------------------------------------------------------------------- |
|
300 | #----------------------------------------------------------------------------- | |
315 |
|
301 | |||
316 | # to minimize subclass changes: |
|
302 | # to minimize subclass changes: | |
317 | HTTPError = web.HTTPError |
|
303 | HTTPError = web.HTTPError | |
318 |
|
304 | |||
319 | class FileFindHandler(web.StaticFileHandler): |
|
305 | class FileFindHandler(web.StaticFileHandler): | |
320 | """subclass of StaticFileHandler for serving files from a search path""" |
|
306 | """subclass of StaticFileHandler for serving files from a search path""" | |
321 |
|
307 | |||
322 | # cache search results, don't search for files more than once |
|
308 | # cache search results, don't search for files more than once | |
323 | _static_paths = {} |
|
309 | _static_paths = {} | |
324 |
|
310 | |||
325 | def initialize(self, path, default_filename=None): |
|
311 | def initialize(self, path, default_filename=None): | |
326 | if isinstance(path, string_types): |
|
312 | if isinstance(path, string_types): | |
327 | path = [path] |
|
313 | path = [path] | |
328 |
|
314 | |||
329 | self.root = tuple( |
|
315 | self.root = tuple( | |
330 | os.path.abspath(os.path.expanduser(p)) + os.sep for p in path |
|
316 | os.path.abspath(os.path.expanduser(p)) + os.sep for p in path | |
331 | ) |
|
317 | ) | |
332 | self.default_filename = default_filename |
|
318 | self.default_filename = default_filename | |
333 |
|
319 | |||
334 | def compute_etag(self): |
|
320 | def compute_etag(self): | |
335 | return None |
|
321 | return None | |
336 |
|
322 | |||
337 | @classmethod |
|
323 | @classmethod | |
338 | def get_absolute_path(cls, roots, path): |
|
324 | def get_absolute_path(cls, roots, path): | |
339 | """locate a file to serve on our static file search path""" |
|
325 | """locate a file to serve on our static file search path""" | |
340 | with cls._lock: |
|
326 | with cls._lock: | |
341 | if path in cls._static_paths: |
|
327 | if path in cls._static_paths: | |
342 | return cls._static_paths[path] |
|
328 | return cls._static_paths[path] | |
343 | try: |
|
329 | try: | |
344 | abspath = os.path.abspath(filefind(path, roots)) |
|
330 | abspath = os.path.abspath(filefind(path, roots)) | |
345 | except IOError: |
|
331 | except IOError: | |
346 | # IOError means not found |
|
332 | # IOError means not found | |
347 | return '' |
|
333 | return '' | |
348 |
|
334 | |||
349 | cls._static_paths[path] = abspath |
|
335 | cls._static_paths[path] = abspath | |
350 | return abspath |
|
336 | return abspath | |
351 |
|
337 | |||
352 | def validate_absolute_path(self, root, absolute_path): |
|
338 | def validate_absolute_path(self, root, absolute_path): | |
353 | """check if the file should be served (raises 404, 403, etc.)""" |
|
339 | """check if the file should be served (raises 404, 403, etc.)""" | |
354 | if absolute_path == '': |
|
340 | if absolute_path == '': | |
355 | raise web.HTTPError(404) |
|
341 | raise web.HTTPError(404) | |
356 |
|
342 | |||
357 | for root in self.root: |
|
343 | for root in self.root: | |
358 | if (absolute_path + os.sep).startswith(root): |
|
344 | if (absolute_path + os.sep).startswith(root): | |
359 | break |
|
345 | break | |
360 |
|
346 | |||
361 | return super(FileFindHandler, self).validate_absolute_path(root, absolute_path) |
|
347 | return super(FileFindHandler, self).validate_absolute_path(root, absolute_path) | |
362 |
|
348 | |||
363 |
|
349 | |||
364 | class TrailingSlashHandler(web.RequestHandler): |
|
350 | class TrailingSlashHandler(web.RequestHandler): | |
365 | """Simple redirect handler that strips trailing slashes |
|
351 | """Simple redirect handler that strips trailing slashes | |
366 |
|
352 | |||
367 | This should be the first, highest priority handler. |
|
353 | This should be the first, highest priority handler. | |
368 | """ |
|
354 | """ | |
369 |
|
355 | |||
370 | SUPPORTED_METHODS = ['GET'] |
|
356 | SUPPORTED_METHODS = ['GET'] | |
371 |
|
357 | |||
372 | def get(self): |
|
358 | def get(self): | |
373 | self.redirect(self.request.uri.rstrip('/')) |
|
359 | self.redirect(self.request.uri.rstrip('/')) | |
374 |
|
360 | |||
375 | #----------------------------------------------------------------------------- |
|
361 | #----------------------------------------------------------------------------- | |
376 | # URL pattern fragments for re-use |
|
362 | # URL pattern fragments for re-use | |
377 | #----------------------------------------------------------------------------- |
|
363 | #----------------------------------------------------------------------------- | |
378 |
|
364 | |||
379 | path_regex = r"(?P<path>(?:/.*)*)" |
|
365 | path_regex = r"(?P<path>(?:/.*)*)" | |
380 | notebook_name_regex = r"(?P<name>[^/]+\.ipynb)" |
|
366 | notebook_name_regex = r"(?P<name>[^/]+\.ipynb)" | |
381 | notebook_path_regex = "%s/%s" % (path_regex, notebook_name_regex) |
|
367 | notebook_path_regex = "%s/%s" % (path_regex, notebook_name_regex) | |
382 |
|
368 | |||
383 | #----------------------------------------------------------------------------- |
|
369 | #----------------------------------------------------------------------------- | |
384 | # URL to handler mappings |
|
370 | # URL to handler mappings | |
385 | #----------------------------------------------------------------------------- |
|
371 | #----------------------------------------------------------------------------- | |
386 |
|
372 | |||
387 |
|
373 | |||
388 | default_handlers = [ |
|
374 | default_handlers = [ | |
389 | (r".*/", TrailingSlashHandler) |
|
375 | (r".*/", TrailingSlashHandler) | |
390 | ] |
|
376 | ] |
@@ -1,238 +1,217 b'' | |||||
1 | """The official API for working with notebooks in the current format version. |
|
1 | """The official API for working with notebooks in the current format version.""" | |
2 |
|
||||
3 | Authors: |
|
|||
4 |
|
||||
5 | * Brian Granger |
|
|||
6 | * Jonathan Frederic |
|
|||
7 | """ |
|
|||
8 |
|
||||
9 | #----------------------------------------------------------------------------- |
|
|||
10 | # Copyright (C) 2008-2011 The IPython Development Team |
|
|||
11 | # |
|
|||
12 | # Distributed under the terms of the BSD License. The full license is in |
|
|||
13 | # the file COPYING, distributed as part of this software. |
|
|||
14 | #----------------------------------------------------------------------------- |
|
|||
15 |
|
||||
16 | #----------------------------------------------------------------------------- |
|
|||
17 | # Imports |
|
|||
18 | #----------------------------------------------------------------------------- |
|
|||
19 |
|
2 | |||
20 | from __future__ import print_function |
|
3 | from __future__ import print_function | |
21 |
|
4 | |||
22 | from xml.etree import ElementTree as ET |
|
5 | from xml.etree import ElementTree as ET | |
23 | import re |
|
6 | import re | |
24 |
|
7 | |||
25 | from IPython.utils.py3compat import unicode_type |
|
8 | from IPython.utils.py3compat import unicode_type | |
26 |
|
9 | |||
27 | from IPython.nbformat.v3 import ( |
|
10 | from IPython.nbformat.v3 import ( | |
28 | NotebookNode, |
|
11 | NotebookNode, | |
29 | new_code_cell, new_text_cell, new_notebook, new_output, new_worksheet, |
|
12 | new_code_cell, new_text_cell, new_notebook, new_output, new_worksheet, | |
30 | parse_filename, new_metadata, new_author, new_heading_cell, nbformat, |
|
13 | parse_filename, new_metadata, new_author, new_heading_cell, nbformat, | |
31 | nbformat_minor, nbformat_schema, to_notebook_json |
|
14 | nbformat_minor, nbformat_schema, to_notebook_json | |
32 | ) |
|
15 | ) | |
33 | from IPython.nbformat import v3 as _v_latest |
|
16 | from IPython.nbformat import v3 as _v_latest | |
34 |
|
17 | |||
35 | from .reader import reads as reader_reads |
|
18 | from .reader import reads as reader_reads | |
36 | from .reader import versions |
|
19 | from .reader import versions | |
37 | from .convert import convert |
|
20 | from .convert import convert | |
38 | from .validator import validate |
|
21 | from .validator import validate | |
39 |
|
22 | |||
40 | import logging |
|
23 | from IPython.utils.log import get_logger | |
41 | logger = logging.getLogger('NotebookApp') |
|
|||
42 |
|
24 | |||
43 | #----------------------------------------------------------------------------- |
|
|||
44 | # Code |
|
|||
45 | #----------------------------------------------------------------------------- |
|
|||
46 |
|
25 | |||
47 | current_nbformat = nbformat |
|
26 | current_nbformat = nbformat | |
48 | current_nbformat_minor = nbformat_minor |
|
27 | current_nbformat_minor = nbformat_minor | |
49 | current_nbformat_module = _v_latest.__name__ |
|
28 | current_nbformat_module = _v_latest.__name__ | |
50 |
|
29 | |||
51 |
|
30 | |||
52 | def docstring_nbformat_mod(func): |
|
31 | def docstring_nbformat_mod(func): | |
53 | """Decorator for docstrings referring to classes/functions accessed through |
|
32 | """Decorator for docstrings referring to classes/functions accessed through | |
54 | nbformat.current. |
|
33 | nbformat.current. | |
55 |
|
34 | |||
56 | Put {nbformat_mod} in the docstring in place of 'IPython.nbformat.v3'. |
|
35 | Put {nbformat_mod} in the docstring in place of 'IPython.nbformat.v3'. | |
57 | """ |
|
36 | """ | |
58 | func.__doc__ = func.__doc__.format(nbformat_mod=current_nbformat_module) |
|
37 | func.__doc__ = func.__doc__.format(nbformat_mod=current_nbformat_module) | |
59 | return func |
|
38 | return func | |
60 |
|
39 | |||
61 |
|
40 | |||
62 | class NBFormatError(ValueError): |
|
41 | class NBFormatError(ValueError): | |
63 | pass |
|
42 | pass | |
64 |
|
43 | |||
65 |
|
44 | |||
66 | def parse_py(s, **kwargs): |
|
45 | def parse_py(s, **kwargs): | |
67 | """Parse a string into a (nbformat, string) tuple.""" |
|
46 | """Parse a string into a (nbformat, string) tuple.""" | |
68 | nbf = current_nbformat |
|
47 | nbf = current_nbformat | |
69 | nbm = current_nbformat_minor |
|
48 | nbm = current_nbformat_minor | |
70 |
|
49 | |||
71 | pattern = r'# <nbformat>(?P<nbformat>\d+[\.\d+]*)</nbformat>' |
|
50 | pattern = r'# <nbformat>(?P<nbformat>\d+[\.\d+]*)</nbformat>' | |
72 | m = re.search(pattern,s) |
|
51 | m = re.search(pattern,s) | |
73 | if m is not None: |
|
52 | if m is not None: | |
74 | digits = m.group('nbformat').split('.') |
|
53 | digits = m.group('nbformat').split('.') | |
75 | nbf = int(digits[0]) |
|
54 | nbf = int(digits[0]) | |
76 | if len(digits) > 1: |
|
55 | if len(digits) > 1: | |
77 | nbm = int(digits[1]) |
|
56 | nbm = int(digits[1]) | |
78 |
|
57 | |||
79 | return nbf, nbm, s |
|
58 | return nbf, nbm, s | |
80 |
|
59 | |||
81 |
|
60 | |||
82 | def reads_json(nbjson, **kwargs): |
|
61 | def reads_json(nbjson, **kwargs): | |
83 | """Read a JSON notebook from a string and return the NotebookNode |
|
62 | """Read a JSON notebook from a string and return the NotebookNode | |
84 | object. Report if any JSON format errors are detected. |
|
63 | object. Report if any JSON format errors are detected. | |
85 |
|
64 | |||
86 | """ |
|
65 | """ | |
87 | nb = reader_reads(nbjson, **kwargs) |
|
66 | nb = reader_reads(nbjson, **kwargs) | |
88 | nb_current = convert(nb, current_nbformat) |
|
67 | nb_current = convert(nb, current_nbformat) | |
89 | errors = validate(nb_current) |
|
68 | errors = validate(nb_current) | |
90 | if errors: |
|
69 | if errors: | |
91 | logger.error( |
|
70 | get_logger().error( | |
92 | "Notebook JSON is invalid (%d errors detected during read)", |
|
71 | "Notebook JSON is invalid (%d errors detected during read)", | |
93 | len(errors)) |
|
72 | len(errors)) | |
94 | return nb_current |
|
73 | return nb_current | |
95 |
|
74 | |||
96 |
|
75 | |||
97 | def writes_json(nb, **kwargs): |
|
76 | def writes_json(nb, **kwargs): | |
98 | """Take a NotebookNode object and write out a JSON string. Report if |
|
77 | """Take a NotebookNode object and write out a JSON string. Report if | |
99 | any JSON format errors are detected. |
|
78 | any JSON format errors are detected. | |
100 |
|
79 | |||
101 | """ |
|
80 | """ | |
102 | errors = validate(nb) |
|
81 | errors = validate(nb) | |
103 | if errors: |
|
82 | if errors: | |
104 | logger.error( |
|
83 | get_logger().error( | |
105 | "Notebook JSON is invalid (%d errors detected during write)", |
|
84 | "Notebook JSON is invalid (%d errors detected during write)", | |
106 | len(errors)) |
|
85 | len(errors)) | |
107 | nbjson = versions[current_nbformat].writes_json(nb, **kwargs) |
|
86 | nbjson = versions[current_nbformat].writes_json(nb, **kwargs) | |
108 | return nbjson |
|
87 | return nbjson | |
109 |
|
88 | |||
110 |
|
89 | |||
111 | def reads_py(s, **kwargs): |
|
90 | def reads_py(s, **kwargs): | |
112 | """Read a .py notebook from a string and return the NotebookNode object.""" |
|
91 | """Read a .py notebook from a string and return the NotebookNode object.""" | |
113 | nbf, nbm, s = parse_py(s, **kwargs) |
|
92 | nbf, nbm, s = parse_py(s, **kwargs) | |
114 | if nbf in (2, 3): |
|
93 | if nbf in (2, 3): | |
115 | nb = versions[nbf].to_notebook_py(s, **kwargs) |
|
94 | nb = versions[nbf].to_notebook_py(s, **kwargs) | |
116 | else: |
|
95 | else: | |
117 | raise NBFormatError('Unsupported PY nbformat version: %i' % nbf) |
|
96 | raise NBFormatError('Unsupported PY nbformat version: %i' % nbf) | |
118 | return nb |
|
97 | return nb | |
119 |
|
98 | |||
120 |
|
99 | |||
121 | def writes_py(nb, **kwargs): |
|
100 | def writes_py(nb, **kwargs): | |
122 | # nbformat 3 is the latest format that supports py |
|
101 | # nbformat 3 is the latest format that supports py | |
123 | return versions[3].writes_py(nb, **kwargs) |
|
102 | return versions[3].writes_py(nb, **kwargs) | |
124 |
|
103 | |||
125 |
|
104 | |||
126 | # High level API |
|
105 | # High level API | |
127 |
|
106 | |||
128 |
|
107 | |||
129 | def reads(s, format, **kwargs): |
|
108 | def reads(s, format, **kwargs): | |
130 | """Read a notebook from a string and return the NotebookNode object. |
|
109 | """Read a notebook from a string and return the NotebookNode object. | |
131 |
|
110 | |||
132 | This function properly handles notebooks of any version. The notebook |
|
111 | This function properly handles notebooks of any version. The notebook | |
133 | returned will always be in the current version's format. |
|
112 | returned will always be in the current version's format. | |
134 |
|
113 | |||
135 | Parameters |
|
114 | Parameters | |
136 | ---------- |
|
115 | ---------- | |
137 | s : unicode |
|
116 | s : unicode | |
138 | The raw unicode string to read the notebook from. |
|
117 | The raw unicode string to read the notebook from. | |
139 | format : (u'json', u'ipynb', u'py') |
|
118 | format : (u'json', u'ipynb', u'py') | |
140 | The format that the string is in. |
|
119 | The format that the string is in. | |
141 |
|
120 | |||
142 | Returns |
|
121 | Returns | |
143 | ------- |
|
122 | ------- | |
144 | nb : NotebookNode |
|
123 | nb : NotebookNode | |
145 | The notebook that was read. |
|
124 | The notebook that was read. | |
146 | """ |
|
125 | """ | |
147 | format = unicode_type(format) |
|
126 | format = unicode_type(format) | |
148 | if format == u'json' or format == u'ipynb': |
|
127 | if format == u'json' or format == u'ipynb': | |
149 | return reads_json(s, **kwargs) |
|
128 | return reads_json(s, **kwargs) | |
150 | elif format == u'py': |
|
129 | elif format == u'py': | |
151 | return reads_py(s, **kwargs) |
|
130 | return reads_py(s, **kwargs) | |
152 | else: |
|
131 | else: | |
153 | raise NBFormatError('Unsupported format: %s' % format) |
|
132 | raise NBFormatError('Unsupported format: %s' % format) | |
154 |
|
133 | |||
155 |
|
134 | |||
156 | def writes(nb, format, **kwargs): |
|
135 | def writes(nb, format, **kwargs): | |
157 | """Write a notebook to a string in a given format in the current nbformat version. |
|
136 | """Write a notebook to a string in a given format in the current nbformat version. | |
158 |
|
137 | |||
159 | This function always writes the notebook in the current nbformat version. |
|
138 | This function always writes the notebook in the current nbformat version. | |
160 |
|
139 | |||
161 | Parameters |
|
140 | Parameters | |
162 | ---------- |
|
141 | ---------- | |
163 | nb : NotebookNode |
|
142 | nb : NotebookNode | |
164 | The notebook to write. |
|
143 | The notebook to write. | |
165 | format : (u'json', u'ipynb', u'py') |
|
144 | format : (u'json', u'ipynb', u'py') | |
166 | The format to write the notebook in. |
|
145 | The format to write the notebook in. | |
167 |
|
146 | |||
168 | Returns |
|
147 | Returns | |
169 | ------- |
|
148 | ------- | |
170 | s : unicode |
|
149 | s : unicode | |
171 | The notebook string. |
|
150 | The notebook string. | |
172 | """ |
|
151 | """ | |
173 | format = unicode_type(format) |
|
152 | format = unicode_type(format) | |
174 | if format == u'json' or format == u'ipynb': |
|
153 | if format == u'json' or format == u'ipynb': | |
175 | return writes_json(nb, **kwargs) |
|
154 | return writes_json(nb, **kwargs) | |
176 | elif format == u'py': |
|
155 | elif format == u'py': | |
177 | return writes_py(nb, **kwargs) |
|
156 | return writes_py(nb, **kwargs) | |
178 | else: |
|
157 | else: | |
179 | raise NBFormatError('Unsupported format: %s' % format) |
|
158 | raise NBFormatError('Unsupported format: %s' % format) | |
180 |
|
159 | |||
181 |
|
160 | |||
182 | def read(fp, format, **kwargs): |
|
161 | def read(fp, format, **kwargs): | |
183 | """Read a notebook from a file and return the NotebookNode object. |
|
162 | """Read a notebook from a file and return the NotebookNode object. | |
184 |
|
163 | |||
185 | This function properly handles notebooks of any version. The notebook |
|
164 | This function properly handles notebooks of any version. The notebook | |
186 | returned will always be in the current version's format. |
|
165 | returned will always be in the current version's format. | |
187 |
|
166 | |||
188 | Parameters |
|
167 | Parameters | |
189 | ---------- |
|
168 | ---------- | |
190 | fp : file |
|
169 | fp : file | |
191 | Any file-like object with a read method. |
|
170 | Any file-like object with a read method. | |
192 | format : (u'json', u'ipynb', u'py') |
|
171 | format : (u'json', u'ipynb', u'py') | |
193 | The format that the string is in. |
|
172 | The format that the string is in. | |
194 |
|
173 | |||
195 | Returns |
|
174 | Returns | |
196 | ------- |
|
175 | ------- | |
197 | nb : NotebookNode |
|
176 | nb : NotebookNode | |
198 | The notebook that was read. |
|
177 | The notebook that was read. | |
199 | """ |
|
178 | """ | |
200 | return reads(fp.read(), format, **kwargs) |
|
179 | return reads(fp.read(), format, **kwargs) | |
201 |
|
180 | |||
202 |
|
181 | |||
203 | def write(nb, fp, format, **kwargs): |
|
182 | def write(nb, fp, format, **kwargs): | |
204 | """Write a notebook to a file in a given format in the current nbformat version. |
|
183 | """Write a notebook to a file in a given format in the current nbformat version. | |
205 |
|
184 | |||
206 | This function always writes the notebook in the current nbformat version. |
|
185 | This function always writes the notebook in the current nbformat version. | |
207 |
|
186 | |||
208 | Parameters |
|
187 | Parameters | |
209 | ---------- |
|
188 | ---------- | |
210 | nb : NotebookNode |
|
189 | nb : NotebookNode | |
211 | The notebook to write. |
|
190 | The notebook to write. | |
212 | fp : file |
|
191 | fp : file | |
213 | Any file-like object with a write method. |
|
192 | Any file-like object with a write method. | |
214 | format : (u'json', u'ipynb', u'py') |
|
193 | format : (u'json', u'ipynb', u'py') | |
215 | The format to write the notebook in. |
|
194 | The format to write the notebook in. | |
216 |
|
195 | |||
217 | Returns |
|
196 | Returns | |
218 | ------- |
|
197 | ------- | |
219 | s : unicode |
|
198 | s : unicode | |
220 | The notebook string. |
|
199 | The notebook string. | |
221 | """ |
|
200 | """ | |
222 | return fp.write(writes(nb, format, **kwargs)) |
|
201 | return fp.write(writes(nb, format, **kwargs)) | |
223 |
|
202 | |||
224 | def _convert_to_metadata(): |
|
203 | def _convert_to_metadata(): | |
225 | """Convert to a notebook having notebook metadata.""" |
|
204 | """Convert to a notebook having notebook metadata.""" | |
226 | import glob |
|
205 | import glob | |
227 | for fname in glob.glob('*.ipynb'): |
|
206 | for fname in glob.glob('*.ipynb'): | |
228 | print('Converting file:',fname) |
|
207 | print('Converting file:',fname) | |
229 | with open(fname,'r') as f: |
|
208 | with open(fname,'r') as f: | |
230 | nb = read(f,u'json') |
|
209 | nb = read(f,u'json') | |
231 | md = new_metadata() |
|
210 | md = new_metadata() | |
232 | if u'name' in nb: |
|
211 | if u'name' in nb: | |
233 | md.name = nb.name |
|
212 | md.name = nb.name | |
234 | del nb[u'name'] |
|
213 | del nb[u'name'] | |
235 | nb.metadata = md |
|
214 | nb.metadata = md | |
236 | with open(fname,'w') as f: |
|
215 | with open(fname,'w') as f: | |
237 | write(nb, f, u'json') |
|
216 | write(nb, f, u'json') | |
238 |
|
217 |
@@ -1,859 +1,848 b'' | |||||
1 | """The Python scheduler for rich scheduling. |
|
1 | """The Python scheduler for rich scheduling. | |
2 |
|
2 | |||
3 | The Pure ZMQ scheduler does not allow routing schemes other than LRU, |
|
3 | The Pure ZMQ scheduler does not allow routing schemes other than LRU, | |
4 | nor does it check msg_id DAG dependencies. For those, a slightly slower |
|
4 | nor does it check msg_id DAG dependencies. For those, a slightly slower | |
5 | Python Scheduler exists. |
|
5 | Python Scheduler exists. | |
6 |
|
||||
7 | Authors: |
|
|||
8 |
|
||||
9 | * Min RK |
|
|||
10 | """ |
|
6 | """ | |
11 | #----------------------------------------------------------------------------- |
|
|||
12 | # Copyright (C) 2010-2011 The IPython Development Team |
|
|||
13 | # |
|
|||
14 | # Distributed under the terms of the BSD License. The full license is in |
|
|||
15 | # the file COPYING, distributed as part of this software. |
|
|||
16 | #----------------------------------------------------------------------------- |
|
|||
17 |
|
7 | |||
18 | #---------------------------------------------------------------------- |
|
8 | # Copyright (c) IPython Development Team. | |
19 | # Imports |
|
9 | # Distributed under the terms of the Modified BSD License. | |
20 | #---------------------------------------------------------------------- |
|
|||
21 |
|
10 | |||
22 | import logging |
|
11 | import logging | |
23 | import sys |
|
12 | import sys | |
24 | import time |
|
13 | import time | |
25 |
|
14 | |||
26 | from collections import deque |
|
15 | from collections import deque | |
27 | from datetime import datetime |
|
16 | from datetime import datetime | |
28 | from random import randint, random |
|
17 | from random import randint, random | |
29 | from types import FunctionType |
|
18 | from types import FunctionType | |
30 |
|
19 | |||
31 | try: |
|
20 | try: | |
32 | import numpy |
|
21 | import numpy | |
33 | except ImportError: |
|
22 | except ImportError: | |
34 | numpy = None |
|
23 | numpy = None | |
35 |
|
24 | |||
36 | import zmq |
|
25 | import zmq | |
37 | from zmq.eventloop import ioloop, zmqstream |
|
26 | from zmq.eventloop import ioloop, zmqstream | |
38 |
|
27 | |||
39 | # local imports |
|
28 | # local imports | |
40 | from IPython.external.decorator import decorator |
|
29 | from IPython.external.decorator import decorator | |
41 | from IPython.config.application import Application |
|
30 | from IPython.config.application import Application | |
42 | from IPython.config.loader import Config |
|
31 | from IPython.config.loader import Config | |
43 | from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes |
|
32 | from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes | |
44 | from IPython.utils.py3compat import cast_bytes |
|
33 | from IPython.utils.py3compat import cast_bytes | |
45 |
|
34 | |||
46 | from IPython.parallel import error, util |
|
35 | from IPython.parallel import error, util | |
47 | from IPython.parallel.factory import SessionFactory |
|
36 | from IPython.parallel.factory import SessionFactory | |
48 | from IPython.parallel.util import connect_logger, local_logger |
|
37 | from IPython.parallel.util import connect_logger, local_logger | |
49 |
|
38 | |||
50 | from .dependency import Dependency |
|
39 | from .dependency import Dependency | |
51 |
|
40 | |||
52 | @decorator |
|
41 | @decorator | |
53 | def logged(f,self,*args,**kwargs): |
|
42 | def logged(f,self,*args,**kwargs): | |
54 | # print ("#--------------------") |
|
43 | # print ("#--------------------") | |
55 | self.log.debug("scheduler::%s(*%s,**%s)", f.__name__, args, kwargs) |
|
44 | self.log.debug("scheduler::%s(*%s,**%s)", f.__name__, args, kwargs) | |
56 | # print ("#--") |
|
45 | # print ("#--") | |
57 | return f(self,*args, **kwargs) |
|
46 | return f(self,*args, **kwargs) | |
58 |
|
47 | |||
59 | #---------------------------------------------------------------------- |
|
48 | #---------------------------------------------------------------------- | |
60 | # Chooser functions |
|
49 | # Chooser functions | |
61 | #---------------------------------------------------------------------- |
|
50 | #---------------------------------------------------------------------- | |
62 |
|
51 | |||
63 | def plainrandom(loads): |
|
52 | def plainrandom(loads): | |
64 | """Plain random pick.""" |
|
53 | """Plain random pick.""" | |
65 | n = len(loads) |
|
54 | n = len(loads) | |
66 | return randint(0,n-1) |
|
55 | return randint(0,n-1) | |
67 |
|
56 | |||
68 | def lru(loads): |
|
57 | def lru(loads): | |
69 | """Always pick the front of the line. |
|
58 | """Always pick the front of the line. | |
70 |
|
59 | |||
71 | The content of `loads` is ignored. |
|
60 | The content of `loads` is ignored. | |
72 |
|
61 | |||
73 | Assumes LRU ordering of loads, with oldest first. |
|
62 | Assumes LRU ordering of loads, with oldest first. | |
74 | """ |
|
63 | """ | |
75 | return 0 |
|
64 | return 0 | |
76 |
|
65 | |||
77 | def twobin(loads): |
|
66 | def twobin(loads): | |
78 | """Pick two at random, use the LRU of the two. |
|
67 | """Pick two at random, use the LRU of the two. | |
79 |
|
68 | |||
80 | The content of loads is ignored. |
|
69 | The content of loads is ignored. | |
81 |
|
70 | |||
82 | Assumes LRU ordering of loads, with oldest first. |
|
71 | Assumes LRU ordering of loads, with oldest first. | |
83 | """ |
|
72 | """ | |
84 | n = len(loads) |
|
73 | n = len(loads) | |
85 | a = randint(0,n-1) |
|
74 | a = randint(0,n-1) | |
86 | b = randint(0,n-1) |
|
75 | b = randint(0,n-1) | |
87 | return min(a,b) |
|
76 | return min(a,b) | |
88 |
|
77 | |||
89 | def weighted(loads): |
|
78 | def weighted(loads): | |
90 | """Pick two at random using inverse load as weight. |
|
79 | """Pick two at random using inverse load as weight. | |
91 |
|
80 | |||
92 | Return the less loaded of the two. |
|
81 | Return the less loaded of the two. | |
93 | """ |
|
82 | """ | |
94 | # weight 0 a million times more than 1: |
|
83 | # weight 0 a million times more than 1: | |
95 | weights = 1./(1e-6+numpy.array(loads)) |
|
84 | weights = 1./(1e-6+numpy.array(loads)) | |
96 | sums = weights.cumsum() |
|
85 | sums = weights.cumsum() | |
97 | t = sums[-1] |
|
86 | t = sums[-1] | |
98 | x = random()*t |
|
87 | x = random()*t | |
99 | y = random()*t |
|
88 | y = random()*t | |
100 | idx = 0 |
|
89 | idx = 0 | |
101 | idy = 0 |
|
90 | idy = 0 | |
102 | while sums[idx] < x: |
|
91 | while sums[idx] < x: | |
103 | idx += 1 |
|
92 | idx += 1 | |
104 | while sums[idy] < y: |
|
93 | while sums[idy] < y: | |
105 | idy += 1 |
|
94 | idy += 1 | |
106 | if weights[idy] > weights[idx]: |
|
95 | if weights[idy] > weights[idx]: | |
107 | return idy |
|
96 | return idy | |
108 | else: |
|
97 | else: | |
109 | return idx |
|
98 | return idx | |
110 |
|
99 | |||
111 | def leastload(loads): |
|
100 | def leastload(loads): | |
112 | """Always choose the lowest load. |
|
101 | """Always choose the lowest load. | |
113 |
|
102 | |||
114 | If the lowest load occurs more than once, the first |
|
103 | If the lowest load occurs more than once, the first | |
115 | occurance will be used. If loads has LRU ordering, this means |
|
104 | occurance will be used. If loads has LRU ordering, this means | |
116 | the LRU of those with the lowest load is chosen. |
|
105 | the LRU of those with the lowest load is chosen. | |
117 | """ |
|
106 | """ | |
118 | return loads.index(min(loads)) |
|
107 | return loads.index(min(loads)) | |
119 |
|
108 | |||
120 | #--------------------------------------------------------------------- |
|
109 | #--------------------------------------------------------------------- | |
121 | # Classes |
|
110 | # Classes | |
122 | #--------------------------------------------------------------------- |
|
111 | #--------------------------------------------------------------------- | |
123 |
|
112 | |||
124 |
|
113 | |||
125 | # store empty default dependency: |
|
114 | # store empty default dependency: | |
126 | MET = Dependency([]) |
|
115 | MET = Dependency([]) | |
127 |
|
116 | |||
128 |
|
117 | |||
129 | class Job(object): |
|
118 | class Job(object): | |
130 | """Simple container for a job""" |
|
119 | """Simple container for a job""" | |
131 | def __init__(self, msg_id, raw_msg, idents, msg, header, metadata, |
|
120 | def __init__(self, msg_id, raw_msg, idents, msg, header, metadata, | |
132 | targets, after, follow, timeout): |
|
121 | targets, after, follow, timeout): | |
133 | self.msg_id = msg_id |
|
122 | self.msg_id = msg_id | |
134 | self.raw_msg = raw_msg |
|
123 | self.raw_msg = raw_msg | |
135 | self.idents = idents |
|
124 | self.idents = idents | |
136 | self.msg = msg |
|
125 | self.msg = msg | |
137 | self.header = header |
|
126 | self.header = header | |
138 | self.metadata = metadata |
|
127 | self.metadata = metadata | |
139 | self.targets = targets |
|
128 | self.targets = targets | |
140 | self.after = after |
|
129 | self.after = after | |
141 | self.follow = follow |
|
130 | self.follow = follow | |
142 | self.timeout = timeout |
|
131 | self.timeout = timeout | |
143 |
|
132 | |||
144 | self.removed = False # used for lazy-delete from sorted queue |
|
133 | self.removed = False # used for lazy-delete from sorted queue | |
145 | self.timestamp = time.time() |
|
134 | self.timestamp = time.time() | |
146 | self.timeout_id = 0 |
|
135 | self.timeout_id = 0 | |
147 | self.blacklist = set() |
|
136 | self.blacklist = set() | |
148 |
|
137 | |||
149 | def __lt__(self, other): |
|
138 | def __lt__(self, other): | |
150 | return self.timestamp < other.timestamp |
|
139 | return self.timestamp < other.timestamp | |
151 |
|
140 | |||
152 | def __cmp__(self, other): |
|
141 | def __cmp__(self, other): | |
153 | return cmp(self.timestamp, other.timestamp) |
|
142 | return cmp(self.timestamp, other.timestamp) | |
154 |
|
143 | |||
155 | @property |
|
144 | @property | |
156 | def dependents(self): |
|
145 | def dependents(self): | |
157 | return self.follow.union(self.after) |
|
146 | return self.follow.union(self.after) | |
158 |
|
147 | |||
159 |
|
148 | |||
160 | class TaskScheduler(SessionFactory): |
|
149 | class TaskScheduler(SessionFactory): | |
161 | """Python TaskScheduler object. |
|
150 | """Python TaskScheduler object. | |
162 |
|
151 | |||
163 | This is the simplest object that supports msg_id based |
|
152 | This is the simplest object that supports msg_id based | |
164 | DAG dependencies. *Only* task msg_ids are checked, not |
|
153 | DAG dependencies. *Only* task msg_ids are checked, not | |
165 | msg_ids of jobs submitted via the MUX queue. |
|
154 | msg_ids of jobs submitted via the MUX queue. | |
166 |
|
155 | |||
167 | """ |
|
156 | """ | |
168 |
|
157 | |||
169 | hwm = Integer(1, config=True, |
|
158 | hwm = Integer(1, config=True, | |
170 | help="""specify the High Water Mark (HWM) for the downstream |
|
159 | help="""specify the High Water Mark (HWM) for the downstream | |
171 | socket in the Task scheduler. This is the maximum number |
|
160 | socket in the Task scheduler. This is the maximum number | |
172 | of allowed outstanding tasks on each engine. |
|
161 | of allowed outstanding tasks on each engine. | |
173 |
|
162 | |||
174 | The default (1) means that only one task can be outstanding on each |
|
163 | The default (1) means that only one task can be outstanding on each | |
175 | engine. Setting TaskScheduler.hwm=0 means there is no limit, and the |
|
164 | engine. Setting TaskScheduler.hwm=0 means there is no limit, and the | |
176 | engines continue to be assigned tasks while they are working, |
|
165 | engines continue to be assigned tasks while they are working, | |
177 | effectively hiding network latency behind computation, but can result |
|
166 | effectively hiding network latency behind computation, but can result | |
178 | in an imbalance of work when submitting many heterogenous tasks all at |
|
167 | in an imbalance of work when submitting many heterogenous tasks all at | |
179 | once. Any positive value greater than one is a compromise between the |
|
168 | once. Any positive value greater than one is a compromise between the | |
180 | two. |
|
169 | two. | |
181 |
|
170 | |||
182 | """ |
|
171 | """ | |
183 | ) |
|
172 | ) | |
184 | scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'), |
|
173 | scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'), | |
185 | 'leastload', config=True, allow_none=False, |
|
174 | 'leastload', config=True, allow_none=False, | |
186 | help="""select the task scheduler scheme [default: Python LRU] |
|
175 | help="""select the task scheduler scheme [default: Python LRU] | |
187 | Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'""" |
|
176 | Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'""" | |
188 | ) |
|
177 | ) | |
189 | def _scheme_name_changed(self, old, new): |
|
178 | def _scheme_name_changed(self, old, new): | |
190 | self.log.debug("Using scheme %r"%new) |
|
179 | self.log.debug("Using scheme %r"%new) | |
191 | self.scheme = globals()[new] |
|
180 | self.scheme = globals()[new] | |
192 |
|
181 | |||
193 | # input arguments: |
|
182 | # input arguments: | |
194 | scheme = Instance(FunctionType) # function for determining the destination |
|
183 | scheme = Instance(FunctionType) # function for determining the destination | |
195 | def _scheme_default(self): |
|
184 | def _scheme_default(self): | |
196 | return leastload |
|
185 | return leastload | |
197 | client_stream = Instance(zmqstream.ZMQStream) # client-facing stream |
|
186 | client_stream = Instance(zmqstream.ZMQStream) # client-facing stream | |
198 | engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream |
|
187 | engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream | |
199 | notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream |
|
188 | notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream | |
200 | mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream |
|
189 | mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream | |
201 | query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream |
|
190 | query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream | |
202 |
|
191 | |||
203 | # internals: |
|
192 | # internals: | |
204 | queue = Instance(deque) # sorted list of Jobs |
|
193 | queue = Instance(deque) # sorted list of Jobs | |
205 | def _queue_default(self): |
|
194 | def _queue_default(self): | |
206 | return deque() |
|
195 | return deque() | |
207 | queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue) |
|
196 | queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue) | |
208 | graph = Dict() # dict by msg_id of [ msg_ids that depend on key ] |
|
197 | graph = Dict() # dict by msg_id of [ msg_ids that depend on key ] | |
209 | retries = Dict() # dict by msg_id of retries remaining (non-neg ints) |
|
198 | retries = Dict() # dict by msg_id of retries remaining (non-neg ints) | |
210 | # waiting = List() # list of msg_ids ready to run, but haven't due to HWM |
|
199 | # waiting = List() # list of msg_ids ready to run, but haven't due to HWM | |
211 | pending = Dict() # dict by engine_uuid of submitted tasks |
|
200 | pending = Dict() # dict by engine_uuid of submitted tasks | |
212 | completed = Dict() # dict by engine_uuid of completed tasks |
|
201 | completed = Dict() # dict by engine_uuid of completed tasks | |
213 | failed = Dict() # dict by engine_uuid of failed tasks |
|
202 | failed = Dict() # dict by engine_uuid of failed tasks | |
214 | destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed) |
|
203 | destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed) | |
215 | clients = Dict() # dict by msg_id for who submitted the task |
|
204 | clients = Dict() # dict by msg_id for who submitted the task | |
216 | targets = List() # list of target IDENTs |
|
205 | targets = List() # list of target IDENTs | |
217 | loads = List() # list of engine loads |
|
206 | loads = List() # list of engine loads | |
218 | # full = Set() # set of IDENTs that have HWM outstanding tasks |
|
207 | # full = Set() # set of IDENTs that have HWM outstanding tasks | |
219 | all_completed = Set() # set of all completed tasks |
|
208 | all_completed = Set() # set of all completed tasks | |
220 | all_failed = Set() # set of all failed tasks |
|
209 | all_failed = Set() # set of all failed tasks | |
221 | all_done = Set() # set of all finished tasks=union(completed,failed) |
|
210 | all_done = Set() # set of all finished tasks=union(completed,failed) | |
222 | all_ids = Set() # set of all submitted task IDs |
|
211 | all_ids = Set() # set of all submitted task IDs | |
223 |
|
212 | |||
224 | ident = CBytes() # ZMQ identity. This should just be self.session.session |
|
213 | ident = CBytes() # ZMQ identity. This should just be self.session.session | |
225 | # but ensure Bytes |
|
214 | # but ensure Bytes | |
226 | def _ident_default(self): |
|
215 | def _ident_default(self): | |
227 | return self.session.bsession |
|
216 | return self.session.bsession | |
228 |
|
217 | |||
229 | def start(self): |
|
218 | def start(self): | |
230 | self.query_stream.on_recv(self.dispatch_query_reply) |
|
219 | self.query_stream.on_recv(self.dispatch_query_reply) | |
231 | self.session.send(self.query_stream, "connection_request", {}) |
|
220 | self.session.send(self.query_stream, "connection_request", {}) | |
232 |
|
221 | |||
233 | self.engine_stream.on_recv(self.dispatch_result, copy=False) |
|
222 | self.engine_stream.on_recv(self.dispatch_result, copy=False) | |
234 | self.client_stream.on_recv(self.dispatch_submission, copy=False) |
|
223 | self.client_stream.on_recv(self.dispatch_submission, copy=False) | |
235 |
|
224 | |||
236 | self._notification_handlers = dict( |
|
225 | self._notification_handlers = dict( | |
237 | registration_notification = self._register_engine, |
|
226 | registration_notification = self._register_engine, | |
238 | unregistration_notification = self._unregister_engine |
|
227 | unregistration_notification = self._unregister_engine | |
239 | ) |
|
228 | ) | |
240 | self.notifier_stream.on_recv(self.dispatch_notification) |
|
229 | self.notifier_stream.on_recv(self.dispatch_notification) | |
241 | self.log.info("Scheduler started [%s]" % self.scheme_name) |
|
230 | self.log.info("Scheduler started [%s]" % self.scheme_name) | |
242 |
|
231 | |||
243 | def resume_receiving(self): |
|
232 | def resume_receiving(self): | |
244 | """Resume accepting jobs.""" |
|
233 | """Resume accepting jobs.""" | |
245 | self.client_stream.on_recv(self.dispatch_submission, copy=False) |
|
234 | self.client_stream.on_recv(self.dispatch_submission, copy=False) | |
246 |
|
235 | |||
247 | def stop_receiving(self): |
|
236 | def stop_receiving(self): | |
248 | """Stop accepting jobs while there are no engines. |
|
237 | """Stop accepting jobs while there are no engines. | |
249 | Leave them in the ZMQ queue.""" |
|
238 | Leave them in the ZMQ queue.""" | |
250 | self.client_stream.on_recv(None) |
|
239 | self.client_stream.on_recv(None) | |
251 |
|
240 | |||
252 | #----------------------------------------------------------------------- |
|
241 | #----------------------------------------------------------------------- | |
253 | # [Un]Registration Handling |
|
242 | # [Un]Registration Handling | |
254 | #----------------------------------------------------------------------- |
|
243 | #----------------------------------------------------------------------- | |
255 |
|
244 | |||
256 |
|
245 | |||
257 | def dispatch_query_reply(self, msg): |
|
246 | def dispatch_query_reply(self, msg): | |
258 | """handle reply to our initial connection request""" |
|
247 | """handle reply to our initial connection request""" | |
259 | try: |
|
248 | try: | |
260 | idents,msg = self.session.feed_identities(msg) |
|
249 | idents,msg = self.session.feed_identities(msg) | |
261 | except ValueError: |
|
250 | except ValueError: | |
262 | self.log.warn("task::Invalid Message: %r",msg) |
|
251 | self.log.warn("task::Invalid Message: %r",msg) | |
263 | return |
|
252 | return | |
264 | try: |
|
253 | try: | |
265 | msg = self.session.unserialize(msg) |
|
254 | msg = self.session.unserialize(msg) | |
266 | except ValueError: |
|
255 | except ValueError: | |
267 | self.log.warn("task::Unauthorized message from: %r"%idents) |
|
256 | self.log.warn("task::Unauthorized message from: %r"%idents) | |
268 | return |
|
257 | return | |
269 |
|
258 | |||
270 | content = msg['content'] |
|
259 | content = msg['content'] | |
271 | for uuid in content.get('engines', {}).values(): |
|
260 | for uuid in content.get('engines', {}).values(): | |
272 | self._register_engine(cast_bytes(uuid)) |
|
261 | self._register_engine(cast_bytes(uuid)) | |
273 |
|
262 | |||
274 |
|
263 | |||
275 | @util.log_errors |
|
264 | @util.log_errors | |
276 | def dispatch_notification(self, msg): |
|
265 | def dispatch_notification(self, msg): | |
277 | """dispatch register/unregister events.""" |
|
266 | """dispatch register/unregister events.""" | |
278 | try: |
|
267 | try: | |
279 | idents,msg = self.session.feed_identities(msg) |
|
268 | idents,msg = self.session.feed_identities(msg) | |
280 | except ValueError: |
|
269 | except ValueError: | |
281 | self.log.warn("task::Invalid Message: %r",msg) |
|
270 | self.log.warn("task::Invalid Message: %r",msg) | |
282 | return |
|
271 | return | |
283 | try: |
|
272 | try: | |
284 | msg = self.session.unserialize(msg) |
|
273 | msg = self.session.unserialize(msg) | |
285 | except ValueError: |
|
274 | except ValueError: | |
286 | self.log.warn("task::Unauthorized message from: %r"%idents) |
|
275 | self.log.warn("task::Unauthorized message from: %r"%idents) | |
287 | return |
|
276 | return | |
288 |
|
277 | |||
289 | msg_type = msg['header']['msg_type'] |
|
278 | msg_type = msg['header']['msg_type'] | |
290 |
|
279 | |||
291 | handler = self._notification_handlers.get(msg_type, None) |
|
280 | handler = self._notification_handlers.get(msg_type, None) | |
292 | if handler is None: |
|
281 | if handler is None: | |
293 | self.log.error("Unhandled message type: %r"%msg_type) |
|
282 | self.log.error("Unhandled message type: %r"%msg_type) | |
294 | else: |
|
283 | else: | |
295 | try: |
|
284 | try: | |
296 | handler(cast_bytes(msg['content']['uuid'])) |
|
285 | handler(cast_bytes(msg['content']['uuid'])) | |
297 | except Exception: |
|
286 | except Exception: | |
298 | self.log.error("task::Invalid notification msg: %r", msg, exc_info=True) |
|
287 | self.log.error("task::Invalid notification msg: %r", msg, exc_info=True) | |
299 |
|
288 | |||
300 | def _register_engine(self, uid): |
|
289 | def _register_engine(self, uid): | |
301 | """New engine with ident `uid` became available.""" |
|
290 | """New engine with ident `uid` became available.""" | |
302 | # head of the line: |
|
291 | # head of the line: | |
303 | self.targets.insert(0,uid) |
|
292 | self.targets.insert(0,uid) | |
304 | self.loads.insert(0,0) |
|
293 | self.loads.insert(0,0) | |
305 |
|
294 | |||
306 | # initialize sets |
|
295 | # initialize sets | |
307 | self.completed[uid] = set() |
|
296 | self.completed[uid] = set() | |
308 | self.failed[uid] = set() |
|
297 | self.failed[uid] = set() | |
309 | self.pending[uid] = {} |
|
298 | self.pending[uid] = {} | |
310 |
|
299 | |||
311 | # rescan the graph: |
|
300 | # rescan the graph: | |
312 | self.update_graph(None) |
|
301 | self.update_graph(None) | |
313 |
|
302 | |||
314 | def _unregister_engine(self, uid): |
|
303 | def _unregister_engine(self, uid): | |
315 | """Existing engine with ident `uid` became unavailable.""" |
|
304 | """Existing engine with ident `uid` became unavailable.""" | |
316 | if len(self.targets) == 1: |
|
305 | if len(self.targets) == 1: | |
317 | # this was our only engine |
|
306 | # this was our only engine | |
318 | pass |
|
307 | pass | |
319 |
|
308 | |||
320 | # handle any potentially finished tasks: |
|
309 | # handle any potentially finished tasks: | |
321 | self.engine_stream.flush() |
|
310 | self.engine_stream.flush() | |
322 |
|
311 | |||
323 | # don't pop destinations, because they might be used later |
|
312 | # don't pop destinations, because they might be used later | |
324 | # map(self.destinations.pop, self.completed.pop(uid)) |
|
313 | # map(self.destinations.pop, self.completed.pop(uid)) | |
325 | # map(self.destinations.pop, self.failed.pop(uid)) |
|
314 | # map(self.destinations.pop, self.failed.pop(uid)) | |
326 |
|
315 | |||
327 | # prevent this engine from receiving work |
|
316 | # prevent this engine from receiving work | |
328 | idx = self.targets.index(uid) |
|
317 | idx = self.targets.index(uid) | |
329 | self.targets.pop(idx) |
|
318 | self.targets.pop(idx) | |
330 | self.loads.pop(idx) |
|
319 | self.loads.pop(idx) | |
331 |
|
320 | |||
332 | # wait 5 seconds before cleaning up pending jobs, since the results might |
|
321 | # wait 5 seconds before cleaning up pending jobs, since the results might | |
333 | # still be incoming |
|
322 | # still be incoming | |
334 | if self.pending[uid]: |
|
323 | if self.pending[uid]: | |
335 | dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop) |
|
324 | dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop) | |
336 | dc.start() |
|
325 | dc.start() | |
337 | else: |
|
326 | else: | |
338 | self.completed.pop(uid) |
|
327 | self.completed.pop(uid) | |
339 | self.failed.pop(uid) |
|
328 | self.failed.pop(uid) | |
340 |
|
329 | |||
341 |
|
330 | |||
342 | def handle_stranded_tasks(self, engine): |
|
331 | def handle_stranded_tasks(self, engine): | |
343 | """Deal with jobs resident in an engine that died.""" |
|
332 | """Deal with jobs resident in an engine that died.""" | |
344 | lost = self.pending[engine] |
|
333 | lost = self.pending[engine] | |
345 | for msg_id in lost.keys(): |
|
334 | for msg_id in lost.keys(): | |
346 | if msg_id not in self.pending[engine]: |
|
335 | if msg_id not in self.pending[engine]: | |
347 | # prevent double-handling of messages |
|
336 | # prevent double-handling of messages | |
348 | continue |
|
337 | continue | |
349 |
|
338 | |||
350 | raw_msg = lost[msg_id].raw_msg |
|
339 | raw_msg = lost[msg_id].raw_msg | |
351 | idents,msg = self.session.feed_identities(raw_msg, copy=False) |
|
340 | idents,msg = self.session.feed_identities(raw_msg, copy=False) | |
352 | parent = self.session.unpack(msg[1].bytes) |
|
341 | parent = self.session.unpack(msg[1].bytes) | |
353 | idents = [engine, idents[0]] |
|
342 | idents = [engine, idents[0]] | |
354 |
|
343 | |||
355 | # build fake error reply |
|
344 | # build fake error reply | |
356 | try: |
|
345 | try: | |
357 | raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id)) |
|
346 | raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id)) | |
358 | except: |
|
347 | except: | |
359 | content = error.wrap_exception() |
|
348 | content = error.wrap_exception() | |
360 | # build fake metadata |
|
349 | # build fake metadata | |
361 | md = dict( |
|
350 | md = dict( | |
362 | status=u'error', |
|
351 | status=u'error', | |
363 | engine=engine.decode('ascii'), |
|
352 | engine=engine.decode('ascii'), | |
364 | date=datetime.now(), |
|
353 | date=datetime.now(), | |
365 | ) |
|
354 | ) | |
366 | msg = self.session.msg('apply_reply', content, parent=parent, metadata=md) |
|
355 | msg = self.session.msg('apply_reply', content, parent=parent, metadata=md) | |
367 | raw_reply = list(map(zmq.Message, self.session.serialize(msg, ident=idents))) |
|
356 | raw_reply = list(map(zmq.Message, self.session.serialize(msg, ident=idents))) | |
368 | # and dispatch it |
|
357 | # and dispatch it | |
369 | self.dispatch_result(raw_reply) |
|
358 | self.dispatch_result(raw_reply) | |
370 |
|
359 | |||
371 | # finally scrub completed/failed lists |
|
360 | # finally scrub completed/failed lists | |
372 | self.completed.pop(engine) |
|
361 | self.completed.pop(engine) | |
373 | self.failed.pop(engine) |
|
362 | self.failed.pop(engine) | |
374 |
|
363 | |||
375 |
|
364 | |||
376 | #----------------------------------------------------------------------- |
|
365 | #----------------------------------------------------------------------- | |
377 | # Job Submission |
|
366 | # Job Submission | |
378 | #----------------------------------------------------------------------- |
|
367 | #----------------------------------------------------------------------- | |
379 |
|
368 | |||
380 |
|
369 | |||
381 | @util.log_errors |
|
370 | @util.log_errors | |
382 | def dispatch_submission(self, raw_msg): |
|
371 | def dispatch_submission(self, raw_msg): | |
383 | """Dispatch job submission to appropriate handlers.""" |
|
372 | """Dispatch job submission to appropriate handlers.""" | |
384 | # ensure targets up to date: |
|
373 | # ensure targets up to date: | |
385 | self.notifier_stream.flush() |
|
374 | self.notifier_stream.flush() | |
386 | try: |
|
375 | try: | |
387 | idents, msg = self.session.feed_identities(raw_msg, copy=False) |
|
376 | idents, msg = self.session.feed_identities(raw_msg, copy=False) | |
388 | msg = self.session.unserialize(msg, content=False, copy=False) |
|
377 | msg = self.session.unserialize(msg, content=False, copy=False) | |
389 | except Exception: |
|
378 | except Exception: | |
390 | self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True) |
|
379 | self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True) | |
391 | return |
|
380 | return | |
392 |
|
381 | |||
393 |
|
382 | |||
394 | # send to monitor |
|
383 | # send to monitor | |
395 | self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False) |
|
384 | self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False) | |
396 |
|
385 | |||
397 | header = msg['header'] |
|
386 | header = msg['header'] | |
398 | md = msg['metadata'] |
|
387 | md = msg['metadata'] | |
399 | msg_id = header['msg_id'] |
|
388 | msg_id = header['msg_id'] | |
400 | self.all_ids.add(msg_id) |
|
389 | self.all_ids.add(msg_id) | |
401 |
|
390 | |||
402 | # get targets as a set of bytes objects |
|
391 | # get targets as a set of bytes objects | |
403 | # from a list of unicode objects |
|
392 | # from a list of unicode objects | |
404 | targets = md.get('targets', []) |
|
393 | targets = md.get('targets', []) | |
405 | targets = set(map(cast_bytes, targets)) |
|
394 | targets = set(map(cast_bytes, targets)) | |
406 |
|
395 | |||
407 | retries = md.get('retries', 0) |
|
396 | retries = md.get('retries', 0) | |
408 | self.retries[msg_id] = retries |
|
397 | self.retries[msg_id] = retries | |
409 |
|
398 | |||
410 | # time dependencies |
|
399 | # time dependencies | |
411 | after = md.get('after', None) |
|
400 | after = md.get('after', None) | |
412 | if after: |
|
401 | if after: | |
413 | after = Dependency(after) |
|
402 | after = Dependency(after) | |
414 | if after.all: |
|
403 | if after.all: | |
415 | if after.success: |
|
404 | if after.success: | |
416 | after = Dependency(after.difference(self.all_completed), |
|
405 | after = Dependency(after.difference(self.all_completed), | |
417 | success=after.success, |
|
406 | success=after.success, | |
418 | failure=after.failure, |
|
407 | failure=after.failure, | |
419 | all=after.all, |
|
408 | all=after.all, | |
420 | ) |
|
409 | ) | |
421 | if after.failure: |
|
410 | if after.failure: | |
422 | after = Dependency(after.difference(self.all_failed), |
|
411 | after = Dependency(after.difference(self.all_failed), | |
423 | success=after.success, |
|
412 | success=after.success, | |
424 | failure=after.failure, |
|
413 | failure=after.failure, | |
425 | all=after.all, |
|
414 | all=after.all, | |
426 | ) |
|
415 | ) | |
427 | if after.check(self.all_completed, self.all_failed): |
|
416 | if after.check(self.all_completed, self.all_failed): | |
428 | # recast as empty set, if `after` already met, |
|
417 | # recast as empty set, if `after` already met, | |
429 | # to prevent unnecessary set comparisons |
|
418 | # to prevent unnecessary set comparisons | |
430 | after = MET |
|
419 | after = MET | |
431 | else: |
|
420 | else: | |
432 | after = MET |
|
421 | after = MET | |
433 |
|
422 | |||
434 | # location dependencies |
|
423 | # location dependencies | |
435 | follow = Dependency(md.get('follow', [])) |
|
424 | follow = Dependency(md.get('follow', [])) | |
436 |
|
425 | |||
437 | timeout = md.get('timeout', None) |
|
426 | timeout = md.get('timeout', None) | |
438 | if timeout: |
|
427 | if timeout: | |
439 | timeout = float(timeout) |
|
428 | timeout = float(timeout) | |
440 |
|
429 | |||
441 | job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg, |
|
430 | job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg, | |
442 | header=header, targets=targets, after=after, follow=follow, |
|
431 | header=header, targets=targets, after=after, follow=follow, | |
443 | timeout=timeout, metadata=md, |
|
432 | timeout=timeout, metadata=md, | |
444 | ) |
|
433 | ) | |
445 | # validate and reduce dependencies: |
|
434 | # validate and reduce dependencies: | |
446 | for dep in after,follow: |
|
435 | for dep in after,follow: | |
447 | if not dep: # empty dependency |
|
436 | if not dep: # empty dependency | |
448 | continue |
|
437 | continue | |
449 | # check valid: |
|
438 | # check valid: | |
450 | if msg_id in dep or dep.difference(self.all_ids): |
|
439 | if msg_id in dep or dep.difference(self.all_ids): | |
451 | self.queue_map[msg_id] = job |
|
440 | self.queue_map[msg_id] = job | |
452 | return self.fail_unreachable(msg_id, error.InvalidDependency) |
|
441 | return self.fail_unreachable(msg_id, error.InvalidDependency) | |
453 | # check if unreachable: |
|
442 | # check if unreachable: | |
454 | if dep.unreachable(self.all_completed, self.all_failed): |
|
443 | if dep.unreachable(self.all_completed, self.all_failed): | |
455 | self.queue_map[msg_id] = job |
|
444 | self.queue_map[msg_id] = job | |
456 | return self.fail_unreachable(msg_id) |
|
445 | return self.fail_unreachable(msg_id) | |
457 |
|
446 | |||
458 | if after.check(self.all_completed, self.all_failed): |
|
447 | if after.check(self.all_completed, self.all_failed): | |
459 | # time deps already met, try to run |
|
448 | # time deps already met, try to run | |
460 | if not self.maybe_run(job): |
|
449 | if not self.maybe_run(job): | |
461 | # can't run yet |
|
450 | # can't run yet | |
462 | if msg_id not in self.all_failed: |
|
451 | if msg_id not in self.all_failed: | |
463 | # could have failed as unreachable |
|
452 | # could have failed as unreachable | |
464 | self.save_unmet(job) |
|
453 | self.save_unmet(job) | |
465 | else: |
|
454 | else: | |
466 | self.save_unmet(job) |
|
455 | self.save_unmet(job) | |
467 |
|
456 | |||
468 | def job_timeout(self, job, timeout_id): |
|
457 | def job_timeout(self, job, timeout_id): | |
469 | """callback for a job's timeout. |
|
458 | """callback for a job's timeout. | |
470 |
|
459 | |||
471 | The job may or may not have been run at this point. |
|
460 | The job may or may not have been run at this point. | |
472 | """ |
|
461 | """ | |
473 | if job.timeout_id != timeout_id: |
|
462 | if job.timeout_id != timeout_id: | |
474 | # not the most recent call |
|
463 | # not the most recent call | |
475 | return |
|
464 | return | |
476 | now = time.time() |
|
465 | now = time.time() | |
477 | if job.timeout >= (now + 1): |
|
466 | if job.timeout >= (now + 1): | |
478 | self.log.warn("task %s timeout fired prematurely: %s > %s", |
|
467 | self.log.warn("task %s timeout fired prematurely: %s > %s", | |
479 | job.msg_id, job.timeout, now |
|
468 | job.msg_id, job.timeout, now | |
480 | ) |
|
469 | ) | |
481 | if job.msg_id in self.queue_map: |
|
470 | if job.msg_id in self.queue_map: | |
482 | # still waiting, but ran out of time |
|
471 | # still waiting, but ran out of time | |
483 | self.log.info("task %r timed out", job.msg_id) |
|
472 | self.log.info("task %r timed out", job.msg_id) | |
484 | self.fail_unreachable(job.msg_id, error.TaskTimeout) |
|
473 | self.fail_unreachable(job.msg_id, error.TaskTimeout) | |
485 |
|
474 | |||
486 | def fail_unreachable(self, msg_id, why=error.ImpossibleDependency): |
|
475 | def fail_unreachable(self, msg_id, why=error.ImpossibleDependency): | |
487 | """a task has become unreachable, send a reply with an ImpossibleDependency |
|
476 | """a task has become unreachable, send a reply with an ImpossibleDependency | |
488 | error.""" |
|
477 | error.""" | |
489 | if msg_id not in self.queue_map: |
|
478 | if msg_id not in self.queue_map: | |
490 | self.log.error("task %r already failed!", msg_id) |
|
479 | self.log.error("task %r already failed!", msg_id) | |
491 | return |
|
480 | return | |
492 | job = self.queue_map.pop(msg_id) |
|
481 | job = self.queue_map.pop(msg_id) | |
493 | # lazy-delete from the queue |
|
482 | # lazy-delete from the queue | |
494 | job.removed = True |
|
483 | job.removed = True | |
495 | for mid in job.dependents: |
|
484 | for mid in job.dependents: | |
496 | if mid in self.graph: |
|
485 | if mid in self.graph: | |
497 | self.graph[mid].remove(msg_id) |
|
486 | self.graph[mid].remove(msg_id) | |
498 |
|
487 | |||
499 | try: |
|
488 | try: | |
500 | raise why() |
|
489 | raise why() | |
501 | except: |
|
490 | except: | |
502 | content = error.wrap_exception() |
|
491 | content = error.wrap_exception() | |
503 | self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename']) |
|
492 | self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename']) | |
504 |
|
493 | |||
505 | self.all_done.add(msg_id) |
|
494 | self.all_done.add(msg_id) | |
506 | self.all_failed.add(msg_id) |
|
495 | self.all_failed.add(msg_id) | |
507 |
|
496 | |||
508 | msg = self.session.send(self.client_stream, 'apply_reply', content, |
|
497 | msg = self.session.send(self.client_stream, 'apply_reply', content, | |
509 | parent=job.header, ident=job.idents) |
|
498 | parent=job.header, ident=job.idents) | |
510 | self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents) |
|
499 | self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents) | |
511 |
|
500 | |||
512 | self.update_graph(msg_id, success=False) |
|
501 | self.update_graph(msg_id, success=False) | |
513 |
|
502 | |||
514 | def available_engines(self): |
|
503 | def available_engines(self): | |
515 | """return a list of available engine indices based on HWM""" |
|
504 | """return a list of available engine indices based on HWM""" | |
516 | if not self.hwm: |
|
505 | if not self.hwm: | |
517 | return list(range(len(self.targets))) |
|
506 | return list(range(len(self.targets))) | |
518 | available = [] |
|
507 | available = [] | |
519 | for idx in range(len(self.targets)): |
|
508 | for idx in range(len(self.targets)): | |
520 | if self.loads[idx] < self.hwm: |
|
509 | if self.loads[idx] < self.hwm: | |
521 | available.append(idx) |
|
510 | available.append(idx) | |
522 | return available |
|
511 | return available | |
523 |
|
512 | |||
524 | def maybe_run(self, job): |
|
513 | def maybe_run(self, job): | |
525 | """check location dependencies, and run if they are met.""" |
|
514 | """check location dependencies, and run if they are met.""" | |
526 | msg_id = job.msg_id |
|
515 | msg_id = job.msg_id | |
527 | self.log.debug("Attempting to assign task %s", msg_id) |
|
516 | self.log.debug("Attempting to assign task %s", msg_id) | |
528 | available = self.available_engines() |
|
517 | available = self.available_engines() | |
529 | if not available: |
|
518 | if not available: | |
530 | # no engines, definitely can't run |
|
519 | # no engines, definitely can't run | |
531 | return False |
|
520 | return False | |
532 |
|
521 | |||
533 | if job.follow or job.targets or job.blacklist or self.hwm: |
|
522 | if job.follow or job.targets or job.blacklist or self.hwm: | |
534 | # we need a can_run filter |
|
523 | # we need a can_run filter | |
535 | def can_run(idx): |
|
524 | def can_run(idx): | |
536 | # check hwm |
|
525 | # check hwm | |
537 | if self.hwm and self.loads[idx] == self.hwm: |
|
526 | if self.hwm and self.loads[idx] == self.hwm: | |
538 | return False |
|
527 | return False | |
539 | target = self.targets[idx] |
|
528 | target = self.targets[idx] | |
540 | # check blacklist |
|
529 | # check blacklist | |
541 | if target in job.blacklist: |
|
530 | if target in job.blacklist: | |
542 | return False |
|
531 | return False | |
543 | # check targets |
|
532 | # check targets | |
544 | if job.targets and target not in job.targets: |
|
533 | if job.targets and target not in job.targets: | |
545 | return False |
|
534 | return False | |
546 | # check follow |
|
535 | # check follow | |
547 | return job.follow.check(self.completed[target], self.failed[target]) |
|
536 | return job.follow.check(self.completed[target], self.failed[target]) | |
548 |
|
537 | |||
549 | indices = list(filter(can_run, available)) |
|
538 | indices = list(filter(can_run, available)) | |
550 |
|
539 | |||
551 | if not indices: |
|
540 | if not indices: | |
552 | # couldn't run |
|
541 | # couldn't run | |
553 | if job.follow.all: |
|
542 | if job.follow.all: | |
554 | # check follow for impossibility |
|
543 | # check follow for impossibility | |
555 | dests = set() |
|
544 | dests = set() | |
556 | relevant = set() |
|
545 | relevant = set() | |
557 | if job.follow.success: |
|
546 | if job.follow.success: | |
558 | relevant = self.all_completed |
|
547 | relevant = self.all_completed | |
559 | if job.follow.failure: |
|
548 | if job.follow.failure: | |
560 | relevant = relevant.union(self.all_failed) |
|
549 | relevant = relevant.union(self.all_failed) | |
561 | for m in job.follow.intersection(relevant): |
|
550 | for m in job.follow.intersection(relevant): | |
562 | dests.add(self.destinations[m]) |
|
551 | dests.add(self.destinations[m]) | |
563 | if len(dests) > 1: |
|
552 | if len(dests) > 1: | |
564 | self.queue_map[msg_id] = job |
|
553 | self.queue_map[msg_id] = job | |
565 | self.fail_unreachable(msg_id) |
|
554 | self.fail_unreachable(msg_id) | |
566 | return False |
|
555 | return False | |
567 | if job.targets: |
|
556 | if job.targets: | |
568 | # check blacklist+targets for impossibility |
|
557 | # check blacklist+targets for impossibility | |
569 | job.targets.difference_update(job.blacklist) |
|
558 | job.targets.difference_update(job.blacklist) | |
570 | if not job.targets or not job.targets.intersection(self.targets): |
|
559 | if not job.targets or not job.targets.intersection(self.targets): | |
571 | self.queue_map[msg_id] = job |
|
560 | self.queue_map[msg_id] = job | |
572 | self.fail_unreachable(msg_id) |
|
561 | self.fail_unreachable(msg_id) | |
573 | return False |
|
562 | return False | |
574 | return False |
|
563 | return False | |
575 | else: |
|
564 | else: | |
576 | indices = None |
|
565 | indices = None | |
577 |
|
566 | |||
578 | self.submit_task(job, indices) |
|
567 | self.submit_task(job, indices) | |
579 | return True |
|
568 | return True | |
580 |
|
569 | |||
581 | def save_unmet(self, job): |
|
570 | def save_unmet(self, job): | |
582 | """Save a message for later submission when its dependencies are met.""" |
|
571 | """Save a message for later submission when its dependencies are met.""" | |
583 | msg_id = job.msg_id |
|
572 | msg_id = job.msg_id | |
584 | self.log.debug("Adding task %s to the queue", msg_id) |
|
573 | self.log.debug("Adding task %s to the queue", msg_id) | |
585 | self.queue_map[msg_id] = job |
|
574 | self.queue_map[msg_id] = job | |
586 | self.queue.append(job) |
|
575 | self.queue.append(job) | |
587 | # track the ids in follow or after, but not those already finished |
|
576 | # track the ids in follow or after, but not those already finished | |
588 | for dep_id in job.after.union(job.follow).difference(self.all_done): |
|
577 | for dep_id in job.after.union(job.follow).difference(self.all_done): | |
589 | if dep_id not in self.graph: |
|
578 | if dep_id not in self.graph: | |
590 | self.graph[dep_id] = set() |
|
579 | self.graph[dep_id] = set() | |
591 | self.graph[dep_id].add(msg_id) |
|
580 | self.graph[dep_id].add(msg_id) | |
592 |
|
581 | |||
593 | # schedule timeout callback |
|
582 | # schedule timeout callback | |
594 | if job.timeout: |
|
583 | if job.timeout: | |
595 | timeout_id = job.timeout_id = job.timeout_id + 1 |
|
584 | timeout_id = job.timeout_id = job.timeout_id + 1 | |
596 | self.loop.add_timeout(time.time() + job.timeout, |
|
585 | self.loop.add_timeout(time.time() + job.timeout, | |
597 | lambda : self.job_timeout(job, timeout_id) |
|
586 | lambda : self.job_timeout(job, timeout_id) | |
598 | ) |
|
587 | ) | |
599 |
|
588 | |||
600 |
|
589 | |||
601 | def submit_task(self, job, indices=None): |
|
590 | def submit_task(self, job, indices=None): | |
602 | """Submit a task to any of a subset of our targets.""" |
|
591 | """Submit a task to any of a subset of our targets.""" | |
603 | if indices: |
|
592 | if indices: | |
604 | loads = [self.loads[i] for i in indices] |
|
593 | loads = [self.loads[i] for i in indices] | |
605 | else: |
|
594 | else: | |
606 | loads = self.loads |
|
595 | loads = self.loads | |
607 | idx = self.scheme(loads) |
|
596 | idx = self.scheme(loads) | |
608 | if indices: |
|
597 | if indices: | |
609 | idx = indices[idx] |
|
598 | idx = indices[idx] | |
610 | target = self.targets[idx] |
|
599 | target = self.targets[idx] | |
611 | # print (target, map(str, msg[:3])) |
|
600 | # print (target, map(str, msg[:3])) | |
612 | # send job to the engine |
|
601 | # send job to the engine | |
613 | self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False) |
|
602 | self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False) | |
614 | self.engine_stream.send_multipart(job.raw_msg, copy=False) |
|
603 | self.engine_stream.send_multipart(job.raw_msg, copy=False) | |
615 | # update load |
|
604 | # update load | |
616 | self.add_job(idx) |
|
605 | self.add_job(idx) | |
617 | self.pending[target][job.msg_id] = job |
|
606 | self.pending[target][job.msg_id] = job | |
618 | # notify Hub |
|
607 | # notify Hub | |
619 | content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii')) |
|
608 | content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii')) | |
620 | self.session.send(self.mon_stream, 'task_destination', content=content, |
|
609 | self.session.send(self.mon_stream, 'task_destination', content=content, | |
621 | ident=[b'tracktask',self.ident]) |
|
610 | ident=[b'tracktask',self.ident]) | |
622 |
|
611 | |||
623 |
|
612 | |||
624 | #----------------------------------------------------------------------- |
|
613 | #----------------------------------------------------------------------- | |
625 | # Result Handling |
|
614 | # Result Handling | |
626 | #----------------------------------------------------------------------- |
|
615 | #----------------------------------------------------------------------- | |
627 |
|
616 | |||
628 |
|
617 | |||
629 | @util.log_errors |
|
618 | @util.log_errors | |
630 | def dispatch_result(self, raw_msg): |
|
619 | def dispatch_result(self, raw_msg): | |
631 | """dispatch method for result replies""" |
|
620 | """dispatch method for result replies""" | |
632 | try: |
|
621 | try: | |
633 | idents,msg = self.session.feed_identities(raw_msg, copy=False) |
|
622 | idents,msg = self.session.feed_identities(raw_msg, copy=False) | |
634 | msg = self.session.unserialize(msg, content=False, copy=False) |
|
623 | msg = self.session.unserialize(msg, content=False, copy=False) | |
635 | engine = idents[0] |
|
624 | engine = idents[0] | |
636 | try: |
|
625 | try: | |
637 | idx = self.targets.index(engine) |
|
626 | idx = self.targets.index(engine) | |
638 | except ValueError: |
|
627 | except ValueError: | |
639 | pass # skip load-update for dead engines |
|
628 | pass # skip load-update for dead engines | |
640 | else: |
|
629 | else: | |
641 | self.finish_job(idx) |
|
630 | self.finish_job(idx) | |
642 | except Exception: |
|
631 | except Exception: | |
643 | self.log.error("task::Invalid result: %r", raw_msg, exc_info=True) |
|
632 | self.log.error("task::Invalid result: %r", raw_msg, exc_info=True) | |
644 | return |
|
633 | return | |
645 |
|
634 | |||
646 | md = msg['metadata'] |
|
635 | md = msg['metadata'] | |
647 | parent = msg['parent_header'] |
|
636 | parent = msg['parent_header'] | |
648 | if md.get('dependencies_met', True): |
|
637 | if md.get('dependencies_met', True): | |
649 | success = (md['status'] == 'ok') |
|
638 | success = (md['status'] == 'ok') | |
650 | msg_id = parent['msg_id'] |
|
639 | msg_id = parent['msg_id'] | |
651 | retries = self.retries[msg_id] |
|
640 | retries = self.retries[msg_id] | |
652 | if not success and retries > 0: |
|
641 | if not success and retries > 0: | |
653 | # failed |
|
642 | # failed | |
654 | self.retries[msg_id] = retries - 1 |
|
643 | self.retries[msg_id] = retries - 1 | |
655 | self.handle_unmet_dependency(idents, parent) |
|
644 | self.handle_unmet_dependency(idents, parent) | |
656 | else: |
|
645 | else: | |
657 | del self.retries[msg_id] |
|
646 | del self.retries[msg_id] | |
658 | # relay to client and update graph |
|
647 | # relay to client and update graph | |
659 | self.handle_result(idents, parent, raw_msg, success) |
|
648 | self.handle_result(idents, parent, raw_msg, success) | |
660 | # send to Hub monitor |
|
649 | # send to Hub monitor | |
661 | self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False) |
|
650 | self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False) | |
662 | else: |
|
651 | else: | |
663 | self.handle_unmet_dependency(idents, parent) |
|
652 | self.handle_unmet_dependency(idents, parent) | |
664 |
|
653 | |||
665 | def handle_result(self, idents, parent, raw_msg, success=True): |
|
654 | def handle_result(self, idents, parent, raw_msg, success=True): | |
666 | """handle a real task result, either success or failure""" |
|
655 | """handle a real task result, either success or failure""" | |
667 | # first, relay result to client |
|
656 | # first, relay result to client | |
668 | engine = idents[0] |
|
657 | engine = idents[0] | |
669 | client = idents[1] |
|
658 | client = idents[1] | |
670 | # swap_ids for ROUTER-ROUTER mirror |
|
659 | # swap_ids for ROUTER-ROUTER mirror | |
671 | raw_msg[:2] = [client,engine] |
|
660 | raw_msg[:2] = [client,engine] | |
672 | # print (map(str, raw_msg[:4])) |
|
661 | # print (map(str, raw_msg[:4])) | |
673 | self.client_stream.send_multipart(raw_msg, copy=False) |
|
662 | self.client_stream.send_multipart(raw_msg, copy=False) | |
674 | # now, update our data structures |
|
663 | # now, update our data structures | |
675 | msg_id = parent['msg_id'] |
|
664 | msg_id = parent['msg_id'] | |
676 | self.pending[engine].pop(msg_id) |
|
665 | self.pending[engine].pop(msg_id) | |
677 | if success: |
|
666 | if success: | |
678 | self.completed[engine].add(msg_id) |
|
667 | self.completed[engine].add(msg_id) | |
679 | self.all_completed.add(msg_id) |
|
668 | self.all_completed.add(msg_id) | |
680 | else: |
|
669 | else: | |
681 | self.failed[engine].add(msg_id) |
|
670 | self.failed[engine].add(msg_id) | |
682 | self.all_failed.add(msg_id) |
|
671 | self.all_failed.add(msg_id) | |
683 | self.all_done.add(msg_id) |
|
672 | self.all_done.add(msg_id) | |
684 | self.destinations[msg_id] = engine |
|
673 | self.destinations[msg_id] = engine | |
685 |
|
674 | |||
686 | self.update_graph(msg_id, success) |
|
675 | self.update_graph(msg_id, success) | |
687 |
|
676 | |||
688 | def handle_unmet_dependency(self, idents, parent): |
|
677 | def handle_unmet_dependency(self, idents, parent): | |
689 | """handle an unmet dependency""" |
|
678 | """handle an unmet dependency""" | |
690 | engine = idents[0] |
|
679 | engine = idents[0] | |
691 | msg_id = parent['msg_id'] |
|
680 | msg_id = parent['msg_id'] | |
692 |
|
681 | |||
693 | job = self.pending[engine].pop(msg_id) |
|
682 | job = self.pending[engine].pop(msg_id) | |
694 | job.blacklist.add(engine) |
|
683 | job.blacklist.add(engine) | |
695 |
|
684 | |||
696 | if job.blacklist == job.targets: |
|
685 | if job.blacklist == job.targets: | |
697 | self.queue_map[msg_id] = job |
|
686 | self.queue_map[msg_id] = job | |
698 | self.fail_unreachable(msg_id) |
|
687 | self.fail_unreachable(msg_id) | |
699 | elif not self.maybe_run(job): |
|
688 | elif not self.maybe_run(job): | |
700 | # resubmit failed |
|
689 | # resubmit failed | |
701 | if msg_id not in self.all_failed: |
|
690 | if msg_id not in self.all_failed: | |
702 | # put it back in our dependency tree |
|
691 | # put it back in our dependency tree | |
703 | self.save_unmet(job) |
|
692 | self.save_unmet(job) | |
704 |
|
693 | |||
705 | if self.hwm: |
|
694 | if self.hwm: | |
706 | try: |
|
695 | try: | |
707 | idx = self.targets.index(engine) |
|
696 | idx = self.targets.index(engine) | |
708 | except ValueError: |
|
697 | except ValueError: | |
709 | pass # skip load-update for dead engines |
|
698 | pass # skip load-update for dead engines | |
710 | else: |
|
699 | else: | |
711 | if self.loads[idx] == self.hwm-1: |
|
700 | if self.loads[idx] == self.hwm-1: | |
712 | self.update_graph(None) |
|
701 | self.update_graph(None) | |
713 |
|
702 | |||
714 | def update_graph(self, dep_id=None, success=True): |
|
703 | def update_graph(self, dep_id=None, success=True): | |
715 | """dep_id just finished. Update our dependency |
|
704 | """dep_id just finished. Update our dependency | |
716 | graph and submit any jobs that just became runnable. |
|
705 | graph and submit any jobs that just became runnable. | |
717 |
|
706 | |||
718 | Called with dep_id=None to update entire graph for hwm, but without finishing a task. |
|
707 | Called with dep_id=None to update entire graph for hwm, but without finishing a task. | |
719 | """ |
|
708 | """ | |
720 | # print ("\n\n***********") |
|
709 | # print ("\n\n***********") | |
721 | # pprint (dep_id) |
|
710 | # pprint (dep_id) | |
722 | # pprint (self.graph) |
|
711 | # pprint (self.graph) | |
723 | # pprint (self.queue_map) |
|
712 | # pprint (self.queue_map) | |
724 | # pprint (self.all_completed) |
|
713 | # pprint (self.all_completed) | |
725 | # pprint (self.all_failed) |
|
714 | # pprint (self.all_failed) | |
726 | # print ("\n\n***********\n\n") |
|
715 | # print ("\n\n***********\n\n") | |
727 | # update any jobs that depended on the dependency |
|
716 | # update any jobs that depended on the dependency | |
728 | msg_ids = self.graph.pop(dep_id, []) |
|
717 | msg_ids = self.graph.pop(dep_id, []) | |
729 |
|
718 | |||
730 | # recheck *all* jobs if |
|
719 | # recheck *all* jobs if | |
731 | # a) we have HWM and an engine just become no longer full |
|
720 | # a) we have HWM and an engine just become no longer full | |
732 | # or b) dep_id was given as None |
|
721 | # or b) dep_id was given as None | |
733 |
|
722 | |||
734 | if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]): |
|
723 | if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]): | |
735 | jobs = self.queue |
|
724 | jobs = self.queue | |
736 | using_queue = True |
|
725 | using_queue = True | |
737 | else: |
|
726 | else: | |
738 | using_queue = False |
|
727 | using_queue = False | |
739 | jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids )) |
|
728 | jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids )) | |
740 |
|
729 | |||
741 | to_restore = [] |
|
730 | to_restore = [] | |
742 | while jobs: |
|
731 | while jobs: | |
743 | job = jobs.popleft() |
|
732 | job = jobs.popleft() | |
744 | if job.removed: |
|
733 | if job.removed: | |
745 | continue |
|
734 | continue | |
746 | msg_id = job.msg_id |
|
735 | msg_id = job.msg_id | |
747 |
|
736 | |||
748 | put_it_back = True |
|
737 | put_it_back = True | |
749 |
|
738 | |||
750 | if job.after.unreachable(self.all_completed, self.all_failed)\ |
|
739 | if job.after.unreachable(self.all_completed, self.all_failed)\ | |
751 | or job.follow.unreachable(self.all_completed, self.all_failed): |
|
740 | or job.follow.unreachable(self.all_completed, self.all_failed): | |
752 | self.fail_unreachable(msg_id) |
|
741 | self.fail_unreachable(msg_id) | |
753 | put_it_back = False |
|
742 | put_it_back = False | |
754 |
|
743 | |||
755 | elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run |
|
744 | elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run | |
756 | if self.maybe_run(job): |
|
745 | if self.maybe_run(job): | |
757 | put_it_back = False |
|
746 | put_it_back = False | |
758 | self.queue_map.pop(msg_id) |
|
747 | self.queue_map.pop(msg_id) | |
759 | for mid in job.dependents: |
|
748 | for mid in job.dependents: | |
760 | if mid in self.graph: |
|
749 | if mid in self.graph: | |
761 | self.graph[mid].remove(msg_id) |
|
750 | self.graph[mid].remove(msg_id) | |
762 |
|
751 | |||
763 | # abort the loop if we just filled up all of our engines. |
|
752 | # abort the loop if we just filled up all of our engines. | |
764 | # avoids an O(N) operation in situation of full queue, |
|
753 | # avoids an O(N) operation in situation of full queue, | |
765 | # where graph update is triggered as soon as an engine becomes |
|
754 | # where graph update is triggered as soon as an engine becomes | |
766 | # non-full, and all tasks after the first are checked, |
|
755 | # non-full, and all tasks after the first are checked, | |
767 | # even though they can't run. |
|
756 | # even though they can't run. | |
768 | if not self.available_engines(): |
|
757 | if not self.available_engines(): | |
769 | break |
|
758 | break | |
770 |
|
759 | |||
771 | if using_queue and put_it_back: |
|
760 | if using_queue and put_it_back: | |
772 | # popped a job from the queue but it neither ran nor failed, |
|
761 | # popped a job from the queue but it neither ran nor failed, | |
773 | # so we need to put it back when we are done |
|
762 | # so we need to put it back when we are done | |
774 | # make sure to_restore preserves the same ordering |
|
763 | # make sure to_restore preserves the same ordering | |
775 | to_restore.append(job) |
|
764 | to_restore.append(job) | |
776 |
|
765 | |||
777 | # put back any tasks we popped but didn't run |
|
766 | # put back any tasks we popped but didn't run | |
778 | if using_queue: |
|
767 | if using_queue: | |
779 | self.queue.extendleft(to_restore) |
|
768 | self.queue.extendleft(to_restore) | |
780 |
|
769 | |||
781 | #---------------------------------------------------------------------- |
|
770 | #---------------------------------------------------------------------- | |
782 | # methods to be overridden by subclasses |
|
771 | # methods to be overridden by subclasses | |
783 | #---------------------------------------------------------------------- |
|
772 | #---------------------------------------------------------------------- | |
784 |
|
773 | |||
785 | def add_job(self, idx): |
|
774 | def add_job(self, idx): | |
786 | """Called after self.targets[idx] just got the job with header. |
|
775 | """Called after self.targets[idx] just got the job with header. | |
787 | Override with subclasses. The default ordering is simple LRU. |
|
776 | Override with subclasses. The default ordering is simple LRU. | |
788 | The default loads are the number of outstanding jobs.""" |
|
777 | The default loads are the number of outstanding jobs.""" | |
789 | self.loads[idx] += 1 |
|
778 | self.loads[idx] += 1 | |
790 | for lis in (self.targets, self.loads): |
|
779 | for lis in (self.targets, self.loads): | |
791 | lis.append(lis.pop(idx)) |
|
780 | lis.append(lis.pop(idx)) | |
792 |
|
781 | |||
793 |
|
782 | |||
794 | def finish_job(self, idx): |
|
783 | def finish_job(self, idx): | |
795 | """Called after self.targets[idx] just finished a job. |
|
784 | """Called after self.targets[idx] just finished a job. | |
796 | Override with subclasses.""" |
|
785 | Override with subclasses.""" | |
797 | self.loads[idx] -= 1 |
|
786 | self.loads[idx] -= 1 | |
798 |
|
787 | |||
799 |
|
788 | |||
800 |
|
789 | |||
801 | def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None, |
|
790 | def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None, | |
802 | logname='root', log_url=None, loglevel=logging.DEBUG, |
|
791 | logname='root', log_url=None, loglevel=logging.DEBUG, | |
803 | identity=b'task', in_thread=False): |
|
792 | identity=b'task', in_thread=False): | |
804 |
|
793 | |||
805 | ZMQStream = zmqstream.ZMQStream |
|
794 | ZMQStream = zmqstream.ZMQStream | |
806 |
|
795 | |||
807 | if config: |
|
796 | if config: | |
808 | # unwrap dict back into Config |
|
797 | # unwrap dict back into Config | |
809 | config = Config(config) |
|
798 | config = Config(config) | |
810 |
|
799 | |||
811 | if in_thread: |
|
800 | if in_thread: | |
812 | # use instance() to get the same Context/Loop as our parent |
|
801 | # use instance() to get the same Context/Loop as our parent | |
813 | ctx = zmq.Context.instance() |
|
802 | ctx = zmq.Context.instance() | |
814 | loop = ioloop.IOLoop.instance() |
|
803 | loop = ioloop.IOLoop.instance() | |
815 | else: |
|
804 | else: | |
816 | # in a process, don't use instance() |
|
805 | # in a process, don't use instance() | |
817 | # for safety with multiprocessing |
|
806 | # for safety with multiprocessing | |
818 | ctx = zmq.Context() |
|
807 | ctx = zmq.Context() | |
819 | loop = ioloop.IOLoop() |
|
808 | loop = ioloop.IOLoop() | |
820 | ins = ZMQStream(ctx.socket(zmq.ROUTER),loop) |
|
809 | ins = ZMQStream(ctx.socket(zmq.ROUTER),loop) | |
821 | util.set_hwm(ins, 0) |
|
810 | util.set_hwm(ins, 0) | |
822 | ins.setsockopt(zmq.IDENTITY, identity + b'_in') |
|
811 | ins.setsockopt(zmq.IDENTITY, identity + b'_in') | |
823 | ins.bind(in_addr) |
|
812 | ins.bind(in_addr) | |
824 |
|
813 | |||
825 | outs = ZMQStream(ctx.socket(zmq.ROUTER),loop) |
|
814 | outs = ZMQStream(ctx.socket(zmq.ROUTER),loop) | |
826 | util.set_hwm(outs, 0) |
|
815 | util.set_hwm(outs, 0) | |
827 | outs.setsockopt(zmq.IDENTITY, identity + b'_out') |
|
816 | outs.setsockopt(zmq.IDENTITY, identity + b'_out') | |
828 | outs.bind(out_addr) |
|
817 | outs.bind(out_addr) | |
829 | mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop) |
|
818 | mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop) | |
830 | util.set_hwm(mons, 0) |
|
819 | util.set_hwm(mons, 0) | |
831 | mons.connect(mon_addr) |
|
820 | mons.connect(mon_addr) | |
832 | nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop) |
|
821 | nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop) | |
833 | nots.setsockopt(zmq.SUBSCRIBE, b'') |
|
822 | nots.setsockopt(zmq.SUBSCRIBE, b'') | |
834 | nots.connect(not_addr) |
|
823 | nots.connect(not_addr) | |
835 |
|
824 | |||
836 | querys = ZMQStream(ctx.socket(zmq.DEALER),loop) |
|
825 | querys = ZMQStream(ctx.socket(zmq.DEALER),loop) | |
837 | querys.connect(reg_addr) |
|
826 | querys.connect(reg_addr) | |
838 |
|
827 | |||
839 | # setup logging. |
|
828 | # setup logging. | |
840 | if in_thread: |
|
829 | if in_thread: | |
841 | log = Application.instance().log |
|
830 | log = Application.instance().log | |
842 | else: |
|
831 | else: | |
843 | if log_url: |
|
832 | if log_url: | |
844 | log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel) |
|
833 | log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel) | |
845 | else: |
|
834 | else: | |
846 | log = local_logger(logname, loglevel) |
|
835 | log = local_logger(logname, loglevel) | |
847 |
|
836 | |||
848 | scheduler = TaskScheduler(client_stream=ins, engine_stream=outs, |
|
837 | scheduler = TaskScheduler(client_stream=ins, engine_stream=outs, | |
849 | mon_stream=mons, notifier_stream=nots, |
|
838 | mon_stream=mons, notifier_stream=nots, | |
850 | query_stream=querys, |
|
839 | query_stream=querys, | |
851 | loop=loop, log=log, |
|
840 | loop=loop, log=log, | |
852 | config=config) |
|
841 | config=config) | |
853 | scheduler.start() |
|
842 | scheduler.start() | |
854 | if not in_thread: |
|
843 | if not in_thread: | |
855 | try: |
|
844 | try: | |
856 | loop.start() |
|
845 | loop.start() | |
857 | except KeyboardInterrupt: |
|
846 | except KeyboardInterrupt: | |
858 | scheduler.log.critical("Interrupted, exiting...") |
|
847 | scheduler.log.critical("Interrupted, exiting...") | |
859 |
|
848 |
@@ -1,388 +1,389 b'' | |||||
1 | """Some generic utilities for dealing with classes, urls, and serialization.""" |
|
1 | """Some generic utilities for dealing with classes, urls, and serialization.""" | |
2 |
|
2 | |||
3 | # Copyright (c) IPython Development Team. |
|
3 | # Copyright (c) IPython Development Team. | |
4 | # Distributed under the terms of the Modified BSD License. |
|
4 | # Distributed under the terms of the Modified BSD License. | |
5 |
|
5 | |||
6 | import logging |
|
6 | import logging | |
7 | import os |
|
7 | import os | |
8 | import re |
|
8 | import re | |
9 | import stat |
|
9 | import stat | |
10 | import socket |
|
10 | import socket | |
11 | import sys |
|
11 | import sys | |
12 | import warnings |
|
12 | import warnings | |
13 | from signal import signal, SIGINT, SIGABRT, SIGTERM |
|
13 | from signal import signal, SIGINT, SIGABRT, SIGTERM | |
14 | try: |
|
14 | try: | |
15 | from signal import SIGKILL |
|
15 | from signal import SIGKILL | |
16 | except ImportError: |
|
16 | except ImportError: | |
17 | SIGKILL=None |
|
17 | SIGKILL=None | |
18 | from types import FunctionType |
|
18 | from types import FunctionType | |
19 |
|
19 | |||
20 | try: |
|
20 | try: | |
21 | import cPickle |
|
21 | import cPickle | |
22 | pickle = cPickle |
|
22 | pickle = cPickle | |
23 | except: |
|
23 | except: | |
24 | cPickle = None |
|
24 | cPickle = None | |
25 | import pickle |
|
25 | import pickle | |
26 |
|
26 | |||
27 | import zmq |
|
27 | import zmq | |
28 | from zmq.log import handlers |
|
28 | from zmq.log import handlers | |
29 |
|
29 | |||
|
30 | from IPython.utils.log import get_logger | |||
30 | from IPython.external.decorator import decorator |
|
31 | from IPython.external.decorator import decorator | |
31 |
|
32 | |||
32 | from IPython.config.application import Application |
|
33 | from IPython.config.application import Application | |
33 | from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips |
|
34 | from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips | |
34 | from IPython.utils.py3compat import string_types, iteritems, itervalues |
|
35 | from IPython.utils.py3compat import string_types, iteritems, itervalues | |
35 | from IPython.kernel.zmq.log import EnginePUBHandler |
|
36 | from IPython.kernel.zmq.log import EnginePUBHandler | |
36 |
|
37 | |||
37 |
|
38 | |||
38 | #----------------------------------------------------------------------------- |
|
39 | #----------------------------------------------------------------------------- | |
39 | # Classes |
|
40 | # Classes | |
40 | #----------------------------------------------------------------------------- |
|
41 | #----------------------------------------------------------------------------- | |
41 |
|
42 | |||
42 | class Namespace(dict): |
|
43 | class Namespace(dict): | |
43 | """Subclass of dict for attribute access to keys.""" |
|
44 | """Subclass of dict for attribute access to keys.""" | |
44 |
|
45 | |||
45 | def __getattr__(self, key): |
|
46 | def __getattr__(self, key): | |
46 | """getattr aliased to getitem""" |
|
47 | """getattr aliased to getitem""" | |
47 | if key in self: |
|
48 | if key in self: | |
48 | return self[key] |
|
49 | return self[key] | |
49 | else: |
|
50 | else: | |
50 | raise NameError(key) |
|
51 | raise NameError(key) | |
51 |
|
52 | |||
52 | def __setattr__(self, key, value): |
|
53 | def __setattr__(self, key, value): | |
53 | """setattr aliased to setitem, with strict""" |
|
54 | """setattr aliased to setitem, with strict""" | |
54 | if hasattr(dict, key): |
|
55 | if hasattr(dict, key): | |
55 | raise KeyError("Cannot override dict keys %r"%key) |
|
56 | raise KeyError("Cannot override dict keys %r"%key) | |
56 | self[key] = value |
|
57 | self[key] = value | |
57 |
|
58 | |||
58 |
|
59 | |||
59 | class ReverseDict(dict): |
|
60 | class ReverseDict(dict): | |
60 | """simple double-keyed subset of dict methods.""" |
|
61 | """simple double-keyed subset of dict methods.""" | |
61 |
|
62 | |||
62 | def __init__(self, *args, **kwargs): |
|
63 | def __init__(self, *args, **kwargs): | |
63 | dict.__init__(self, *args, **kwargs) |
|
64 | dict.__init__(self, *args, **kwargs) | |
64 | self._reverse = dict() |
|
65 | self._reverse = dict() | |
65 | for key, value in iteritems(self): |
|
66 | for key, value in iteritems(self): | |
66 | self._reverse[value] = key |
|
67 | self._reverse[value] = key | |
67 |
|
68 | |||
68 | def __getitem__(self, key): |
|
69 | def __getitem__(self, key): | |
69 | try: |
|
70 | try: | |
70 | return dict.__getitem__(self, key) |
|
71 | return dict.__getitem__(self, key) | |
71 | except KeyError: |
|
72 | except KeyError: | |
72 | return self._reverse[key] |
|
73 | return self._reverse[key] | |
73 |
|
74 | |||
74 | def __setitem__(self, key, value): |
|
75 | def __setitem__(self, key, value): | |
75 | if key in self._reverse: |
|
76 | if key in self._reverse: | |
76 | raise KeyError("Can't have key %r on both sides!"%key) |
|
77 | raise KeyError("Can't have key %r on both sides!"%key) | |
77 | dict.__setitem__(self, key, value) |
|
78 | dict.__setitem__(self, key, value) | |
78 | self._reverse[value] = key |
|
79 | self._reverse[value] = key | |
79 |
|
80 | |||
80 | def pop(self, key): |
|
81 | def pop(self, key): | |
81 | value = dict.pop(self, key) |
|
82 | value = dict.pop(self, key) | |
82 | self._reverse.pop(value) |
|
83 | self._reverse.pop(value) | |
83 | return value |
|
84 | return value | |
84 |
|
85 | |||
85 | def get(self, key, default=None): |
|
86 | def get(self, key, default=None): | |
86 | try: |
|
87 | try: | |
87 | return self[key] |
|
88 | return self[key] | |
88 | except KeyError: |
|
89 | except KeyError: | |
89 | return default |
|
90 | return default | |
90 |
|
91 | |||
91 | #----------------------------------------------------------------------------- |
|
92 | #----------------------------------------------------------------------------- | |
92 | # Functions |
|
93 | # Functions | |
93 | #----------------------------------------------------------------------------- |
|
94 | #----------------------------------------------------------------------------- | |
94 |
|
95 | |||
95 | @decorator |
|
96 | @decorator | |
96 | def log_errors(f, self, *args, **kwargs): |
|
97 | def log_errors(f, self, *args, **kwargs): | |
97 | """decorator to log unhandled exceptions raised in a method. |
|
98 | """decorator to log unhandled exceptions raised in a method. | |
98 |
|
99 | |||
99 | For use wrapping on_recv callbacks, so that exceptions |
|
100 | For use wrapping on_recv callbacks, so that exceptions | |
100 | do not cause the stream to be closed. |
|
101 | do not cause the stream to be closed. | |
101 | """ |
|
102 | """ | |
102 | try: |
|
103 | try: | |
103 | return f(self, *args, **kwargs) |
|
104 | return f(self, *args, **kwargs) | |
104 | except Exception: |
|
105 | except Exception: | |
105 | self.log.error("Uncaught exception in %r" % f, exc_info=True) |
|
106 | self.log.error("Uncaught exception in %r" % f, exc_info=True) | |
106 |
|
107 | |||
107 |
|
108 | |||
108 | def is_url(url): |
|
109 | def is_url(url): | |
109 | """boolean check for whether a string is a zmq url""" |
|
110 | """boolean check for whether a string is a zmq url""" | |
110 | if '://' not in url: |
|
111 | if '://' not in url: | |
111 | return False |
|
112 | return False | |
112 | proto, addr = url.split('://', 1) |
|
113 | proto, addr = url.split('://', 1) | |
113 | if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']: |
|
114 | if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']: | |
114 | return False |
|
115 | return False | |
115 | return True |
|
116 | return True | |
116 |
|
117 | |||
117 | def validate_url(url): |
|
118 | def validate_url(url): | |
118 | """validate a url for zeromq""" |
|
119 | """validate a url for zeromq""" | |
119 | if not isinstance(url, string_types): |
|
120 | if not isinstance(url, string_types): | |
120 | raise TypeError("url must be a string, not %r"%type(url)) |
|
121 | raise TypeError("url must be a string, not %r"%type(url)) | |
121 | url = url.lower() |
|
122 | url = url.lower() | |
122 |
|
123 | |||
123 | proto_addr = url.split('://') |
|
124 | proto_addr = url.split('://') | |
124 | assert len(proto_addr) == 2, 'Invalid url: %r'%url |
|
125 | assert len(proto_addr) == 2, 'Invalid url: %r'%url | |
125 | proto, addr = proto_addr |
|
126 | proto, addr = proto_addr | |
126 | assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto |
|
127 | assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto | |
127 |
|
128 | |||
128 | # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391 |
|
129 | # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391 | |
129 | # author: Remi Sabourin |
|
130 | # author: Remi Sabourin | |
130 | pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$') |
|
131 | pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$') | |
131 |
|
132 | |||
132 | if proto == 'tcp': |
|
133 | if proto == 'tcp': | |
133 | lis = addr.split(':') |
|
134 | lis = addr.split(':') | |
134 | assert len(lis) == 2, 'Invalid url: %r'%url |
|
135 | assert len(lis) == 2, 'Invalid url: %r'%url | |
135 | addr,s_port = lis |
|
136 | addr,s_port = lis | |
136 | try: |
|
137 | try: | |
137 | port = int(s_port) |
|
138 | port = int(s_port) | |
138 | except ValueError: |
|
139 | except ValueError: | |
139 | raise AssertionError("Invalid port %r in url: %r"%(port, url)) |
|
140 | raise AssertionError("Invalid port %r in url: %r"%(port, url)) | |
140 |
|
141 | |||
141 | assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url |
|
142 | assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url | |
142 |
|
143 | |||
143 | else: |
|
144 | else: | |
144 | # only validate tcp urls currently |
|
145 | # only validate tcp urls currently | |
145 | pass |
|
146 | pass | |
146 |
|
147 | |||
147 | return True |
|
148 | return True | |
148 |
|
149 | |||
149 |
|
150 | |||
150 | def validate_url_container(container): |
|
151 | def validate_url_container(container): | |
151 | """validate a potentially nested collection of urls.""" |
|
152 | """validate a potentially nested collection of urls.""" | |
152 | if isinstance(container, string_types): |
|
153 | if isinstance(container, string_types): | |
153 | url = container |
|
154 | url = container | |
154 | return validate_url(url) |
|
155 | return validate_url(url) | |
155 | elif isinstance(container, dict): |
|
156 | elif isinstance(container, dict): | |
156 | container = itervalues(container) |
|
157 | container = itervalues(container) | |
157 |
|
158 | |||
158 | for element in container: |
|
159 | for element in container: | |
159 | validate_url_container(element) |
|
160 | validate_url_container(element) | |
160 |
|
161 | |||
161 |
|
162 | |||
162 | def split_url(url): |
|
163 | def split_url(url): | |
163 | """split a zmq url (tcp://ip:port) into ('tcp','ip','port').""" |
|
164 | """split a zmq url (tcp://ip:port) into ('tcp','ip','port').""" | |
164 | proto_addr = url.split('://') |
|
165 | proto_addr = url.split('://') | |
165 | assert len(proto_addr) == 2, 'Invalid url: %r'%url |
|
166 | assert len(proto_addr) == 2, 'Invalid url: %r'%url | |
166 | proto, addr = proto_addr |
|
167 | proto, addr = proto_addr | |
167 | lis = addr.split(':') |
|
168 | lis = addr.split(':') | |
168 | assert len(lis) == 2, 'Invalid url: %r'%url |
|
169 | assert len(lis) == 2, 'Invalid url: %r'%url | |
169 | addr,s_port = lis |
|
170 | addr,s_port = lis | |
170 | return proto,addr,s_port |
|
171 | return proto,addr,s_port | |
171 |
|
172 | |||
172 |
|
173 | |||
173 | def disambiguate_ip_address(ip, location=None): |
|
174 | def disambiguate_ip_address(ip, location=None): | |
174 | """turn multi-ip interfaces '0.0.0.0' and '*' into a connectable address |
|
175 | """turn multi-ip interfaces '0.0.0.0' and '*' into a connectable address | |
175 |
|
176 | |||
176 | Explicit IP addresses are returned unmodified. |
|
177 | Explicit IP addresses are returned unmodified. | |
177 |
|
178 | |||
178 | Parameters |
|
179 | Parameters | |
179 | ---------- |
|
180 | ---------- | |
180 |
|
181 | |||
181 | ip : IP address |
|
182 | ip : IP address | |
182 | An IP address, or the special values 0.0.0.0, or * |
|
183 | An IP address, or the special values 0.0.0.0, or * | |
183 | location: IP address, optional |
|
184 | location: IP address, optional | |
184 | A public IP of the target machine. |
|
185 | A public IP of the target machine. | |
185 | If location is an IP of the current machine, |
|
186 | If location is an IP of the current machine, | |
186 | localhost will be returned, |
|
187 | localhost will be returned, | |
187 | otherwise location will be returned. |
|
188 | otherwise location will be returned. | |
188 | """ |
|
189 | """ | |
189 | if ip in {'0.0.0.0', '*'}: |
|
190 | if ip in {'0.0.0.0', '*'}: | |
190 | if not location: |
|
191 | if not location: | |
191 | # unspecified location, localhost is the only choice |
|
192 | # unspecified location, localhost is the only choice | |
192 | ip = localhost() |
|
193 | ip = localhost() | |
193 | elif is_public_ip(location): |
|
194 | elif is_public_ip(location): | |
194 | # location is a public IP on this machine, use localhost |
|
195 | # location is a public IP on this machine, use localhost | |
195 | ip = localhost() |
|
196 | ip = localhost() | |
196 | elif not public_ips(): |
|
197 | elif not public_ips(): | |
197 | # this machine's public IPs cannot be determined, |
|
198 | # this machine's public IPs cannot be determined, | |
198 | # assume `location` is not this machine |
|
199 | # assume `location` is not this machine | |
199 | warnings.warn("IPython could not determine public IPs", RuntimeWarning) |
|
200 | warnings.warn("IPython could not determine public IPs", RuntimeWarning) | |
200 | ip = location |
|
201 | ip = location | |
201 | else: |
|
202 | else: | |
202 | # location is not this machine, do not use loopback |
|
203 | # location is not this machine, do not use loopback | |
203 | ip = location |
|
204 | ip = location | |
204 | return ip |
|
205 | return ip | |
205 |
|
206 | |||
206 |
|
207 | |||
207 | def disambiguate_url(url, location=None): |
|
208 | def disambiguate_url(url, location=None): | |
208 | """turn multi-ip interfaces '0.0.0.0' and '*' into connectable |
|
209 | """turn multi-ip interfaces '0.0.0.0' and '*' into connectable | |
209 | ones, based on the location (default interpretation is localhost). |
|
210 | ones, based on the location (default interpretation is localhost). | |
210 |
|
211 | |||
211 | This is for zeromq urls, such as ``tcp://*:10101``. |
|
212 | This is for zeromq urls, such as ``tcp://*:10101``. | |
212 | """ |
|
213 | """ | |
213 | try: |
|
214 | try: | |
214 | proto,ip,port = split_url(url) |
|
215 | proto,ip,port = split_url(url) | |
215 | except AssertionError: |
|
216 | except AssertionError: | |
216 | # probably not tcp url; could be ipc, etc. |
|
217 | # probably not tcp url; could be ipc, etc. | |
217 | return url |
|
218 | return url | |
218 |
|
219 | |||
219 | ip = disambiguate_ip_address(ip,location) |
|
220 | ip = disambiguate_ip_address(ip,location) | |
220 |
|
221 | |||
221 | return "%s://%s:%s"%(proto,ip,port) |
|
222 | return "%s://%s:%s"%(proto,ip,port) | |
222 |
|
223 | |||
223 |
|
224 | |||
224 | #-------------------------------------------------------------------------- |
|
225 | #-------------------------------------------------------------------------- | |
225 | # helpers for implementing old MEC API via view.apply |
|
226 | # helpers for implementing old MEC API via view.apply | |
226 | #-------------------------------------------------------------------------- |
|
227 | #-------------------------------------------------------------------------- | |
227 |
|
228 | |||
228 | def interactive(f): |
|
229 | def interactive(f): | |
229 | """decorator for making functions appear as interactively defined. |
|
230 | """decorator for making functions appear as interactively defined. | |
230 | This results in the function being linked to the user_ns as globals() |
|
231 | This results in the function being linked to the user_ns as globals() | |
231 | instead of the module globals(). |
|
232 | instead of the module globals(). | |
232 | """ |
|
233 | """ | |
233 |
|
234 | |||
234 | # build new FunctionType, so it can have the right globals |
|
235 | # build new FunctionType, so it can have the right globals | |
235 | # interactive functions never have closures, that's kind of the point |
|
236 | # interactive functions never have closures, that's kind of the point | |
236 | if isinstance(f, FunctionType): |
|
237 | if isinstance(f, FunctionType): | |
237 | mainmod = __import__('__main__') |
|
238 | mainmod = __import__('__main__') | |
238 | f = FunctionType(f.__code__, mainmod.__dict__, |
|
239 | f = FunctionType(f.__code__, mainmod.__dict__, | |
239 | f.__name__, f.__defaults__, |
|
240 | f.__name__, f.__defaults__, | |
240 | ) |
|
241 | ) | |
241 | # associate with __main__ for uncanning |
|
242 | # associate with __main__ for uncanning | |
242 | f.__module__ = '__main__' |
|
243 | f.__module__ = '__main__' | |
243 | return f |
|
244 | return f | |
244 |
|
245 | |||
245 | @interactive |
|
246 | @interactive | |
246 | def _push(**ns): |
|
247 | def _push(**ns): | |
247 | """helper method for implementing `client.push` via `client.apply`""" |
|
248 | """helper method for implementing `client.push` via `client.apply`""" | |
248 | user_ns = globals() |
|
249 | user_ns = globals() | |
249 | tmp = '_IP_PUSH_TMP_' |
|
250 | tmp = '_IP_PUSH_TMP_' | |
250 | while tmp in user_ns: |
|
251 | while tmp in user_ns: | |
251 | tmp = tmp + '_' |
|
252 | tmp = tmp + '_' | |
252 | try: |
|
253 | try: | |
253 | for name, value in ns.items(): |
|
254 | for name, value in ns.items(): | |
254 | user_ns[tmp] = value |
|
255 | user_ns[tmp] = value | |
255 | exec("%s = %s" % (name, tmp), user_ns) |
|
256 | exec("%s = %s" % (name, tmp), user_ns) | |
256 | finally: |
|
257 | finally: | |
257 | user_ns.pop(tmp, None) |
|
258 | user_ns.pop(tmp, None) | |
258 |
|
259 | |||
259 | @interactive |
|
260 | @interactive | |
260 | def _pull(keys): |
|
261 | def _pull(keys): | |
261 | """helper method for implementing `client.pull` via `client.apply`""" |
|
262 | """helper method for implementing `client.pull` via `client.apply`""" | |
262 | if isinstance(keys, (list,tuple, set)): |
|
263 | if isinstance(keys, (list,tuple, set)): | |
263 | return [eval(key, globals()) for key in keys] |
|
264 | return [eval(key, globals()) for key in keys] | |
264 | else: |
|
265 | else: | |
265 | return eval(keys, globals()) |
|
266 | return eval(keys, globals()) | |
266 |
|
267 | |||
267 | @interactive |
|
268 | @interactive | |
268 | def _execute(code): |
|
269 | def _execute(code): | |
269 | """helper method for implementing `client.execute` via `client.apply`""" |
|
270 | """helper method for implementing `client.execute` via `client.apply`""" | |
270 | exec(code, globals()) |
|
271 | exec(code, globals()) | |
271 |
|
272 | |||
272 | #-------------------------------------------------------------------------- |
|
273 | #-------------------------------------------------------------------------- | |
273 | # extra process management utilities |
|
274 | # extra process management utilities | |
274 | #-------------------------------------------------------------------------- |
|
275 | #-------------------------------------------------------------------------- | |
275 |
|
276 | |||
276 | _random_ports = set() |
|
277 | _random_ports = set() | |
277 |
|
278 | |||
278 | def select_random_ports(n): |
|
279 | def select_random_ports(n): | |
279 | """Selects and return n random ports that are available.""" |
|
280 | """Selects and return n random ports that are available.""" | |
280 | ports = [] |
|
281 | ports = [] | |
281 | for i in range(n): |
|
282 | for i in range(n): | |
282 | sock = socket.socket() |
|
283 | sock = socket.socket() | |
283 | sock.bind(('', 0)) |
|
284 | sock.bind(('', 0)) | |
284 | while sock.getsockname()[1] in _random_ports: |
|
285 | while sock.getsockname()[1] in _random_ports: | |
285 | sock.close() |
|
286 | sock.close() | |
286 | sock = socket.socket() |
|
287 | sock = socket.socket() | |
287 | sock.bind(('', 0)) |
|
288 | sock.bind(('', 0)) | |
288 | ports.append(sock) |
|
289 | ports.append(sock) | |
289 | for i, sock in enumerate(ports): |
|
290 | for i, sock in enumerate(ports): | |
290 | port = sock.getsockname()[1] |
|
291 | port = sock.getsockname()[1] | |
291 | sock.close() |
|
292 | sock.close() | |
292 | ports[i] = port |
|
293 | ports[i] = port | |
293 | _random_ports.add(port) |
|
294 | _random_ports.add(port) | |
294 | return ports |
|
295 | return ports | |
295 |
|
296 | |||
296 | def signal_children(children): |
|
297 | def signal_children(children): | |
297 | """Relay interupt/term signals to children, for more solid process cleanup.""" |
|
298 | """Relay interupt/term signals to children, for more solid process cleanup.""" | |
298 | def terminate_children(sig, frame): |
|
299 | def terminate_children(sig, frame): | |
299 | log = Application.instance().log |
|
300 | log = get_logger() | |
300 | log.critical("Got signal %i, terminating children..."%sig) |
|
301 | log.critical("Got signal %i, terminating children..."%sig) | |
301 | for child in children: |
|
302 | for child in children: | |
302 | child.terminate() |
|
303 | child.terminate() | |
303 |
|
304 | |||
304 | sys.exit(sig != SIGINT) |
|
305 | sys.exit(sig != SIGINT) | |
305 | # sys.exit(sig) |
|
306 | # sys.exit(sig) | |
306 | for sig in (SIGINT, SIGABRT, SIGTERM): |
|
307 | for sig in (SIGINT, SIGABRT, SIGTERM): | |
307 | signal(sig, terminate_children) |
|
308 | signal(sig, terminate_children) | |
308 |
|
309 | |||
309 | def generate_exec_key(keyfile): |
|
310 | def generate_exec_key(keyfile): | |
310 | import uuid |
|
311 | import uuid | |
311 | newkey = str(uuid.uuid4()) |
|
312 | newkey = str(uuid.uuid4()) | |
312 | with open(keyfile, 'w') as f: |
|
313 | with open(keyfile, 'w') as f: | |
313 | # f.write('ipython-key ') |
|
314 | # f.write('ipython-key ') | |
314 | f.write(newkey+'\n') |
|
315 | f.write(newkey+'\n') | |
315 | # set user-only RW permissions (0600) |
|
316 | # set user-only RW permissions (0600) | |
316 | # this will have no effect on Windows |
|
317 | # this will have no effect on Windows | |
317 | os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR) |
|
318 | os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR) | |
318 |
|
319 | |||
319 |
|
320 | |||
320 | def integer_loglevel(loglevel): |
|
321 | def integer_loglevel(loglevel): | |
321 | try: |
|
322 | try: | |
322 | loglevel = int(loglevel) |
|
323 | loglevel = int(loglevel) | |
323 | except ValueError: |
|
324 | except ValueError: | |
324 | if isinstance(loglevel, str): |
|
325 | if isinstance(loglevel, str): | |
325 | loglevel = getattr(logging, loglevel) |
|
326 | loglevel = getattr(logging, loglevel) | |
326 | return loglevel |
|
327 | return loglevel | |
327 |
|
328 | |||
328 | def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG): |
|
329 | def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG): | |
329 | logger = logging.getLogger(logname) |
|
330 | logger = logging.getLogger(logname) | |
330 | if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]): |
|
331 | if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]): | |
331 | # don't add a second PUBHandler |
|
332 | # don't add a second PUBHandler | |
332 | return |
|
333 | return | |
333 | loglevel = integer_loglevel(loglevel) |
|
334 | loglevel = integer_loglevel(loglevel) | |
334 | lsock = context.socket(zmq.PUB) |
|
335 | lsock = context.socket(zmq.PUB) | |
335 | lsock.connect(iface) |
|
336 | lsock.connect(iface) | |
336 | handler = handlers.PUBHandler(lsock) |
|
337 | handler = handlers.PUBHandler(lsock) | |
337 | handler.setLevel(loglevel) |
|
338 | handler.setLevel(loglevel) | |
338 | handler.root_topic = root |
|
339 | handler.root_topic = root | |
339 | logger.addHandler(handler) |
|
340 | logger.addHandler(handler) | |
340 | logger.setLevel(loglevel) |
|
341 | logger.setLevel(loglevel) | |
341 |
|
342 | |||
342 | def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG): |
|
343 | def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG): | |
343 | logger = logging.getLogger() |
|
344 | logger = logging.getLogger() | |
344 | if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]): |
|
345 | if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]): | |
345 | # don't add a second PUBHandler |
|
346 | # don't add a second PUBHandler | |
346 | return |
|
347 | return | |
347 | loglevel = integer_loglevel(loglevel) |
|
348 | loglevel = integer_loglevel(loglevel) | |
348 | lsock = context.socket(zmq.PUB) |
|
349 | lsock = context.socket(zmq.PUB) | |
349 | lsock.connect(iface) |
|
350 | lsock.connect(iface) | |
350 | handler = EnginePUBHandler(engine, lsock) |
|
351 | handler = EnginePUBHandler(engine, lsock) | |
351 | handler.setLevel(loglevel) |
|
352 | handler.setLevel(loglevel) | |
352 | logger.addHandler(handler) |
|
353 | logger.addHandler(handler) | |
353 | logger.setLevel(loglevel) |
|
354 | logger.setLevel(loglevel) | |
354 | return logger |
|
355 | return logger | |
355 |
|
356 | |||
356 | def local_logger(logname, loglevel=logging.DEBUG): |
|
357 | def local_logger(logname, loglevel=logging.DEBUG): | |
357 | loglevel = integer_loglevel(loglevel) |
|
358 | loglevel = integer_loglevel(loglevel) | |
358 | logger = logging.getLogger(logname) |
|
359 | logger = logging.getLogger(logname) | |
359 | if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]): |
|
360 | if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]): | |
360 | # don't add a second StreamHandler |
|
361 | # don't add a second StreamHandler | |
361 | return |
|
362 | return | |
362 | handler = logging.StreamHandler() |
|
363 | handler = logging.StreamHandler() | |
363 | handler.setLevel(loglevel) |
|
364 | handler.setLevel(loglevel) | |
364 | formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s", |
|
365 | formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s", | |
365 | datefmt="%Y-%m-%d %H:%M:%S") |
|
366 | datefmt="%Y-%m-%d %H:%M:%S") | |
366 | handler.setFormatter(formatter) |
|
367 | handler.setFormatter(formatter) | |
367 |
|
368 | |||
368 | logger.addHandler(handler) |
|
369 | logger.addHandler(handler) | |
369 | logger.setLevel(loglevel) |
|
370 | logger.setLevel(loglevel) | |
370 | return logger |
|
371 | return logger | |
371 |
|
372 | |||
372 | def set_hwm(sock, hwm=0): |
|
373 | def set_hwm(sock, hwm=0): | |
373 | """set zmq High Water Mark on a socket |
|
374 | """set zmq High Water Mark on a socket | |
374 |
|
375 | |||
375 | in a way that always works for various pyzmq / libzmq versions. |
|
376 | in a way that always works for various pyzmq / libzmq versions. | |
376 | """ |
|
377 | """ | |
377 | import zmq |
|
378 | import zmq | |
378 |
|
379 | |||
379 | for key in ('HWM', 'SNDHWM', 'RCVHWM'): |
|
380 | for key in ('HWM', 'SNDHWM', 'RCVHWM'): | |
380 | opt = getattr(zmq, key, None) |
|
381 | opt = getattr(zmq, key, None) | |
381 | if opt is None: |
|
382 | if opt is None: | |
382 | continue |
|
383 | continue | |
383 | try: |
|
384 | try: | |
384 | sock.setsockopt(opt, hwm) |
|
385 | sock.setsockopt(opt, hwm) | |
385 | except zmq.ZMQError: |
|
386 | except zmq.ZMQError: | |
386 | pass |
|
387 | pass | |
387 |
|
388 | |||
388 |
|
389 |
@@ -1,433 +1,420 b'' | |||||
1 | # encoding: utf-8 |
|
1 | # encoding: utf-8 | |
2 | """Pickle related utilities. Perhaps this should be called 'can'.""" |
|
2 | """Pickle related utilities. Perhaps this should be called 'can'.""" | |
3 |
|
3 | |||
4 | # Copyright (c) IPython Development Team. |
|
4 | # Copyright (c) IPython Development Team. | |
5 | # Distributed under the terms of the Modified BSD License. |
|
5 | # Distributed under the terms of the Modified BSD License. | |
6 |
|
6 | |||
7 | import copy |
|
7 | import copy | |
8 | import logging |
|
8 | import logging | |
9 | import sys |
|
9 | import sys | |
10 | from types import FunctionType |
|
10 | from types import FunctionType | |
11 |
|
11 | |||
12 | try: |
|
12 | try: | |
13 | import cPickle as pickle |
|
13 | import cPickle as pickle | |
14 | except ImportError: |
|
14 | except ImportError: | |
15 | import pickle |
|
15 | import pickle | |
16 |
|
16 | |||
17 | from . import codeutil # This registers a hook when it's imported |
|
17 | from . import codeutil # This registers a hook when it's imported | |
18 | from . import py3compat |
|
18 | from . import py3compat | |
19 | from .importstring import import_item |
|
19 | from .importstring import import_item | |
20 | from .py3compat import string_types, iteritems |
|
20 | from .py3compat import string_types, iteritems | |
21 |
|
21 | |||
22 | from IPython.config import Application |
|
22 | from IPython.config import Application | |
|
23 | from IPython.utils.log import get_logger | |||
23 |
|
24 | |||
24 | if py3compat.PY3: |
|
25 | if py3compat.PY3: | |
25 | buffer = memoryview |
|
26 | buffer = memoryview | |
26 | class_type = type |
|
27 | class_type = type | |
27 | else: |
|
28 | else: | |
28 | from types import ClassType |
|
29 | from types import ClassType | |
29 | class_type = (type, ClassType) |
|
30 | class_type = (type, ClassType) | |
30 |
|
31 | |||
31 | def _get_cell_type(a=None): |
|
32 | def _get_cell_type(a=None): | |
32 | """the type of a closure cell doesn't seem to be importable, |
|
33 | """the type of a closure cell doesn't seem to be importable, | |
33 | so just create one |
|
34 | so just create one | |
34 | """ |
|
35 | """ | |
35 | def inner(): |
|
36 | def inner(): | |
36 | return a |
|
37 | return a | |
37 | return type(py3compat.get_closure(inner)[0]) |
|
38 | return type(py3compat.get_closure(inner)[0]) | |
38 |
|
39 | |||
39 | cell_type = _get_cell_type() |
|
40 | cell_type = _get_cell_type() | |
40 |
|
41 | |||
41 | #------------------------------------------------------------------------------- |
|
42 | #------------------------------------------------------------------------------- | |
42 | # Functions |
|
43 | # Functions | |
43 | #------------------------------------------------------------------------------- |
|
44 | #------------------------------------------------------------------------------- | |
44 |
|
45 | |||
45 |
|
46 | |||
46 | def use_dill(): |
|
47 | def use_dill(): | |
47 | """use dill to expand serialization support |
|
48 | """use dill to expand serialization support | |
48 |
|
49 | |||
49 | adds support for object methods and closures to serialization. |
|
50 | adds support for object methods and closures to serialization. | |
50 | """ |
|
51 | """ | |
51 | # import dill causes most of the magic |
|
52 | # import dill causes most of the magic | |
52 | import dill |
|
53 | import dill | |
53 |
|
54 | |||
54 | # dill doesn't work with cPickle, |
|
55 | # dill doesn't work with cPickle, | |
55 | # tell the two relevant modules to use plain pickle |
|
56 | # tell the two relevant modules to use plain pickle | |
56 |
|
57 | |||
57 | global pickle |
|
58 | global pickle | |
58 | pickle = dill |
|
59 | pickle = dill | |
59 |
|
60 | |||
60 | try: |
|
61 | try: | |
61 | from IPython.kernel.zmq import serialize |
|
62 | from IPython.kernel.zmq import serialize | |
62 | except ImportError: |
|
63 | except ImportError: | |
63 | pass |
|
64 | pass | |
64 | else: |
|
65 | else: | |
65 | serialize.pickle = dill |
|
66 | serialize.pickle = dill | |
66 |
|
67 | |||
67 | # disable special function handling, let dill take care of it |
|
68 | # disable special function handling, let dill take care of it | |
68 | can_map.pop(FunctionType, None) |
|
69 | can_map.pop(FunctionType, None) | |
69 |
|
70 | |||
70 | def use_cloudpickle(): |
|
71 | def use_cloudpickle(): | |
71 | """use cloudpickle to expand serialization support |
|
72 | """use cloudpickle to expand serialization support | |
72 |
|
73 | |||
73 | adds support for object methods and closures to serialization. |
|
74 | adds support for object methods and closures to serialization. | |
74 | """ |
|
75 | """ | |
75 | from cloud.serialization import cloudpickle |
|
76 | from cloud.serialization import cloudpickle | |
76 |
|
77 | |||
77 | global pickle |
|
78 | global pickle | |
78 | pickle = cloudpickle |
|
79 | pickle = cloudpickle | |
79 |
|
80 | |||
80 | try: |
|
81 | try: | |
81 | from IPython.kernel.zmq import serialize |
|
82 | from IPython.kernel.zmq import serialize | |
82 | except ImportError: |
|
83 | except ImportError: | |
83 | pass |
|
84 | pass | |
84 | else: |
|
85 | else: | |
85 | serialize.pickle = cloudpickle |
|
86 | serialize.pickle = cloudpickle | |
86 |
|
87 | |||
87 | # disable special function handling, let cloudpickle take care of it |
|
88 | # disable special function handling, let cloudpickle take care of it | |
88 | can_map.pop(FunctionType, None) |
|
89 | can_map.pop(FunctionType, None) | |
89 |
|
90 | |||
90 |
|
91 | |||
91 | #------------------------------------------------------------------------------- |
|
92 | #------------------------------------------------------------------------------- | |
92 | # Classes |
|
93 | # Classes | |
93 | #------------------------------------------------------------------------------- |
|
94 | #------------------------------------------------------------------------------- | |
94 |
|
95 | |||
95 |
|
96 | |||
96 | class CannedObject(object): |
|
97 | class CannedObject(object): | |
97 | def __init__(self, obj, keys=[], hook=None): |
|
98 | def __init__(self, obj, keys=[], hook=None): | |
98 | """can an object for safe pickling |
|
99 | """can an object for safe pickling | |
99 |
|
100 | |||
100 | Parameters |
|
101 | Parameters | |
101 | ========== |
|
102 | ========== | |
102 |
|
103 | |||
103 | obj: |
|
104 | obj: | |
104 | The object to be canned |
|
105 | The object to be canned | |
105 | keys: list (optional) |
|
106 | keys: list (optional) | |
106 | list of attribute names that will be explicitly canned / uncanned |
|
107 | list of attribute names that will be explicitly canned / uncanned | |
107 | hook: callable (optional) |
|
108 | hook: callable (optional) | |
108 | An optional extra callable, |
|
109 | An optional extra callable, | |
109 | which can do additional processing of the uncanned object. |
|
110 | which can do additional processing of the uncanned object. | |
110 |
|
111 | |||
111 | large data may be offloaded into the buffers list, |
|
112 | large data may be offloaded into the buffers list, | |
112 | used for zero-copy transfers. |
|
113 | used for zero-copy transfers. | |
113 | """ |
|
114 | """ | |
114 | self.keys = keys |
|
115 | self.keys = keys | |
115 | self.obj = copy.copy(obj) |
|
116 | self.obj = copy.copy(obj) | |
116 | self.hook = can(hook) |
|
117 | self.hook = can(hook) | |
117 | for key in keys: |
|
118 | for key in keys: | |
118 | setattr(self.obj, key, can(getattr(obj, key))) |
|
119 | setattr(self.obj, key, can(getattr(obj, key))) | |
119 |
|
120 | |||
120 | self.buffers = [] |
|
121 | self.buffers = [] | |
121 |
|
122 | |||
122 | def get_object(self, g=None): |
|
123 | def get_object(self, g=None): | |
123 | if g is None: |
|
124 | if g is None: | |
124 | g = {} |
|
125 | g = {} | |
125 | obj = self.obj |
|
126 | obj = self.obj | |
126 | for key in self.keys: |
|
127 | for key in self.keys: | |
127 | setattr(obj, key, uncan(getattr(obj, key), g)) |
|
128 | setattr(obj, key, uncan(getattr(obj, key), g)) | |
128 |
|
129 | |||
129 | if self.hook: |
|
130 | if self.hook: | |
130 | self.hook = uncan(self.hook, g) |
|
131 | self.hook = uncan(self.hook, g) | |
131 | self.hook(obj, g) |
|
132 | self.hook(obj, g) | |
132 | return self.obj |
|
133 | return self.obj | |
133 |
|
134 | |||
134 |
|
135 | |||
135 | class Reference(CannedObject): |
|
136 | class Reference(CannedObject): | |
136 | """object for wrapping a remote reference by name.""" |
|
137 | """object for wrapping a remote reference by name.""" | |
137 | def __init__(self, name): |
|
138 | def __init__(self, name): | |
138 | if not isinstance(name, string_types): |
|
139 | if not isinstance(name, string_types): | |
139 | raise TypeError("illegal name: %r"%name) |
|
140 | raise TypeError("illegal name: %r"%name) | |
140 | self.name = name |
|
141 | self.name = name | |
141 | self.buffers = [] |
|
142 | self.buffers = [] | |
142 |
|
143 | |||
143 | def __repr__(self): |
|
144 | def __repr__(self): | |
144 | return "<Reference: %r>"%self.name |
|
145 | return "<Reference: %r>"%self.name | |
145 |
|
146 | |||
146 | def get_object(self, g=None): |
|
147 | def get_object(self, g=None): | |
147 | if g is None: |
|
148 | if g is None: | |
148 | g = {} |
|
149 | g = {} | |
149 |
|
150 | |||
150 | return eval(self.name, g) |
|
151 | return eval(self.name, g) | |
151 |
|
152 | |||
152 |
|
153 | |||
153 | class CannedCell(CannedObject): |
|
154 | class CannedCell(CannedObject): | |
154 | """Can a closure cell""" |
|
155 | """Can a closure cell""" | |
155 | def __init__(self, cell): |
|
156 | def __init__(self, cell): | |
156 | self.cell_contents = can(cell.cell_contents) |
|
157 | self.cell_contents = can(cell.cell_contents) | |
157 |
|
158 | |||
158 | def get_object(self, g=None): |
|
159 | def get_object(self, g=None): | |
159 | cell_contents = uncan(self.cell_contents, g) |
|
160 | cell_contents = uncan(self.cell_contents, g) | |
160 | def inner(): |
|
161 | def inner(): | |
161 | return cell_contents |
|
162 | return cell_contents | |
162 | return py3compat.get_closure(inner)[0] |
|
163 | return py3compat.get_closure(inner)[0] | |
163 |
|
164 | |||
164 |
|
165 | |||
165 | class CannedFunction(CannedObject): |
|
166 | class CannedFunction(CannedObject): | |
166 |
|
167 | |||
167 | def __init__(self, f): |
|
168 | def __init__(self, f): | |
168 | self._check_type(f) |
|
169 | self._check_type(f) | |
169 | self.code = f.__code__ |
|
170 | self.code = f.__code__ | |
170 | if f.__defaults__: |
|
171 | if f.__defaults__: | |
171 | self.defaults = [ can(fd) for fd in f.__defaults__ ] |
|
172 | self.defaults = [ can(fd) for fd in f.__defaults__ ] | |
172 | else: |
|
173 | else: | |
173 | self.defaults = None |
|
174 | self.defaults = None | |
174 |
|
175 | |||
175 | closure = py3compat.get_closure(f) |
|
176 | closure = py3compat.get_closure(f) | |
176 | if closure: |
|
177 | if closure: | |
177 | self.closure = tuple( can(cell) for cell in closure ) |
|
178 | self.closure = tuple( can(cell) for cell in closure ) | |
178 | else: |
|
179 | else: | |
179 | self.closure = None |
|
180 | self.closure = None | |
180 |
|
181 | |||
181 | self.module = f.__module__ or '__main__' |
|
182 | self.module = f.__module__ or '__main__' | |
182 | self.__name__ = f.__name__ |
|
183 | self.__name__ = f.__name__ | |
183 | self.buffers = [] |
|
184 | self.buffers = [] | |
184 |
|
185 | |||
185 | def _check_type(self, obj): |
|
186 | def _check_type(self, obj): | |
186 | assert isinstance(obj, FunctionType), "Not a function type" |
|
187 | assert isinstance(obj, FunctionType), "Not a function type" | |
187 |
|
188 | |||
188 | def get_object(self, g=None): |
|
189 | def get_object(self, g=None): | |
189 | # try to load function back into its module: |
|
190 | # try to load function back into its module: | |
190 | if not self.module.startswith('__'): |
|
191 | if not self.module.startswith('__'): | |
191 | __import__(self.module) |
|
192 | __import__(self.module) | |
192 | g = sys.modules[self.module].__dict__ |
|
193 | g = sys.modules[self.module].__dict__ | |
193 |
|
194 | |||
194 | if g is None: |
|
195 | if g is None: | |
195 | g = {} |
|
196 | g = {} | |
196 | if self.defaults: |
|
197 | if self.defaults: | |
197 | defaults = tuple(uncan(cfd, g) for cfd in self.defaults) |
|
198 | defaults = tuple(uncan(cfd, g) for cfd in self.defaults) | |
198 | else: |
|
199 | else: | |
199 | defaults = None |
|
200 | defaults = None | |
200 | if self.closure: |
|
201 | if self.closure: | |
201 | closure = tuple(uncan(cell, g) for cell in self.closure) |
|
202 | closure = tuple(uncan(cell, g) for cell in self.closure) | |
202 | else: |
|
203 | else: | |
203 | closure = None |
|
204 | closure = None | |
204 | newFunc = FunctionType(self.code, g, self.__name__, defaults, closure) |
|
205 | newFunc = FunctionType(self.code, g, self.__name__, defaults, closure) | |
205 | return newFunc |
|
206 | return newFunc | |
206 |
|
207 | |||
207 | class CannedClass(CannedObject): |
|
208 | class CannedClass(CannedObject): | |
208 |
|
209 | |||
209 | def __init__(self, cls): |
|
210 | def __init__(self, cls): | |
210 | self._check_type(cls) |
|
211 | self._check_type(cls) | |
211 | self.name = cls.__name__ |
|
212 | self.name = cls.__name__ | |
212 | self.old_style = not isinstance(cls, type) |
|
213 | self.old_style = not isinstance(cls, type) | |
213 | self._canned_dict = {} |
|
214 | self._canned_dict = {} | |
214 | for k,v in cls.__dict__.items(): |
|
215 | for k,v in cls.__dict__.items(): | |
215 | if k not in ('__weakref__', '__dict__'): |
|
216 | if k not in ('__weakref__', '__dict__'): | |
216 | self._canned_dict[k] = can(v) |
|
217 | self._canned_dict[k] = can(v) | |
217 | if self.old_style: |
|
218 | if self.old_style: | |
218 | mro = [] |
|
219 | mro = [] | |
219 | else: |
|
220 | else: | |
220 | mro = cls.mro() |
|
221 | mro = cls.mro() | |
221 |
|
222 | |||
222 | self.parents = [ can(c) for c in mro[1:] ] |
|
223 | self.parents = [ can(c) for c in mro[1:] ] | |
223 | self.buffers = [] |
|
224 | self.buffers = [] | |
224 |
|
225 | |||
225 | def _check_type(self, obj): |
|
226 | def _check_type(self, obj): | |
226 | assert isinstance(obj, class_type), "Not a class type" |
|
227 | assert isinstance(obj, class_type), "Not a class type" | |
227 |
|
228 | |||
228 | def get_object(self, g=None): |
|
229 | def get_object(self, g=None): | |
229 | parents = tuple(uncan(p, g) for p in self.parents) |
|
230 | parents = tuple(uncan(p, g) for p in self.parents) | |
230 | return type(self.name, parents, uncan_dict(self._canned_dict, g=g)) |
|
231 | return type(self.name, parents, uncan_dict(self._canned_dict, g=g)) | |
231 |
|
232 | |||
232 | class CannedArray(CannedObject): |
|
233 | class CannedArray(CannedObject): | |
233 | def __init__(self, obj): |
|
234 | def __init__(self, obj): | |
234 | from numpy import ascontiguousarray |
|
235 | from numpy import ascontiguousarray | |
235 | self.shape = obj.shape |
|
236 | self.shape = obj.shape | |
236 | self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str |
|
237 | self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str | |
237 | self.pickled = False |
|
238 | self.pickled = False | |
238 | if sum(obj.shape) == 0: |
|
239 | if sum(obj.shape) == 0: | |
239 | self.pickled = True |
|
240 | self.pickled = True | |
240 | elif obj.dtype == 'O': |
|
241 | elif obj.dtype == 'O': | |
241 | # can't handle object dtype with buffer approach |
|
242 | # can't handle object dtype with buffer approach | |
242 | self.pickled = True |
|
243 | self.pickled = True | |
243 | elif obj.dtype.fields and any(dt == 'O' for dt,sz in obj.dtype.fields.values()): |
|
244 | elif obj.dtype.fields and any(dt == 'O' for dt,sz in obj.dtype.fields.values()): | |
244 | self.pickled = True |
|
245 | self.pickled = True | |
245 | if self.pickled: |
|
246 | if self.pickled: | |
246 | # just pickle it |
|
247 | # just pickle it | |
247 | self.buffers = [pickle.dumps(obj, -1)] |
|
248 | self.buffers = [pickle.dumps(obj, -1)] | |
248 | else: |
|
249 | else: | |
249 | # ensure contiguous |
|
250 | # ensure contiguous | |
250 | obj = ascontiguousarray(obj, dtype=None) |
|
251 | obj = ascontiguousarray(obj, dtype=None) | |
251 | self.buffers = [buffer(obj)] |
|
252 | self.buffers = [buffer(obj)] | |
252 |
|
253 | |||
253 | def get_object(self, g=None): |
|
254 | def get_object(self, g=None): | |
254 | from numpy import frombuffer |
|
255 | from numpy import frombuffer | |
255 | data = self.buffers[0] |
|
256 | data = self.buffers[0] | |
256 | if self.pickled: |
|
257 | if self.pickled: | |
257 | # no shape, we just pickled it |
|
258 | # no shape, we just pickled it | |
258 | return pickle.loads(data) |
|
259 | return pickle.loads(data) | |
259 | else: |
|
260 | else: | |
260 | return frombuffer(data, dtype=self.dtype).reshape(self.shape) |
|
261 | return frombuffer(data, dtype=self.dtype).reshape(self.shape) | |
261 |
|
262 | |||
262 |
|
263 | |||
263 | class CannedBytes(CannedObject): |
|
264 | class CannedBytes(CannedObject): | |
264 | wrap = bytes |
|
265 | wrap = bytes | |
265 | def __init__(self, obj): |
|
266 | def __init__(self, obj): | |
266 | self.buffers = [obj] |
|
267 | self.buffers = [obj] | |
267 |
|
268 | |||
268 | def get_object(self, g=None): |
|
269 | def get_object(self, g=None): | |
269 | data = self.buffers[0] |
|
270 | data = self.buffers[0] | |
270 | return self.wrap(data) |
|
271 | return self.wrap(data) | |
271 |
|
272 | |||
272 | def CannedBuffer(CannedBytes): |
|
273 | def CannedBuffer(CannedBytes): | |
273 | wrap = buffer |
|
274 | wrap = buffer | |
274 |
|
275 | |||
275 | #------------------------------------------------------------------------------- |
|
276 | #------------------------------------------------------------------------------- | |
276 | # Functions |
|
277 | # Functions | |
277 | #------------------------------------------------------------------------------- |
|
278 | #------------------------------------------------------------------------------- | |
278 |
|
279 | |||
279 | def _logger(): |
|
|||
280 | """get the logger for the current Application |
|
|||
281 |
|
||||
282 | the root logger will be used if no Application is running |
|
|||
283 | """ |
|
|||
284 | if Application.initialized(): |
|
|||
285 | logger = Application.instance().log |
|
|||
286 | else: |
|
|||
287 | logger = logging.getLogger() |
|
|||
288 | if not logger.handlers: |
|
|||
289 | logging.basicConfig() |
|
|||
290 |
|
||||
291 | return logger |
|
|||
292 |
|
||||
293 | def _import_mapping(mapping, original=None): |
|
280 | def _import_mapping(mapping, original=None): | |
294 | """import any string-keys in a type mapping |
|
281 | """import any string-keys in a type mapping | |
295 |
|
282 | |||
296 | """ |
|
283 | """ | |
297 | log = _logger() |
|
284 | log = get_logger() | |
298 | log.debug("Importing canning map") |
|
285 | log.debug("Importing canning map") | |
299 | for key,value in list(mapping.items()): |
|
286 | for key,value in list(mapping.items()): | |
300 | if isinstance(key, string_types): |
|
287 | if isinstance(key, string_types): | |
301 | try: |
|
288 | try: | |
302 | cls = import_item(key) |
|
289 | cls = import_item(key) | |
303 | except Exception: |
|
290 | except Exception: | |
304 | if original and key not in original: |
|
291 | if original and key not in original: | |
305 | # only message on user-added classes |
|
292 | # only message on user-added classes | |
306 | log.error("canning class not importable: %r", key, exc_info=True) |
|
293 | log.error("canning class not importable: %r", key, exc_info=True) | |
307 | mapping.pop(key) |
|
294 | mapping.pop(key) | |
308 | else: |
|
295 | else: | |
309 | mapping[cls] = mapping.pop(key) |
|
296 | mapping[cls] = mapping.pop(key) | |
310 |
|
297 | |||
311 | def istype(obj, check): |
|
298 | def istype(obj, check): | |
312 | """like isinstance(obj, check), but strict |
|
299 | """like isinstance(obj, check), but strict | |
313 |
|
300 | |||
314 | This won't catch subclasses. |
|
301 | This won't catch subclasses. | |
315 | """ |
|
302 | """ | |
316 | if isinstance(check, tuple): |
|
303 | if isinstance(check, tuple): | |
317 | for cls in check: |
|
304 | for cls in check: | |
318 | if type(obj) is cls: |
|
305 | if type(obj) is cls: | |
319 | return True |
|
306 | return True | |
320 | return False |
|
307 | return False | |
321 | else: |
|
308 | else: | |
322 | return type(obj) is check |
|
309 | return type(obj) is check | |
323 |
|
310 | |||
324 | def can(obj): |
|
311 | def can(obj): | |
325 | """prepare an object for pickling""" |
|
312 | """prepare an object for pickling""" | |
326 |
|
313 | |||
327 | import_needed = False |
|
314 | import_needed = False | |
328 |
|
315 | |||
329 | for cls,canner in iteritems(can_map): |
|
316 | for cls,canner in iteritems(can_map): | |
330 | if isinstance(cls, string_types): |
|
317 | if isinstance(cls, string_types): | |
331 | import_needed = True |
|
318 | import_needed = True | |
332 | break |
|
319 | break | |
333 | elif istype(obj, cls): |
|
320 | elif istype(obj, cls): | |
334 | return canner(obj) |
|
321 | return canner(obj) | |
335 |
|
322 | |||
336 | if import_needed: |
|
323 | if import_needed: | |
337 | # perform can_map imports, then try again |
|
324 | # perform can_map imports, then try again | |
338 | # this will usually only happen once |
|
325 | # this will usually only happen once | |
339 | _import_mapping(can_map, _original_can_map) |
|
326 | _import_mapping(can_map, _original_can_map) | |
340 | return can(obj) |
|
327 | return can(obj) | |
341 |
|
328 | |||
342 | return obj |
|
329 | return obj | |
343 |
|
330 | |||
344 | def can_class(obj): |
|
331 | def can_class(obj): | |
345 | if isinstance(obj, class_type) and obj.__module__ == '__main__': |
|
332 | if isinstance(obj, class_type) and obj.__module__ == '__main__': | |
346 | return CannedClass(obj) |
|
333 | return CannedClass(obj) | |
347 | else: |
|
334 | else: | |
348 | return obj |
|
335 | return obj | |
349 |
|
336 | |||
350 | def can_dict(obj): |
|
337 | def can_dict(obj): | |
351 | """can the *values* of a dict""" |
|
338 | """can the *values* of a dict""" | |
352 | if istype(obj, dict): |
|
339 | if istype(obj, dict): | |
353 | newobj = {} |
|
340 | newobj = {} | |
354 | for k, v in iteritems(obj): |
|
341 | for k, v in iteritems(obj): | |
355 | newobj[k] = can(v) |
|
342 | newobj[k] = can(v) | |
356 | return newobj |
|
343 | return newobj | |
357 | else: |
|
344 | else: | |
358 | return obj |
|
345 | return obj | |
359 |
|
346 | |||
360 | sequence_types = (list, tuple, set) |
|
347 | sequence_types = (list, tuple, set) | |
361 |
|
348 | |||
362 | def can_sequence(obj): |
|
349 | def can_sequence(obj): | |
363 | """can the elements of a sequence""" |
|
350 | """can the elements of a sequence""" | |
364 | if istype(obj, sequence_types): |
|
351 | if istype(obj, sequence_types): | |
365 | t = type(obj) |
|
352 | t = type(obj) | |
366 | return t([can(i) for i in obj]) |
|
353 | return t([can(i) for i in obj]) | |
367 | else: |
|
354 | else: | |
368 | return obj |
|
355 | return obj | |
369 |
|
356 | |||
370 | def uncan(obj, g=None): |
|
357 | def uncan(obj, g=None): | |
371 | """invert canning""" |
|
358 | """invert canning""" | |
372 |
|
359 | |||
373 | import_needed = False |
|
360 | import_needed = False | |
374 | for cls,uncanner in iteritems(uncan_map): |
|
361 | for cls,uncanner in iteritems(uncan_map): | |
375 | if isinstance(cls, string_types): |
|
362 | if isinstance(cls, string_types): | |
376 | import_needed = True |
|
363 | import_needed = True | |
377 | break |
|
364 | break | |
378 | elif isinstance(obj, cls): |
|
365 | elif isinstance(obj, cls): | |
379 | return uncanner(obj, g) |
|
366 | return uncanner(obj, g) | |
380 |
|
367 | |||
381 | if import_needed: |
|
368 | if import_needed: | |
382 | # perform uncan_map imports, then try again |
|
369 | # perform uncan_map imports, then try again | |
383 | # this will usually only happen once |
|
370 | # this will usually only happen once | |
384 | _import_mapping(uncan_map, _original_uncan_map) |
|
371 | _import_mapping(uncan_map, _original_uncan_map) | |
385 | return uncan(obj, g) |
|
372 | return uncan(obj, g) | |
386 |
|
373 | |||
387 | return obj |
|
374 | return obj | |
388 |
|
375 | |||
389 | def uncan_dict(obj, g=None): |
|
376 | def uncan_dict(obj, g=None): | |
390 | if istype(obj, dict): |
|
377 | if istype(obj, dict): | |
391 | newobj = {} |
|
378 | newobj = {} | |
392 | for k, v in iteritems(obj): |
|
379 | for k, v in iteritems(obj): | |
393 | newobj[k] = uncan(v,g) |
|
380 | newobj[k] = uncan(v,g) | |
394 | return newobj |
|
381 | return newobj | |
395 | else: |
|
382 | else: | |
396 | return obj |
|
383 | return obj | |
397 |
|
384 | |||
398 | def uncan_sequence(obj, g=None): |
|
385 | def uncan_sequence(obj, g=None): | |
399 | if istype(obj, sequence_types): |
|
386 | if istype(obj, sequence_types): | |
400 | t = type(obj) |
|
387 | t = type(obj) | |
401 | return t([uncan(i,g) for i in obj]) |
|
388 | return t([uncan(i,g) for i in obj]) | |
402 | else: |
|
389 | else: | |
403 | return obj |
|
390 | return obj | |
404 |
|
391 | |||
405 | def _uncan_dependent_hook(dep, g=None): |
|
392 | def _uncan_dependent_hook(dep, g=None): | |
406 | dep.check_dependency() |
|
393 | dep.check_dependency() | |
407 |
|
394 | |||
408 | def can_dependent(obj): |
|
395 | def can_dependent(obj): | |
409 | return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook) |
|
396 | return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook) | |
410 |
|
397 | |||
411 | #------------------------------------------------------------------------------- |
|
398 | #------------------------------------------------------------------------------- | |
412 | # API dictionaries |
|
399 | # API dictionaries | |
413 | #------------------------------------------------------------------------------- |
|
400 | #------------------------------------------------------------------------------- | |
414 |
|
401 | |||
415 | # These dicts can be extended for custom serialization of new objects |
|
402 | # These dicts can be extended for custom serialization of new objects | |
416 |
|
403 | |||
417 | can_map = { |
|
404 | can_map = { | |
418 | 'IPython.parallel.dependent' : can_dependent, |
|
405 | 'IPython.parallel.dependent' : can_dependent, | |
419 | 'numpy.ndarray' : CannedArray, |
|
406 | 'numpy.ndarray' : CannedArray, | |
420 | FunctionType : CannedFunction, |
|
407 | FunctionType : CannedFunction, | |
421 | bytes : CannedBytes, |
|
408 | bytes : CannedBytes, | |
422 | buffer : CannedBuffer, |
|
409 | buffer : CannedBuffer, | |
423 | cell_type : CannedCell, |
|
410 | cell_type : CannedCell, | |
424 | class_type : can_class, |
|
411 | class_type : can_class, | |
425 | } |
|
412 | } | |
426 |
|
413 | |||
427 | uncan_map = { |
|
414 | uncan_map = { | |
428 | CannedObject : lambda obj, g: obj.get_object(g), |
|
415 | CannedObject : lambda obj, g: obj.get_object(g), | |
429 | } |
|
416 | } | |
430 |
|
417 | |||
431 | # for use in _import_mapping: |
|
418 | # for use in _import_mapping: | |
432 | _original_can_map = can_map.copy() |
|
419 | _original_can_map = can_map.copy() | |
433 | _original_uncan_map = uncan_map.copy() |
|
420 | _original_uncan_map = uncan_map.copy() |
General Comments 0
You need to be logged in to leave comments.
Login now