##// END OF EJS Templates
Merge pull request #6044 from minrk/core.log...
Thomas Kluyver -
r17065:0c7348f7 merge
parent child Browse files
Show More
@@ -0,0 +1,25 b''
1 """Grab the global logger instance."""
2
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
5
6 import logging
7
8 _logger = None
9
10 def get_logger():
11 """Grab the global logger instance.
12
13 If a global IPython Application is instantiated, grab its logger.
14 Otherwise, grab the root logger.
15 """
16 global _logger
17
18 if _logger is None:
19 from IPython.config import Application
20 if Application.initialized():
21 _logger = Application.instance().log
22 else:
23 logging.basicConfig()
24 _logger = logging.getLogger()
25 return _logger
@@ -1,390 +1,366 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """A base class for objects that are configurable."""
3 A base class for objects that are configurable.
4
5 Inheritance diagram:
6
7 .. inheritance-diagram:: IPython.config.configurable
8 :parts: 3
9
3
10 Authors:
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
11
6
12 * Brian Granger
13 * Fernando Perez
14 * Min RK
15 """
16 from __future__ import print_function
7 from __future__ import print_function
17
8
18 #-----------------------------------------------------------------------------
19 # Copyright (C) 2008-2011 The IPython Development Team
20 #
21 # Distributed under the terms of the BSD License. The full license is in
22 # the file COPYING, distributed as part of this software.
23 #-----------------------------------------------------------------------------
24
25 #-----------------------------------------------------------------------------
26 # Imports
27 #-----------------------------------------------------------------------------
28
29 import logging
9 import logging
30 from copy import deepcopy
10 from copy import deepcopy
31
11
32 from .loader import Config, LazyConfigValue
12 from .loader import Config, LazyConfigValue
33 from IPython.utils.traitlets import HasTraits, Instance
13 from IPython.utils.traitlets import HasTraits, Instance
34 from IPython.utils.text import indent, wrap_paragraphs
14 from IPython.utils.text import indent, wrap_paragraphs
35 from IPython.utils.py3compat import iteritems
15 from IPython.utils.py3compat import iteritems
36
16
37
17
38 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
39 # Helper classes for Configurables
19 # Helper classes for Configurables
40 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
41
21
42
22
43 class ConfigurableError(Exception):
23 class ConfigurableError(Exception):
44 pass
24 pass
45
25
46
26
47 class MultipleInstanceError(ConfigurableError):
27 class MultipleInstanceError(ConfigurableError):
48 pass
28 pass
49
29
50 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
51 # Configurable implementation
31 # Configurable implementation
52 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
53
33
54 class Configurable(HasTraits):
34 class Configurable(HasTraits):
55
35
56 config = Instance(Config, (), {})
36 config = Instance(Config, (), {})
57 parent = Instance('IPython.config.configurable.Configurable')
37 parent = Instance('IPython.config.configurable.Configurable')
58
38
59 def __init__(self, **kwargs):
39 def __init__(self, **kwargs):
60 """Create a configurable given a config config.
40 """Create a configurable given a config config.
61
41
62 Parameters
42 Parameters
63 ----------
43 ----------
64 config : Config
44 config : Config
65 If this is empty, default values are used. If config is a
45 If this is empty, default values are used. If config is a
66 :class:`Config` instance, it will be used to configure the
46 :class:`Config` instance, it will be used to configure the
67 instance.
47 instance.
68 parent : Configurable instance, optional
48 parent : Configurable instance, optional
69 The parent Configurable instance of this object.
49 The parent Configurable instance of this object.
70
50
71 Notes
51 Notes
72 -----
52 -----
73 Subclasses of Configurable must call the :meth:`__init__` method of
53 Subclasses of Configurable must call the :meth:`__init__` method of
74 :class:`Configurable` *before* doing anything else and using
54 :class:`Configurable` *before* doing anything else and using
75 :func:`super`::
55 :func:`super`::
76
56
77 class MyConfigurable(Configurable):
57 class MyConfigurable(Configurable):
78 def __init__(self, config=None):
58 def __init__(self, config=None):
79 super(MyConfigurable, self).__init__(config=config)
59 super(MyConfigurable, self).__init__(config=config)
80 # Then any other code you need to finish initialization.
60 # Then any other code you need to finish initialization.
81
61
82 This ensures that instances will be configured properly.
62 This ensures that instances will be configured properly.
83 """
63 """
84 parent = kwargs.pop('parent', None)
64 parent = kwargs.pop('parent', None)
85 if parent is not None:
65 if parent is not None:
86 # config is implied from parent
66 # config is implied from parent
87 if kwargs.get('config', None) is None:
67 if kwargs.get('config', None) is None:
88 kwargs['config'] = parent.config
68 kwargs['config'] = parent.config
89 self.parent = parent
69 self.parent = parent
90
70
91 config = kwargs.pop('config', None)
71 config = kwargs.pop('config', None)
92 if config is not None:
72 if config is not None:
93 # We used to deepcopy, but for now we are trying to just save
73 # We used to deepcopy, but for now we are trying to just save
94 # by reference. This *could* have side effects as all components
74 # by reference. This *could* have side effects as all components
95 # will share config. In fact, I did find such a side effect in
75 # will share config. In fact, I did find such a side effect in
96 # _config_changed below. If a config attribute value was a mutable type
76 # _config_changed below. If a config attribute value was a mutable type
97 # all instances of a component were getting the same copy, effectively
77 # all instances of a component were getting the same copy, effectively
98 # making that a class attribute.
78 # making that a class attribute.
99 # self.config = deepcopy(config)
79 # self.config = deepcopy(config)
100 self.config = config
80 self.config = config
101 # This should go second so individual keyword arguments override
81 # This should go second so individual keyword arguments override
102 # the values in config.
82 # the values in config.
103 super(Configurable, self).__init__(**kwargs)
83 super(Configurable, self).__init__(**kwargs)
104
84
105 #-------------------------------------------------------------------------
85 #-------------------------------------------------------------------------
106 # Static trait notifiations
86 # Static trait notifiations
107 #-------------------------------------------------------------------------
87 #-------------------------------------------------------------------------
108
88
109 @classmethod
89 @classmethod
110 def section_names(cls):
90 def section_names(cls):
111 """return section names as a list"""
91 """return section names as a list"""
112 return [c.__name__ for c in reversed(cls.__mro__) if
92 return [c.__name__ for c in reversed(cls.__mro__) if
113 issubclass(c, Configurable) and issubclass(cls, c)
93 issubclass(c, Configurable) and issubclass(cls, c)
114 ]
94 ]
115
95
116 def _find_my_config(self, cfg):
96 def _find_my_config(self, cfg):
117 """extract my config from a global Config object
97 """extract my config from a global Config object
118
98
119 will construct a Config object of only the config values that apply to me
99 will construct a Config object of only the config values that apply to me
120 based on my mro(), as well as those of my parent(s) if they exist.
100 based on my mro(), as well as those of my parent(s) if they exist.
121
101
122 If I am Bar and my parent is Foo, and their parent is Tim,
102 If I am Bar and my parent is Foo, and their parent is Tim,
123 this will return merge following config sections, in this order::
103 this will return merge following config sections, in this order::
124
104
125 [Bar, Foo.bar, Tim.Foo.Bar]
105 [Bar, Foo.bar, Tim.Foo.Bar]
126
106
127 With the last item being the highest priority.
107 With the last item being the highest priority.
128 """
108 """
129 cfgs = [cfg]
109 cfgs = [cfg]
130 if self.parent:
110 if self.parent:
131 cfgs.append(self.parent._find_my_config(cfg))
111 cfgs.append(self.parent._find_my_config(cfg))
132 my_config = Config()
112 my_config = Config()
133 for c in cfgs:
113 for c in cfgs:
134 for sname in self.section_names():
114 for sname in self.section_names():
135 # Don't do a blind getattr as that would cause the config to
115 # Don't do a blind getattr as that would cause the config to
136 # dynamically create the section with name Class.__name__.
116 # dynamically create the section with name Class.__name__.
137 if c._has_section(sname):
117 if c._has_section(sname):
138 my_config.merge(c[sname])
118 my_config.merge(c[sname])
139 return my_config
119 return my_config
140
120
141 def _load_config(self, cfg, section_names=None, traits=None):
121 def _load_config(self, cfg, section_names=None, traits=None):
142 """load traits from a Config object"""
122 """load traits from a Config object"""
143
123
144 if traits is None:
124 if traits is None:
145 traits = self.traits(config=True)
125 traits = self.traits(config=True)
146 if section_names is None:
126 if section_names is None:
147 section_names = self.section_names()
127 section_names = self.section_names()
148
128
149 my_config = self._find_my_config(cfg)
129 my_config = self._find_my_config(cfg)
150 for name, config_value in iteritems(my_config):
130 for name, config_value in iteritems(my_config):
151 if name in traits:
131 if name in traits:
152 if isinstance(config_value, LazyConfigValue):
132 if isinstance(config_value, LazyConfigValue):
153 # ConfigValue is a wrapper for using append / update on containers
133 # ConfigValue is a wrapper for using append / update on containers
154 # without having to copy the
134 # without having to copy the
155 initial = getattr(self, name)
135 initial = getattr(self, name)
156 config_value = config_value.get_value(initial)
136 config_value = config_value.get_value(initial)
157 # We have to do a deepcopy here if we don't deepcopy the entire
137 # We have to do a deepcopy here if we don't deepcopy the entire
158 # config object. If we don't, a mutable config_value will be
138 # config object. If we don't, a mutable config_value will be
159 # shared by all instances, effectively making it a class attribute.
139 # shared by all instances, effectively making it a class attribute.
160 setattr(self, name, deepcopy(config_value))
140 setattr(self, name, deepcopy(config_value))
161
141
162 def _config_changed(self, name, old, new):
142 def _config_changed(self, name, old, new):
163 """Update all the class traits having ``config=True`` as metadata.
143 """Update all the class traits having ``config=True`` as metadata.
164
144
165 For any class trait with a ``config`` metadata attribute that is
145 For any class trait with a ``config`` metadata attribute that is
166 ``True``, we update the trait with the value of the corresponding
146 ``True``, we update the trait with the value of the corresponding
167 config entry.
147 config entry.
168 """
148 """
169 # Get all traits with a config metadata entry that is True
149 # Get all traits with a config metadata entry that is True
170 traits = self.traits(config=True)
150 traits = self.traits(config=True)
171
151
172 # We auto-load config section for this class as well as any parent
152 # We auto-load config section for this class as well as any parent
173 # classes that are Configurable subclasses. This starts with Configurable
153 # classes that are Configurable subclasses. This starts with Configurable
174 # and works down the mro loading the config for each section.
154 # and works down the mro loading the config for each section.
175 section_names = self.section_names()
155 section_names = self.section_names()
176 self._load_config(new, traits=traits, section_names=section_names)
156 self._load_config(new, traits=traits, section_names=section_names)
177
157
178 def update_config(self, config):
158 def update_config(self, config):
179 """Fire the traits events when the config is updated."""
159 """Fire the traits events when the config is updated."""
180 # Save a copy of the current config.
160 # Save a copy of the current config.
181 newconfig = deepcopy(self.config)
161 newconfig = deepcopy(self.config)
182 # Merge the new config into the current one.
162 # Merge the new config into the current one.
183 newconfig.merge(config)
163 newconfig.merge(config)
184 # Save the combined config as self.config, which triggers the traits
164 # Save the combined config as self.config, which triggers the traits
185 # events.
165 # events.
186 self.config = newconfig
166 self.config = newconfig
187
167
188 @classmethod
168 @classmethod
189 def class_get_help(cls, inst=None):
169 def class_get_help(cls, inst=None):
190 """Get the help string for this class in ReST format.
170 """Get the help string for this class in ReST format.
191
171
192 If `inst` is given, it's current trait values will be used in place of
172 If `inst` is given, it's current trait values will be used in place of
193 class defaults.
173 class defaults.
194 """
174 """
195 assert inst is None or isinstance(inst, cls)
175 assert inst is None or isinstance(inst, cls)
196 final_help = []
176 final_help = []
197 final_help.append(u'%s options' % cls.__name__)
177 final_help.append(u'%s options' % cls.__name__)
198 final_help.append(len(final_help[0])*u'-')
178 final_help.append(len(final_help[0])*u'-')
199 for k, v in sorted(cls.class_traits(config=True).items()):
179 for k, v in sorted(cls.class_traits(config=True).items()):
200 help = cls.class_get_trait_help(v, inst)
180 help = cls.class_get_trait_help(v, inst)
201 final_help.append(help)
181 final_help.append(help)
202 return '\n'.join(final_help)
182 return '\n'.join(final_help)
203
183
204 @classmethod
184 @classmethod
205 def class_get_trait_help(cls, trait, inst=None):
185 def class_get_trait_help(cls, trait, inst=None):
206 """Get the help string for a single trait.
186 """Get the help string for a single trait.
207
187
208 If `inst` is given, it's current trait values will be used in place of
188 If `inst` is given, it's current trait values will be used in place of
209 the class default.
189 the class default.
210 """
190 """
211 assert inst is None or isinstance(inst, cls)
191 assert inst is None or isinstance(inst, cls)
212 lines = []
192 lines = []
213 header = "--%s.%s=<%s>" % (cls.__name__, trait.name, trait.__class__.__name__)
193 header = "--%s.%s=<%s>" % (cls.__name__, trait.name, trait.__class__.__name__)
214 lines.append(header)
194 lines.append(header)
215 if inst is not None:
195 if inst is not None:
216 lines.append(indent('Current: %r' % getattr(inst, trait.name), 4))
196 lines.append(indent('Current: %r' % getattr(inst, trait.name), 4))
217 else:
197 else:
218 try:
198 try:
219 dvr = repr(trait.get_default_value())
199 dvr = repr(trait.get_default_value())
220 except Exception:
200 except Exception:
221 dvr = None # ignore defaults we can't construct
201 dvr = None # ignore defaults we can't construct
222 if dvr is not None:
202 if dvr is not None:
223 if len(dvr) > 64:
203 if len(dvr) > 64:
224 dvr = dvr[:61]+'...'
204 dvr = dvr[:61]+'...'
225 lines.append(indent('Default: %s' % dvr, 4))
205 lines.append(indent('Default: %s' % dvr, 4))
226 if 'Enum' in trait.__class__.__name__:
206 if 'Enum' in trait.__class__.__name__:
227 # include Enum choices
207 # include Enum choices
228 lines.append(indent('Choices: %r' % (trait.values,)))
208 lines.append(indent('Choices: %r' % (trait.values,)))
229
209
230 help = trait.get_metadata('help')
210 help = trait.get_metadata('help')
231 if help is not None:
211 if help is not None:
232 help = '\n'.join(wrap_paragraphs(help, 76))
212 help = '\n'.join(wrap_paragraphs(help, 76))
233 lines.append(indent(help, 4))
213 lines.append(indent(help, 4))
234 return '\n'.join(lines)
214 return '\n'.join(lines)
235
215
236 @classmethod
216 @classmethod
237 def class_print_help(cls, inst=None):
217 def class_print_help(cls, inst=None):
238 """Get the help string for a single trait and print it."""
218 """Get the help string for a single trait and print it."""
239 print(cls.class_get_help(inst))
219 print(cls.class_get_help(inst))
240
220
241 @classmethod
221 @classmethod
242 def class_config_section(cls):
222 def class_config_section(cls):
243 """Get the config class config section"""
223 """Get the config class config section"""
244 def c(s):
224 def c(s):
245 """return a commented, wrapped block."""
225 """return a commented, wrapped block."""
246 s = '\n\n'.join(wrap_paragraphs(s, 78))
226 s = '\n\n'.join(wrap_paragraphs(s, 78))
247
227
248 return '# ' + s.replace('\n', '\n# ')
228 return '# ' + s.replace('\n', '\n# ')
249
229
250 # section header
230 # section header
251 breaker = '#' + '-'*78
231 breaker = '#' + '-'*78
252 s = "# %s configuration" % cls.__name__
232 s = "# %s configuration" % cls.__name__
253 lines = [breaker, s, breaker, '']
233 lines = [breaker, s, breaker, '']
254 # get the description trait
234 # get the description trait
255 desc = cls.class_traits().get('description')
235 desc = cls.class_traits().get('description')
256 if desc:
236 if desc:
257 desc = desc.default_value
237 desc = desc.default_value
258 else:
238 else:
259 # no description trait, use __doc__
239 # no description trait, use __doc__
260 desc = getattr(cls, '__doc__', '')
240 desc = getattr(cls, '__doc__', '')
261 if desc:
241 if desc:
262 lines.append(c(desc))
242 lines.append(c(desc))
263 lines.append('')
243 lines.append('')
264
244
265 parents = []
245 parents = []
266 for parent in cls.mro():
246 for parent in cls.mro():
267 # only include parents that are not base classes
247 # only include parents that are not base classes
268 # and are not the class itself
248 # and are not the class itself
269 # and have some configurable traits to inherit
249 # and have some configurable traits to inherit
270 if parent is not cls and issubclass(parent, Configurable) and \
250 if parent is not cls and issubclass(parent, Configurable) and \
271 parent.class_traits(config=True):
251 parent.class_traits(config=True):
272 parents.append(parent)
252 parents.append(parent)
273
253
274 if parents:
254 if parents:
275 pstr = ', '.join([ p.__name__ for p in parents ])
255 pstr = ', '.join([ p.__name__ for p in parents ])
276 lines.append(c('%s will inherit config from: %s'%(cls.__name__, pstr)))
256 lines.append(c('%s will inherit config from: %s'%(cls.__name__, pstr)))
277 lines.append('')
257 lines.append('')
278
258
279 for name, trait in iteritems(cls.class_traits(config=True)):
259 for name, trait in iteritems(cls.class_traits(config=True)):
280 help = trait.get_metadata('help') or ''
260 help = trait.get_metadata('help') or ''
281 lines.append(c(help))
261 lines.append(c(help))
282 lines.append('# c.%s.%s = %r'%(cls.__name__, name, trait.get_default_value()))
262 lines.append('# c.%s.%s = %r'%(cls.__name__, name, trait.get_default_value()))
283 lines.append('')
263 lines.append('')
284 return '\n'.join(lines)
264 return '\n'.join(lines)
285
265
286
266
287
267
288 class SingletonConfigurable(Configurable):
268 class SingletonConfigurable(Configurable):
289 """A configurable that only allows one instance.
269 """A configurable that only allows one instance.
290
270
291 This class is for classes that should only have one instance of itself
271 This class is for classes that should only have one instance of itself
292 or *any* subclass. To create and retrieve such a class use the
272 or *any* subclass. To create and retrieve such a class use the
293 :meth:`SingletonConfigurable.instance` method.
273 :meth:`SingletonConfigurable.instance` method.
294 """
274 """
295
275
296 _instance = None
276 _instance = None
297
277
298 @classmethod
278 @classmethod
299 def _walk_mro(cls):
279 def _walk_mro(cls):
300 """Walk the cls.mro() for parent classes that are also singletons
280 """Walk the cls.mro() for parent classes that are also singletons
301
281
302 For use in instance()
282 For use in instance()
303 """
283 """
304
284
305 for subclass in cls.mro():
285 for subclass in cls.mro():
306 if issubclass(cls, subclass) and \
286 if issubclass(cls, subclass) and \
307 issubclass(subclass, SingletonConfigurable) and \
287 issubclass(subclass, SingletonConfigurable) and \
308 subclass != SingletonConfigurable:
288 subclass != SingletonConfigurable:
309 yield subclass
289 yield subclass
310
290
311 @classmethod
291 @classmethod
312 def clear_instance(cls):
292 def clear_instance(cls):
313 """unset _instance for this class and singleton parents.
293 """unset _instance for this class and singleton parents.
314 """
294 """
315 if not cls.initialized():
295 if not cls.initialized():
316 return
296 return
317 for subclass in cls._walk_mro():
297 for subclass in cls._walk_mro():
318 if isinstance(subclass._instance, cls):
298 if isinstance(subclass._instance, cls):
319 # only clear instances that are instances
299 # only clear instances that are instances
320 # of the calling class
300 # of the calling class
321 subclass._instance = None
301 subclass._instance = None
322
302
323 @classmethod
303 @classmethod
324 def instance(cls, *args, **kwargs):
304 def instance(cls, *args, **kwargs):
325 """Returns a global instance of this class.
305 """Returns a global instance of this class.
326
306
327 This method create a new instance if none have previously been created
307 This method create a new instance if none have previously been created
328 and returns a previously created instance is one already exists.
308 and returns a previously created instance is one already exists.
329
309
330 The arguments and keyword arguments passed to this method are passed
310 The arguments and keyword arguments passed to this method are passed
331 on to the :meth:`__init__` method of the class upon instantiation.
311 on to the :meth:`__init__` method of the class upon instantiation.
332
312
333 Examples
313 Examples
334 --------
314 --------
335
315
336 Create a singleton class using instance, and retrieve it::
316 Create a singleton class using instance, and retrieve it::
337
317
338 >>> from IPython.config.configurable import SingletonConfigurable
318 >>> from IPython.config.configurable import SingletonConfigurable
339 >>> class Foo(SingletonConfigurable): pass
319 >>> class Foo(SingletonConfigurable): pass
340 >>> foo = Foo.instance()
320 >>> foo = Foo.instance()
341 >>> foo == Foo.instance()
321 >>> foo == Foo.instance()
342 True
322 True
343
323
344 Create a subclass that is retrived using the base class instance::
324 Create a subclass that is retrived using the base class instance::
345
325
346 >>> class Bar(SingletonConfigurable): pass
326 >>> class Bar(SingletonConfigurable): pass
347 >>> class Bam(Bar): pass
327 >>> class Bam(Bar): pass
348 >>> bam = Bam.instance()
328 >>> bam = Bam.instance()
349 >>> bam == Bar.instance()
329 >>> bam == Bar.instance()
350 True
330 True
351 """
331 """
352 # Create and save the instance
332 # Create and save the instance
353 if cls._instance is None:
333 if cls._instance is None:
354 inst = cls(*args, **kwargs)
334 inst = cls(*args, **kwargs)
355 # Now make sure that the instance will also be returned by
335 # Now make sure that the instance will also be returned by
356 # parent classes' _instance attribute.
336 # parent classes' _instance attribute.
357 for subclass in cls._walk_mro():
337 for subclass in cls._walk_mro():
358 subclass._instance = inst
338 subclass._instance = inst
359
339
360 if isinstance(cls._instance, cls):
340 if isinstance(cls._instance, cls):
361 return cls._instance
341 return cls._instance
362 else:
342 else:
363 raise MultipleInstanceError(
343 raise MultipleInstanceError(
364 'Multiple incompatible subclass instances of '
344 'Multiple incompatible subclass instances of '
365 '%s are being created.' % cls.__name__
345 '%s are being created.' % cls.__name__
366 )
346 )
367
347
368 @classmethod
348 @classmethod
369 def initialized(cls):
349 def initialized(cls):
370 """Has an instance been created?"""
350 """Has an instance been created?"""
371 return hasattr(cls, "_instance") and cls._instance is not None
351 return hasattr(cls, "_instance") and cls._instance is not None
372
352
373
353
374 class LoggingConfigurable(Configurable):
354 class LoggingConfigurable(Configurable):
375 """A parent class for Configurables that log.
355 """A parent class for Configurables that log.
376
356
377 Subclasses have a log trait, and the default behavior
357 Subclasses have a log trait, and the default behavior
378 is to get the logger from the currently running Application
358 is to get the logger from the currently running Application.
379 via Application.instance().log.
380 """
359 """
381
360
382 log = Instance('logging.Logger')
361 log = Instance('logging.Logger')
383 def _log_default(self):
362 def _log_default(self):
384 from IPython.config.application import Application
363 from IPython.utils import log
385 if Application.initialized():
364 return log.get_logger()
386 return Application.instance().log
387 else:
388 return logging.getLogger()
389
365
390
366
@@ -1,846 +1,824 b''
1 """A simple configuration system.
1 # encoding: utf-8
2 """A simple configuration system."""
2
3
3 Inheritance diagram:
4 # Copyright (c) IPython Development Team.
4
5 # Distributed under the terms of the Modified BSD License.
5 .. inheritance-diagram:: IPython.config.loader
6 :parts: 3
7
8 Authors
9 -------
10 * Brian Granger
11 * Fernando Perez
12 * Min RK
13 """
14
15 #-----------------------------------------------------------------------------
16 # Copyright (C) 2008-2011 The IPython Development Team
17 #
18 # Distributed under the terms of the BSD License. The full license is in
19 # the file COPYING, distributed as part of this software.
20 #-----------------------------------------------------------------------------
21
22 #-----------------------------------------------------------------------------
23 # Imports
24 #-----------------------------------------------------------------------------
25
6
26 import argparse
7 import argparse
27 import copy
8 import copy
28 import logging
9 import logging
29 import os
10 import os
30 import re
11 import re
31 import sys
12 import sys
32 import json
13 import json
33
14
34 from IPython.utils.path import filefind, get_ipython_dir
15 from IPython.utils.path import filefind, get_ipython_dir
35 from IPython.utils import py3compat
16 from IPython.utils import py3compat
36 from IPython.utils.encoding import DEFAULT_ENCODING
17 from IPython.utils.encoding import DEFAULT_ENCODING
37 from IPython.utils.py3compat import unicode_type, iteritems
18 from IPython.utils.py3compat import unicode_type, iteritems
38 from IPython.utils.traitlets import HasTraits, List, Any
19 from IPython.utils.traitlets import HasTraits, List, Any
39
20
40 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
41 # Exceptions
22 # Exceptions
42 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
43
24
44
25
45 class ConfigError(Exception):
26 class ConfigError(Exception):
46 pass
27 pass
47
28
48 class ConfigLoaderError(ConfigError):
29 class ConfigLoaderError(ConfigError):
49 pass
30 pass
50
31
51 class ConfigFileNotFound(ConfigError):
32 class ConfigFileNotFound(ConfigError):
52 pass
33 pass
53
34
54 class ArgumentError(ConfigLoaderError):
35 class ArgumentError(ConfigLoaderError):
55 pass
36 pass
56
37
57 #-----------------------------------------------------------------------------
38 #-----------------------------------------------------------------------------
58 # Argparse fix
39 # Argparse fix
59 #-----------------------------------------------------------------------------
40 #-----------------------------------------------------------------------------
60
41
61 # Unfortunately argparse by default prints help messages to stderr instead of
42 # Unfortunately argparse by default prints help messages to stderr instead of
62 # stdout. This makes it annoying to capture long help screens at the command
43 # stdout. This makes it annoying to capture long help screens at the command
63 # line, since one must know how to pipe stderr, which many users don't know how
44 # line, since one must know how to pipe stderr, which many users don't know how
64 # to do. So we override the print_help method with one that defaults to
45 # to do. So we override the print_help method with one that defaults to
65 # stdout and use our class instead.
46 # stdout and use our class instead.
66
47
67 class ArgumentParser(argparse.ArgumentParser):
48 class ArgumentParser(argparse.ArgumentParser):
68 """Simple argparse subclass that prints help to stdout by default."""
49 """Simple argparse subclass that prints help to stdout by default."""
69
50
70 def print_help(self, file=None):
51 def print_help(self, file=None):
71 if file is None:
52 if file is None:
72 file = sys.stdout
53 file = sys.stdout
73 return super(ArgumentParser, self).print_help(file)
54 return super(ArgumentParser, self).print_help(file)
74
55
75 print_help.__doc__ = argparse.ArgumentParser.print_help.__doc__
56 print_help.__doc__ = argparse.ArgumentParser.print_help.__doc__
76
57
77 #-----------------------------------------------------------------------------
58 #-----------------------------------------------------------------------------
78 # Config class for holding config information
59 # Config class for holding config information
79 #-----------------------------------------------------------------------------
60 #-----------------------------------------------------------------------------
80
61
81 class LazyConfigValue(HasTraits):
62 class LazyConfigValue(HasTraits):
82 """Proxy object for exposing methods on configurable containers
63 """Proxy object for exposing methods on configurable containers
83
64
84 Exposes:
65 Exposes:
85
66
86 - append, extend, insert on lists
67 - append, extend, insert on lists
87 - update on dicts
68 - update on dicts
88 - update, add on sets
69 - update, add on sets
89 """
70 """
90
71
91 _value = None
72 _value = None
92
73
93 # list methods
74 # list methods
94 _extend = List()
75 _extend = List()
95 _prepend = List()
76 _prepend = List()
96
77
97 def append(self, obj):
78 def append(self, obj):
98 self._extend.append(obj)
79 self._extend.append(obj)
99
80
100 def extend(self, other):
81 def extend(self, other):
101 self._extend.extend(other)
82 self._extend.extend(other)
102
83
103 def prepend(self, other):
84 def prepend(self, other):
104 """like list.extend, but for the front"""
85 """like list.extend, but for the front"""
105 self._prepend[:0] = other
86 self._prepend[:0] = other
106
87
107 _inserts = List()
88 _inserts = List()
108 def insert(self, index, other):
89 def insert(self, index, other):
109 if not isinstance(index, int):
90 if not isinstance(index, int):
110 raise TypeError("An integer is required")
91 raise TypeError("An integer is required")
111 self._inserts.append((index, other))
92 self._inserts.append((index, other))
112
93
113 # dict methods
94 # dict methods
114 # update is used for both dict and set
95 # update is used for both dict and set
115 _update = Any()
96 _update = Any()
116 def update(self, other):
97 def update(self, other):
117 if self._update is None:
98 if self._update is None:
118 if isinstance(other, dict):
99 if isinstance(other, dict):
119 self._update = {}
100 self._update = {}
120 else:
101 else:
121 self._update = set()
102 self._update = set()
122 self._update.update(other)
103 self._update.update(other)
123
104
124 # set methods
105 # set methods
125 def add(self, obj):
106 def add(self, obj):
126 self.update({obj})
107 self.update({obj})
127
108
128 def get_value(self, initial):
109 def get_value(self, initial):
129 """construct the value from the initial one
110 """construct the value from the initial one
130
111
131 after applying any insert / extend / update changes
112 after applying any insert / extend / update changes
132 """
113 """
133 if self._value is not None:
114 if self._value is not None:
134 return self._value
115 return self._value
135 value = copy.deepcopy(initial)
116 value = copy.deepcopy(initial)
136 if isinstance(value, list):
117 if isinstance(value, list):
137 for idx, obj in self._inserts:
118 for idx, obj in self._inserts:
138 value.insert(idx, obj)
119 value.insert(idx, obj)
139 value[:0] = self._prepend
120 value[:0] = self._prepend
140 value.extend(self._extend)
121 value.extend(self._extend)
141
122
142 elif isinstance(value, dict):
123 elif isinstance(value, dict):
143 if self._update:
124 if self._update:
144 value.update(self._update)
125 value.update(self._update)
145 elif isinstance(value, set):
126 elif isinstance(value, set):
146 if self._update:
127 if self._update:
147 value.update(self._update)
128 value.update(self._update)
148 self._value = value
129 self._value = value
149 return value
130 return value
150
131
151 def to_dict(self):
132 def to_dict(self):
152 """return JSONable dict form of my data
133 """return JSONable dict form of my data
153
134
154 Currently update as dict or set, extend, prepend as lists, and inserts as list of tuples.
135 Currently update as dict or set, extend, prepend as lists, and inserts as list of tuples.
155 """
136 """
156 d = {}
137 d = {}
157 if self._update:
138 if self._update:
158 d['update'] = self._update
139 d['update'] = self._update
159 if self._extend:
140 if self._extend:
160 d['extend'] = self._extend
141 d['extend'] = self._extend
161 if self._prepend:
142 if self._prepend:
162 d['prepend'] = self._prepend
143 d['prepend'] = self._prepend
163 elif self._inserts:
144 elif self._inserts:
164 d['inserts'] = self._inserts
145 d['inserts'] = self._inserts
165 return d
146 return d
166
147
167
148
168 def _is_section_key(key):
149 def _is_section_key(key):
169 """Is a Config key a section name (does it start with a capital)?"""
150 """Is a Config key a section name (does it start with a capital)?"""
170 if key and key[0].upper()==key[0] and not key.startswith('_'):
151 if key and key[0].upper()==key[0] and not key.startswith('_'):
171 return True
152 return True
172 else:
153 else:
173 return False
154 return False
174
155
175
156
176 class Config(dict):
157 class Config(dict):
177 """An attribute based dict that can do smart merges."""
158 """An attribute based dict that can do smart merges."""
178
159
179 def __init__(self, *args, **kwds):
160 def __init__(self, *args, **kwds):
180 dict.__init__(self, *args, **kwds)
161 dict.__init__(self, *args, **kwds)
181 self._ensure_subconfig()
162 self._ensure_subconfig()
182
163
183 def _ensure_subconfig(self):
164 def _ensure_subconfig(self):
184 """ensure that sub-dicts that should be Config objects are
165 """ensure that sub-dicts that should be Config objects are
185
166
186 casts dicts that are under section keys to Config objects,
167 casts dicts that are under section keys to Config objects,
187 which is necessary for constructing Config objects from dict literals.
168 which is necessary for constructing Config objects from dict literals.
188 """
169 """
189 for key in self:
170 for key in self:
190 obj = self[key]
171 obj = self[key]
191 if _is_section_key(key) \
172 if _is_section_key(key) \
192 and isinstance(obj, dict) \
173 and isinstance(obj, dict) \
193 and not isinstance(obj, Config):
174 and not isinstance(obj, Config):
194 setattr(self, key, Config(obj))
175 setattr(self, key, Config(obj))
195
176
196 def _merge(self, other):
177 def _merge(self, other):
197 """deprecated alias, use Config.merge()"""
178 """deprecated alias, use Config.merge()"""
198 self.merge(other)
179 self.merge(other)
199
180
200 def merge(self, other):
181 def merge(self, other):
201 """merge another config object into this one"""
182 """merge another config object into this one"""
202 to_update = {}
183 to_update = {}
203 for k, v in iteritems(other):
184 for k, v in iteritems(other):
204 if k not in self:
185 if k not in self:
205 to_update[k] = copy.deepcopy(v)
186 to_update[k] = copy.deepcopy(v)
206 else: # I have this key
187 else: # I have this key
207 if isinstance(v, Config) and isinstance(self[k], Config):
188 if isinstance(v, Config) and isinstance(self[k], Config):
208 # Recursively merge common sub Configs
189 # Recursively merge common sub Configs
209 self[k].merge(v)
190 self[k].merge(v)
210 else:
191 else:
211 # Plain updates for non-Configs
192 # Plain updates for non-Configs
212 to_update[k] = copy.deepcopy(v)
193 to_update[k] = copy.deepcopy(v)
213
194
214 self.update(to_update)
195 self.update(to_update)
215
196
216 def __contains__(self, key):
197 def __contains__(self, key):
217 # allow nested contains of the form `"Section.key" in config`
198 # allow nested contains of the form `"Section.key" in config`
218 if '.' in key:
199 if '.' in key:
219 first, remainder = key.split('.', 1)
200 first, remainder = key.split('.', 1)
220 if first not in self:
201 if first not in self:
221 return False
202 return False
222 return remainder in self[first]
203 return remainder in self[first]
223
204
224 return super(Config, self).__contains__(key)
205 return super(Config, self).__contains__(key)
225
206
226 # .has_key is deprecated for dictionaries.
207 # .has_key is deprecated for dictionaries.
227 has_key = __contains__
208 has_key = __contains__
228
209
229 def _has_section(self, key):
210 def _has_section(self, key):
230 return _is_section_key(key) and key in self
211 return _is_section_key(key) and key in self
231
212
232 def copy(self):
213 def copy(self):
233 return type(self)(dict.copy(self))
214 return type(self)(dict.copy(self))
234
215
235 def __copy__(self):
216 def __copy__(self):
236 return self.copy()
217 return self.copy()
237
218
238 def __deepcopy__(self, memo):
219 def __deepcopy__(self, memo):
239 import copy
220 import copy
240 return type(self)(copy.deepcopy(list(self.items())))
221 return type(self)(copy.deepcopy(list(self.items())))
241
222
242 def __getitem__(self, key):
223 def __getitem__(self, key):
243 try:
224 try:
244 return dict.__getitem__(self, key)
225 return dict.__getitem__(self, key)
245 except KeyError:
226 except KeyError:
246 if _is_section_key(key):
227 if _is_section_key(key):
247 c = Config()
228 c = Config()
248 dict.__setitem__(self, key, c)
229 dict.__setitem__(self, key, c)
249 return c
230 return c
250 elif not key.startswith('_'):
231 elif not key.startswith('_'):
251 # undefined, create lazy value, used for container methods
232 # undefined, create lazy value, used for container methods
252 v = LazyConfigValue()
233 v = LazyConfigValue()
253 dict.__setitem__(self, key, v)
234 dict.__setitem__(self, key, v)
254 return v
235 return v
255 else:
236 else:
256 raise KeyError
237 raise KeyError
257
238
258 def __setitem__(self, key, value):
239 def __setitem__(self, key, value):
259 if _is_section_key(key):
240 if _is_section_key(key):
260 if not isinstance(value, Config):
241 if not isinstance(value, Config):
261 raise ValueError('values whose keys begin with an uppercase '
242 raise ValueError('values whose keys begin with an uppercase '
262 'char must be Config instances: %r, %r' % (key, value))
243 'char must be Config instances: %r, %r' % (key, value))
263 dict.__setitem__(self, key, value)
244 dict.__setitem__(self, key, value)
264
245
265 def __getattr__(self, key):
246 def __getattr__(self, key):
266 if key.startswith('__'):
247 if key.startswith('__'):
267 return dict.__getattr__(self, key)
248 return dict.__getattr__(self, key)
268 try:
249 try:
269 return self.__getitem__(key)
250 return self.__getitem__(key)
270 except KeyError as e:
251 except KeyError as e:
271 raise AttributeError(e)
252 raise AttributeError(e)
272
253
273 def __setattr__(self, key, value):
254 def __setattr__(self, key, value):
274 if key.startswith('__'):
255 if key.startswith('__'):
275 return dict.__setattr__(self, key, value)
256 return dict.__setattr__(self, key, value)
276 try:
257 try:
277 self.__setitem__(key, value)
258 self.__setitem__(key, value)
278 except KeyError as e:
259 except KeyError as e:
279 raise AttributeError(e)
260 raise AttributeError(e)
280
261
281 def __delattr__(self, key):
262 def __delattr__(self, key):
282 if key.startswith('__'):
263 if key.startswith('__'):
283 return dict.__delattr__(self, key)
264 return dict.__delattr__(self, key)
284 try:
265 try:
285 dict.__delitem__(self, key)
266 dict.__delitem__(self, key)
286 except KeyError as e:
267 except KeyError as e:
287 raise AttributeError(e)
268 raise AttributeError(e)
288
269
289
270
290 #-----------------------------------------------------------------------------
271 #-----------------------------------------------------------------------------
291 # Config loading classes
272 # Config loading classes
292 #-----------------------------------------------------------------------------
273 #-----------------------------------------------------------------------------
293
274
294
275
295 class ConfigLoader(object):
276 class ConfigLoader(object):
296 """A object for loading configurations from just about anywhere.
277 """A object for loading configurations from just about anywhere.
297
278
298 The resulting configuration is packaged as a :class:`Config`.
279 The resulting configuration is packaged as a :class:`Config`.
299
280
300 Notes
281 Notes
301 -----
282 -----
302 A :class:`ConfigLoader` does one thing: load a config from a source
283 A :class:`ConfigLoader` does one thing: load a config from a source
303 (file, command line arguments) and returns the data as a :class:`Config` object.
284 (file, command line arguments) and returns the data as a :class:`Config` object.
304 There are lots of things that :class:`ConfigLoader` does not do. It does
285 There are lots of things that :class:`ConfigLoader` does not do. It does
305 not implement complex logic for finding config files. It does not handle
286 not implement complex logic for finding config files. It does not handle
306 default values or merge multiple configs. These things need to be
287 default values or merge multiple configs. These things need to be
307 handled elsewhere.
288 handled elsewhere.
308 """
289 """
309
290
310 def _log_default(self):
291 def _log_default(self):
311 from IPython.config.application import Application
292 from IPython.utils.log import get_logger
312 if Application.initialized():
293 return get_logger()
313 return Application.instance().log
314 else:
315 return logging.getLogger()
316
294
317 def __init__(self, log=None):
295 def __init__(self, log=None):
318 """A base class for config loaders.
296 """A base class for config loaders.
319
297
320 log : instance of :class:`logging.Logger` to use.
298 log : instance of :class:`logging.Logger` to use.
321 By default loger of :meth:`IPython.config.application.Application.instance()`
299 By default loger of :meth:`IPython.config.application.Application.instance()`
322 will be used
300 will be used
323
301
324 Examples
302 Examples
325 --------
303 --------
326
304
327 >>> cl = ConfigLoader()
305 >>> cl = ConfigLoader()
328 >>> config = cl.load_config()
306 >>> config = cl.load_config()
329 >>> config
307 >>> config
330 {}
308 {}
331 """
309 """
332 self.clear()
310 self.clear()
333 if log is None:
311 if log is None:
334 self.log = self._log_default()
312 self.log = self._log_default()
335 self.log.debug('Using default logger')
313 self.log.debug('Using default logger')
336 else:
314 else:
337 self.log = log
315 self.log = log
338
316
339 def clear(self):
317 def clear(self):
340 self.config = Config()
318 self.config = Config()
341
319
342 def load_config(self):
320 def load_config(self):
343 """Load a config from somewhere, return a :class:`Config` instance.
321 """Load a config from somewhere, return a :class:`Config` instance.
344
322
345 Usually, this will cause self.config to be set and then returned.
323 Usually, this will cause self.config to be set and then returned.
346 However, in most cases, :meth:`ConfigLoader.clear` should be called
324 However, in most cases, :meth:`ConfigLoader.clear` should be called
347 to erase any previous state.
325 to erase any previous state.
348 """
326 """
349 self.clear()
327 self.clear()
350 return self.config
328 return self.config
351
329
352
330
353 class FileConfigLoader(ConfigLoader):
331 class FileConfigLoader(ConfigLoader):
354 """A base class for file based configurations.
332 """A base class for file based configurations.
355
333
356 As we add more file based config loaders, the common logic should go
334 As we add more file based config loaders, the common logic should go
357 here.
335 here.
358 """
336 """
359
337
360 def __init__(self, filename, path=None, **kw):
338 def __init__(self, filename, path=None, **kw):
361 """Build a config loader for a filename and path.
339 """Build a config loader for a filename and path.
362
340
363 Parameters
341 Parameters
364 ----------
342 ----------
365 filename : str
343 filename : str
366 The file name of the config file.
344 The file name of the config file.
367 path : str, list, tuple
345 path : str, list, tuple
368 The path to search for the config file on, or a sequence of
346 The path to search for the config file on, or a sequence of
369 paths to try in order.
347 paths to try in order.
370 """
348 """
371 super(FileConfigLoader, self).__init__(**kw)
349 super(FileConfigLoader, self).__init__(**kw)
372 self.filename = filename
350 self.filename = filename
373 self.path = path
351 self.path = path
374 self.full_filename = ''
352 self.full_filename = ''
375
353
376 def _find_file(self):
354 def _find_file(self):
377 """Try to find the file by searching the paths."""
355 """Try to find the file by searching the paths."""
378 self.full_filename = filefind(self.filename, self.path)
356 self.full_filename = filefind(self.filename, self.path)
379
357
380 class JSONFileConfigLoader(FileConfigLoader):
358 class JSONFileConfigLoader(FileConfigLoader):
381 """A Json file loader for config"""
359 """A Json file loader for config"""
382
360
383 def load_config(self):
361 def load_config(self):
384 """Load the config from a file and return it as a Config object."""
362 """Load the config from a file and return it as a Config object."""
385 self.clear()
363 self.clear()
386 try:
364 try:
387 self._find_file()
365 self._find_file()
388 except IOError as e:
366 except IOError as e:
389 raise ConfigFileNotFound(str(e))
367 raise ConfigFileNotFound(str(e))
390 dct = self._read_file_as_dict()
368 dct = self._read_file_as_dict()
391 self.config = self._convert_to_config(dct)
369 self.config = self._convert_to_config(dct)
392 return self.config
370 return self.config
393
371
394 def _read_file_as_dict(self):
372 def _read_file_as_dict(self):
395 with open(self.full_filename) as f:
373 with open(self.full_filename) as f:
396 return json.load(f)
374 return json.load(f)
397
375
398 def _convert_to_config(self, dictionary):
376 def _convert_to_config(self, dictionary):
399 if 'version' in dictionary:
377 if 'version' in dictionary:
400 version = dictionary.pop('version')
378 version = dictionary.pop('version')
401 else:
379 else:
402 version = 1
380 version = 1
403 self.log.warn("Unrecognized JSON config file version, assuming version {}".format(version))
381 self.log.warn("Unrecognized JSON config file version, assuming version {}".format(version))
404
382
405 if version == 1:
383 if version == 1:
406 return Config(dictionary)
384 return Config(dictionary)
407 else:
385 else:
408 raise ValueError('Unknown version of JSON config file: {version}'.format(version=version))
386 raise ValueError('Unknown version of JSON config file: {version}'.format(version=version))
409
387
410
388
411 class PyFileConfigLoader(FileConfigLoader):
389 class PyFileConfigLoader(FileConfigLoader):
412 """A config loader for pure python files.
390 """A config loader for pure python files.
413
391
414 This is responsible for locating a Python config file by filename and
392 This is responsible for locating a Python config file by filename and
415 path, then executing it to construct a Config object.
393 path, then executing it to construct a Config object.
416 """
394 """
417
395
418 def load_config(self):
396 def load_config(self):
419 """Load the config from a file and return it as a Config object."""
397 """Load the config from a file and return it as a Config object."""
420 self.clear()
398 self.clear()
421 try:
399 try:
422 self._find_file()
400 self._find_file()
423 except IOError as e:
401 except IOError as e:
424 raise ConfigFileNotFound(str(e))
402 raise ConfigFileNotFound(str(e))
425 self._read_file_as_dict()
403 self._read_file_as_dict()
426 return self.config
404 return self.config
427
405
428
406
429 def _read_file_as_dict(self):
407 def _read_file_as_dict(self):
430 """Load the config file into self.config, with recursive loading."""
408 """Load the config file into self.config, with recursive loading."""
431 # This closure is made available in the namespace that is used
409 # This closure is made available in the namespace that is used
432 # to exec the config file. It allows users to call
410 # to exec the config file. It allows users to call
433 # load_subconfig('myconfig.py') to load config files recursively.
411 # load_subconfig('myconfig.py') to load config files recursively.
434 # It needs to be a closure because it has references to self.path
412 # It needs to be a closure because it has references to self.path
435 # and self.config. The sub-config is loaded with the same path
413 # and self.config. The sub-config is loaded with the same path
436 # as the parent, but it uses an empty config which is then merged
414 # as the parent, but it uses an empty config which is then merged
437 # with the parents.
415 # with the parents.
438
416
439 # If a profile is specified, the config file will be loaded
417 # If a profile is specified, the config file will be loaded
440 # from that profile
418 # from that profile
441
419
442 def load_subconfig(fname, profile=None):
420 def load_subconfig(fname, profile=None):
443 # import here to prevent circular imports
421 # import here to prevent circular imports
444 from IPython.core.profiledir import ProfileDir, ProfileDirError
422 from IPython.core.profiledir import ProfileDir, ProfileDirError
445 if profile is not None:
423 if profile is not None:
446 try:
424 try:
447 profile_dir = ProfileDir.find_profile_dir_by_name(
425 profile_dir = ProfileDir.find_profile_dir_by_name(
448 get_ipython_dir(),
426 get_ipython_dir(),
449 profile,
427 profile,
450 )
428 )
451 except ProfileDirError:
429 except ProfileDirError:
452 return
430 return
453 path = profile_dir.location
431 path = profile_dir.location
454 else:
432 else:
455 path = self.path
433 path = self.path
456 loader = PyFileConfigLoader(fname, path)
434 loader = PyFileConfigLoader(fname, path)
457 try:
435 try:
458 sub_config = loader.load_config()
436 sub_config = loader.load_config()
459 except ConfigFileNotFound:
437 except ConfigFileNotFound:
460 # Pass silently if the sub config is not there. This happens
438 # Pass silently if the sub config is not there. This happens
461 # when a user s using a profile, but not the default config.
439 # when a user s using a profile, but not the default config.
462 pass
440 pass
463 else:
441 else:
464 self.config.merge(sub_config)
442 self.config.merge(sub_config)
465
443
466 # Again, this needs to be a closure and should be used in config
444 # Again, this needs to be a closure and should be used in config
467 # files to get the config being loaded.
445 # files to get the config being loaded.
468 def get_config():
446 def get_config():
469 return self.config
447 return self.config
470
448
471 namespace = dict(
449 namespace = dict(
472 load_subconfig=load_subconfig,
450 load_subconfig=load_subconfig,
473 get_config=get_config,
451 get_config=get_config,
474 __file__=self.full_filename,
452 __file__=self.full_filename,
475 )
453 )
476 fs_encoding = sys.getfilesystemencoding() or 'ascii'
454 fs_encoding = sys.getfilesystemencoding() or 'ascii'
477 conf_filename = self.full_filename.encode(fs_encoding)
455 conf_filename = self.full_filename.encode(fs_encoding)
478 py3compat.execfile(conf_filename, namespace)
456 py3compat.execfile(conf_filename, namespace)
479
457
480
458
481 class CommandLineConfigLoader(ConfigLoader):
459 class CommandLineConfigLoader(ConfigLoader):
482 """A config loader for command line arguments.
460 """A config loader for command line arguments.
483
461
484 As we add more command line based loaders, the common logic should go
462 As we add more command line based loaders, the common logic should go
485 here.
463 here.
486 """
464 """
487
465
488 def _exec_config_str(self, lhs, rhs):
466 def _exec_config_str(self, lhs, rhs):
489 """execute self.config.<lhs> = <rhs>
467 """execute self.config.<lhs> = <rhs>
490
468
491 * expands ~ with expanduser
469 * expands ~ with expanduser
492 * tries to assign with raw eval, otherwise assigns with just the string,
470 * tries to assign with raw eval, otherwise assigns with just the string,
493 allowing `--C.a=foobar` and `--C.a="foobar"` to be equivalent. *Not*
471 allowing `--C.a=foobar` and `--C.a="foobar"` to be equivalent. *Not*
494 equivalent are `--C.a=4` and `--C.a='4'`.
472 equivalent are `--C.a=4` and `--C.a='4'`.
495 """
473 """
496 rhs = os.path.expanduser(rhs)
474 rhs = os.path.expanduser(rhs)
497 try:
475 try:
498 # Try to see if regular Python syntax will work. This
476 # Try to see if regular Python syntax will work. This
499 # won't handle strings as the quote marks are removed
477 # won't handle strings as the quote marks are removed
500 # by the system shell.
478 # by the system shell.
501 value = eval(rhs)
479 value = eval(rhs)
502 except (NameError, SyntaxError):
480 except (NameError, SyntaxError):
503 # This case happens if the rhs is a string.
481 # This case happens if the rhs is a string.
504 value = rhs
482 value = rhs
505
483
506 exec(u'self.config.%s = value' % lhs)
484 exec(u'self.config.%s = value' % lhs)
507
485
508 def _load_flag(self, cfg):
486 def _load_flag(self, cfg):
509 """update self.config from a flag, which can be a dict or Config"""
487 """update self.config from a flag, which can be a dict or Config"""
510 if isinstance(cfg, (dict, Config)):
488 if isinstance(cfg, (dict, Config)):
511 # don't clobber whole config sections, update
489 # don't clobber whole config sections, update
512 # each section from config:
490 # each section from config:
513 for sec,c in iteritems(cfg):
491 for sec,c in iteritems(cfg):
514 self.config[sec].update(c)
492 self.config[sec].update(c)
515 else:
493 else:
516 raise TypeError("Invalid flag: %r" % cfg)
494 raise TypeError("Invalid flag: %r" % cfg)
517
495
518 # raw --identifier=value pattern
496 # raw --identifier=value pattern
519 # but *also* accept '-' as wordsep, for aliases
497 # but *also* accept '-' as wordsep, for aliases
520 # accepts: --foo=a
498 # accepts: --foo=a
521 # --Class.trait=value
499 # --Class.trait=value
522 # --alias-name=value
500 # --alias-name=value
523 # rejects: -foo=value
501 # rejects: -foo=value
524 # --foo
502 # --foo
525 # --Class.trait
503 # --Class.trait
526 kv_pattern = re.compile(r'\-\-[A-Za-z][\w\-]*(\.[\w\-]+)*\=.*')
504 kv_pattern = re.compile(r'\-\-[A-Za-z][\w\-]*(\.[\w\-]+)*\=.*')
527
505
528 # just flags, no assignments, with two *or one* leading '-'
506 # just flags, no assignments, with two *or one* leading '-'
529 # accepts: --foo
507 # accepts: --foo
530 # -foo-bar-again
508 # -foo-bar-again
531 # rejects: --anything=anything
509 # rejects: --anything=anything
532 # --two.word
510 # --two.word
533
511
534 flag_pattern = re.compile(r'\-\-?\w+[\-\w]*$')
512 flag_pattern = re.compile(r'\-\-?\w+[\-\w]*$')
535
513
536 class KeyValueConfigLoader(CommandLineConfigLoader):
514 class KeyValueConfigLoader(CommandLineConfigLoader):
537 """A config loader that loads key value pairs from the command line.
515 """A config loader that loads key value pairs from the command line.
538
516
539 This allows command line options to be gives in the following form::
517 This allows command line options to be gives in the following form::
540
518
541 ipython --profile="foo" --InteractiveShell.autocall=False
519 ipython --profile="foo" --InteractiveShell.autocall=False
542 """
520 """
543
521
544 def __init__(self, argv=None, aliases=None, flags=None, **kw):
522 def __init__(self, argv=None, aliases=None, flags=None, **kw):
545 """Create a key value pair config loader.
523 """Create a key value pair config loader.
546
524
547 Parameters
525 Parameters
548 ----------
526 ----------
549 argv : list
527 argv : list
550 A list that has the form of sys.argv[1:] which has unicode
528 A list that has the form of sys.argv[1:] which has unicode
551 elements of the form u"key=value". If this is None (default),
529 elements of the form u"key=value". If this is None (default),
552 then sys.argv[1:] will be used.
530 then sys.argv[1:] will be used.
553 aliases : dict
531 aliases : dict
554 A dict of aliases for configurable traits.
532 A dict of aliases for configurable traits.
555 Keys are the short aliases, Values are the resolved trait.
533 Keys are the short aliases, Values are the resolved trait.
556 Of the form: `{'alias' : 'Configurable.trait'}`
534 Of the form: `{'alias' : 'Configurable.trait'}`
557 flags : dict
535 flags : dict
558 A dict of flags, keyed by str name. Vaues can be Config objects,
536 A dict of flags, keyed by str name. Vaues can be Config objects,
559 dicts, or "key=value" strings. If Config or dict, when the flag
537 dicts, or "key=value" strings. If Config or dict, when the flag
560 is triggered, The flag is loaded as `self.config.update(m)`.
538 is triggered, The flag is loaded as `self.config.update(m)`.
561
539
562 Returns
540 Returns
563 -------
541 -------
564 config : Config
542 config : Config
565 The resulting Config object.
543 The resulting Config object.
566
544
567 Examples
545 Examples
568 --------
546 --------
569
547
570 >>> from IPython.config.loader import KeyValueConfigLoader
548 >>> from IPython.config.loader import KeyValueConfigLoader
571 >>> cl = KeyValueConfigLoader()
549 >>> cl = KeyValueConfigLoader()
572 >>> d = cl.load_config(["--A.name='brian'","--B.number=0"])
550 >>> d = cl.load_config(["--A.name='brian'","--B.number=0"])
573 >>> sorted(d.items())
551 >>> sorted(d.items())
574 [('A', {'name': 'brian'}), ('B', {'number': 0})]
552 [('A', {'name': 'brian'}), ('B', {'number': 0})]
575 """
553 """
576 super(KeyValueConfigLoader, self).__init__(**kw)
554 super(KeyValueConfigLoader, self).__init__(**kw)
577 if argv is None:
555 if argv is None:
578 argv = sys.argv[1:]
556 argv = sys.argv[1:]
579 self.argv = argv
557 self.argv = argv
580 self.aliases = aliases or {}
558 self.aliases = aliases or {}
581 self.flags = flags or {}
559 self.flags = flags or {}
582
560
583
561
584 def clear(self):
562 def clear(self):
585 super(KeyValueConfigLoader, self).clear()
563 super(KeyValueConfigLoader, self).clear()
586 self.extra_args = []
564 self.extra_args = []
587
565
588
566
589 def _decode_argv(self, argv, enc=None):
567 def _decode_argv(self, argv, enc=None):
590 """decode argv if bytes, using stin.encoding, falling back on default enc"""
568 """decode argv if bytes, using stin.encoding, falling back on default enc"""
591 uargv = []
569 uargv = []
592 if enc is None:
570 if enc is None:
593 enc = DEFAULT_ENCODING
571 enc = DEFAULT_ENCODING
594 for arg in argv:
572 for arg in argv:
595 if not isinstance(arg, unicode_type):
573 if not isinstance(arg, unicode_type):
596 # only decode if not already decoded
574 # only decode if not already decoded
597 arg = arg.decode(enc)
575 arg = arg.decode(enc)
598 uargv.append(arg)
576 uargv.append(arg)
599 return uargv
577 return uargv
600
578
601
579
602 def load_config(self, argv=None, aliases=None, flags=None):
580 def load_config(self, argv=None, aliases=None, flags=None):
603 """Parse the configuration and generate the Config object.
581 """Parse the configuration and generate the Config object.
604
582
605 After loading, any arguments that are not key-value or
583 After loading, any arguments that are not key-value or
606 flags will be stored in self.extra_args - a list of
584 flags will be stored in self.extra_args - a list of
607 unparsed command-line arguments. This is used for
585 unparsed command-line arguments. This is used for
608 arguments such as input files or subcommands.
586 arguments such as input files or subcommands.
609
587
610 Parameters
588 Parameters
611 ----------
589 ----------
612 argv : list, optional
590 argv : list, optional
613 A list that has the form of sys.argv[1:] which has unicode
591 A list that has the form of sys.argv[1:] which has unicode
614 elements of the form u"key=value". If this is None (default),
592 elements of the form u"key=value". If this is None (default),
615 then self.argv will be used.
593 then self.argv will be used.
616 aliases : dict
594 aliases : dict
617 A dict of aliases for configurable traits.
595 A dict of aliases for configurable traits.
618 Keys are the short aliases, Values are the resolved trait.
596 Keys are the short aliases, Values are the resolved trait.
619 Of the form: `{'alias' : 'Configurable.trait'}`
597 Of the form: `{'alias' : 'Configurable.trait'}`
620 flags : dict
598 flags : dict
621 A dict of flags, keyed by str name. Values can be Config objects
599 A dict of flags, keyed by str name. Values can be Config objects
622 or dicts. When the flag is triggered, The config is loaded as
600 or dicts. When the flag is triggered, The config is loaded as
623 `self.config.update(cfg)`.
601 `self.config.update(cfg)`.
624 """
602 """
625 self.clear()
603 self.clear()
626 if argv is None:
604 if argv is None:
627 argv = self.argv
605 argv = self.argv
628 if aliases is None:
606 if aliases is None:
629 aliases = self.aliases
607 aliases = self.aliases
630 if flags is None:
608 if flags is None:
631 flags = self.flags
609 flags = self.flags
632
610
633 # ensure argv is a list of unicode strings:
611 # ensure argv is a list of unicode strings:
634 uargv = self._decode_argv(argv)
612 uargv = self._decode_argv(argv)
635 for idx,raw in enumerate(uargv):
613 for idx,raw in enumerate(uargv):
636 # strip leading '-'
614 # strip leading '-'
637 item = raw.lstrip('-')
615 item = raw.lstrip('-')
638
616
639 if raw == '--':
617 if raw == '--':
640 # don't parse arguments after '--'
618 # don't parse arguments after '--'
641 # this is useful for relaying arguments to scripts, e.g.
619 # this is useful for relaying arguments to scripts, e.g.
642 # ipython -i foo.py --matplotlib=qt -- args after '--' go-to-foo.py
620 # ipython -i foo.py --matplotlib=qt -- args after '--' go-to-foo.py
643 self.extra_args.extend(uargv[idx+1:])
621 self.extra_args.extend(uargv[idx+1:])
644 break
622 break
645
623
646 if kv_pattern.match(raw):
624 if kv_pattern.match(raw):
647 lhs,rhs = item.split('=',1)
625 lhs,rhs = item.split('=',1)
648 # Substitute longnames for aliases.
626 # Substitute longnames for aliases.
649 if lhs in aliases:
627 if lhs in aliases:
650 lhs = aliases[lhs]
628 lhs = aliases[lhs]
651 if '.' not in lhs:
629 if '.' not in lhs:
652 # probably a mistyped alias, but not technically illegal
630 # probably a mistyped alias, but not technically illegal
653 self.log.warn("Unrecognized alias: '%s', it will probably have no effect.", raw)
631 self.log.warn("Unrecognized alias: '%s', it will probably have no effect.", raw)
654 try:
632 try:
655 self._exec_config_str(lhs, rhs)
633 self._exec_config_str(lhs, rhs)
656 except Exception:
634 except Exception:
657 raise ArgumentError("Invalid argument: '%s'" % raw)
635 raise ArgumentError("Invalid argument: '%s'" % raw)
658
636
659 elif flag_pattern.match(raw):
637 elif flag_pattern.match(raw):
660 if item in flags:
638 if item in flags:
661 cfg,help = flags[item]
639 cfg,help = flags[item]
662 self._load_flag(cfg)
640 self._load_flag(cfg)
663 else:
641 else:
664 raise ArgumentError("Unrecognized flag: '%s'"%raw)
642 raise ArgumentError("Unrecognized flag: '%s'"%raw)
665 elif raw.startswith('-'):
643 elif raw.startswith('-'):
666 kv = '--'+item
644 kv = '--'+item
667 if kv_pattern.match(kv):
645 if kv_pattern.match(kv):
668 raise ArgumentError("Invalid argument: '%s', did you mean '%s'?"%(raw, kv))
646 raise ArgumentError("Invalid argument: '%s', did you mean '%s'?"%(raw, kv))
669 else:
647 else:
670 raise ArgumentError("Invalid argument: '%s'"%raw)
648 raise ArgumentError("Invalid argument: '%s'"%raw)
671 else:
649 else:
672 # keep all args that aren't valid in a list,
650 # keep all args that aren't valid in a list,
673 # in case our parent knows what to do with them.
651 # in case our parent knows what to do with them.
674 self.extra_args.append(item)
652 self.extra_args.append(item)
675 return self.config
653 return self.config
676
654
677 class ArgParseConfigLoader(CommandLineConfigLoader):
655 class ArgParseConfigLoader(CommandLineConfigLoader):
678 """A loader that uses the argparse module to load from the command line."""
656 """A loader that uses the argparse module to load from the command line."""
679
657
680 def __init__(self, argv=None, aliases=None, flags=None, log=None, *parser_args, **parser_kw):
658 def __init__(self, argv=None, aliases=None, flags=None, log=None, *parser_args, **parser_kw):
681 """Create a config loader for use with argparse.
659 """Create a config loader for use with argparse.
682
660
683 Parameters
661 Parameters
684 ----------
662 ----------
685
663
686 argv : optional, list
664 argv : optional, list
687 If given, used to read command-line arguments from, otherwise
665 If given, used to read command-line arguments from, otherwise
688 sys.argv[1:] is used.
666 sys.argv[1:] is used.
689
667
690 parser_args : tuple
668 parser_args : tuple
691 A tuple of positional arguments that will be passed to the
669 A tuple of positional arguments that will be passed to the
692 constructor of :class:`argparse.ArgumentParser`.
670 constructor of :class:`argparse.ArgumentParser`.
693
671
694 parser_kw : dict
672 parser_kw : dict
695 A tuple of keyword arguments that will be passed to the
673 A tuple of keyword arguments that will be passed to the
696 constructor of :class:`argparse.ArgumentParser`.
674 constructor of :class:`argparse.ArgumentParser`.
697
675
698 Returns
676 Returns
699 -------
677 -------
700 config : Config
678 config : Config
701 The resulting Config object.
679 The resulting Config object.
702 """
680 """
703 super(CommandLineConfigLoader, self).__init__(log=log)
681 super(CommandLineConfigLoader, self).__init__(log=log)
704 self.clear()
682 self.clear()
705 if argv is None:
683 if argv is None:
706 argv = sys.argv[1:]
684 argv = sys.argv[1:]
707 self.argv = argv
685 self.argv = argv
708 self.aliases = aliases or {}
686 self.aliases = aliases or {}
709 self.flags = flags or {}
687 self.flags = flags or {}
710
688
711 self.parser_args = parser_args
689 self.parser_args = parser_args
712 self.version = parser_kw.pop("version", None)
690 self.version = parser_kw.pop("version", None)
713 kwargs = dict(argument_default=argparse.SUPPRESS)
691 kwargs = dict(argument_default=argparse.SUPPRESS)
714 kwargs.update(parser_kw)
692 kwargs.update(parser_kw)
715 self.parser_kw = kwargs
693 self.parser_kw = kwargs
716
694
717 def load_config(self, argv=None, aliases=None, flags=None):
695 def load_config(self, argv=None, aliases=None, flags=None):
718 """Parse command line arguments and return as a Config object.
696 """Parse command line arguments and return as a Config object.
719
697
720 Parameters
698 Parameters
721 ----------
699 ----------
722
700
723 args : optional, list
701 args : optional, list
724 If given, a list with the structure of sys.argv[1:] to parse
702 If given, a list with the structure of sys.argv[1:] to parse
725 arguments from. If not given, the instance's self.argv attribute
703 arguments from. If not given, the instance's self.argv attribute
726 (given at construction time) is used."""
704 (given at construction time) is used."""
727 self.clear()
705 self.clear()
728 if argv is None:
706 if argv is None:
729 argv = self.argv
707 argv = self.argv
730 if aliases is None:
708 if aliases is None:
731 aliases = self.aliases
709 aliases = self.aliases
732 if flags is None:
710 if flags is None:
733 flags = self.flags
711 flags = self.flags
734 self._create_parser(aliases, flags)
712 self._create_parser(aliases, flags)
735 self._parse_args(argv)
713 self._parse_args(argv)
736 self._convert_to_config()
714 self._convert_to_config()
737 return self.config
715 return self.config
738
716
739 def get_extra_args(self):
717 def get_extra_args(self):
740 if hasattr(self, 'extra_args'):
718 if hasattr(self, 'extra_args'):
741 return self.extra_args
719 return self.extra_args
742 else:
720 else:
743 return []
721 return []
744
722
745 def _create_parser(self, aliases=None, flags=None):
723 def _create_parser(self, aliases=None, flags=None):
746 self.parser = ArgumentParser(*self.parser_args, **self.parser_kw)
724 self.parser = ArgumentParser(*self.parser_args, **self.parser_kw)
747 self._add_arguments(aliases, flags)
725 self._add_arguments(aliases, flags)
748
726
749 def _add_arguments(self, aliases=None, flags=None):
727 def _add_arguments(self, aliases=None, flags=None):
750 raise NotImplementedError("subclasses must implement _add_arguments")
728 raise NotImplementedError("subclasses must implement _add_arguments")
751
729
752 def _parse_args(self, args):
730 def _parse_args(self, args):
753 """self.parser->self.parsed_data"""
731 """self.parser->self.parsed_data"""
754 # decode sys.argv to support unicode command-line options
732 # decode sys.argv to support unicode command-line options
755 enc = DEFAULT_ENCODING
733 enc = DEFAULT_ENCODING
756 uargs = [py3compat.cast_unicode(a, enc) for a in args]
734 uargs = [py3compat.cast_unicode(a, enc) for a in args]
757 self.parsed_data, self.extra_args = self.parser.parse_known_args(uargs)
735 self.parsed_data, self.extra_args = self.parser.parse_known_args(uargs)
758
736
759 def _convert_to_config(self):
737 def _convert_to_config(self):
760 """self.parsed_data->self.config"""
738 """self.parsed_data->self.config"""
761 for k, v in iteritems(vars(self.parsed_data)):
739 for k, v in iteritems(vars(self.parsed_data)):
762 exec("self.config.%s = v"%k, locals(), globals())
740 exec("self.config.%s = v"%k, locals(), globals())
763
741
764 class KVArgParseConfigLoader(ArgParseConfigLoader):
742 class KVArgParseConfigLoader(ArgParseConfigLoader):
765 """A config loader that loads aliases and flags with argparse,
743 """A config loader that loads aliases and flags with argparse,
766 but will use KVLoader for the rest. This allows better parsing
744 but will use KVLoader for the rest. This allows better parsing
767 of common args, such as `ipython -c 'print 5'`, but still gets
745 of common args, such as `ipython -c 'print 5'`, but still gets
768 arbitrary config with `ipython --InteractiveShell.use_readline=False`"""
746 arbitrary config with `ipython --InteractiveShell.use_readline=False`"""
769
747
770 def _add_arguments(self, aliases=None, flags=None):
748 def _add_arguments(self, aliases=None, flags=None):
771 self.alias_flags = {}
749 self.alias_flags = {}
772 # print aliases, flags
750 # print aliases, flags
773 if aliases is None:
751 if aliases is None:
774 aliases = self.aliases
752 aliases = self.aliases
775 if flags is None:
753 if flags is None:
776 flags = self.flags
754 flags = self.flags
777 paa = self.parser.add_argument
755 paa = self.parser.add_argument
778 for key,value in iteritems(aliases):
756 for key,value in iteritems(aliases):
779 if key in flags:
757 if key in flags:
780 # flags
758 # flags
781 nargs = '?'
759 nargs = '?'
782 else:
760 else:
783 nargs = None
761 nargs = None
784 if len(key) is 1:
762 if len(key) is 1:
785 paa('-'+key, '--'+key, type=unicode_type, dest=value, nargs=nargs)
763 paa('-'+key, '--'+key, type=unicode_type, dest=value, nargs=nargs)
786 else:
764 else:
787 paa('--'+key, type=unicode_type, dest=value, nargs=nargs)
765 paa('--'+key, type=unicode_type, dest=value, nargs=nargs)
788 for key, (value, help) in iteritems(flags):
766 for key, (value, help) in iteritems(flags):
789 if key in self.aliases:
767 if key in self.aliases:
790 #
768 #
791 self.alias_flags[self.aliases[key]] = value
769 self.alias_flags[self.aliases[key]] = value
792 continue
770 continue
793 if len(key) is 1:
771 if len(key) is 1:
794 paa('-'+key, '--'+key, action='append_const', dest='_flags', const=value)
772 paa('-'+key, '--'+key, action='append_const', dest='_flags', const=value)
795 else:
773 else:
796 paa('--'+key, action='append_const', dest='_flags', const=value)
774 paa('--'+key, action='append_const', dest='_flags', const=value)
797
775
798 def _convert_to_config(self):
776 def _convert_to_config(self):
799 """self.parsed_data->self.config, parse unrecognized extra args via KVLoader."""
777 """self.parsed_data->self.config, parse unrecognized extra args via KVLoader."""
800 # remove subconfigs list from namespace before transforming the Namespace
778 # remove subconfigs list from namespace before transforming the Namespace
801 if '_flags' in self.parsed_data:
779 if '_flags' in self.parsed_data:
802 subcs = self.parsed_data._flags
780 subcs = self.parsed_data._flags
803 del self.parsed_data._flags
781 del self.parsed_data._flags
804 else:
782 else:
805 subcs = []
783 subcs = []
806
784
807 for k, v in iteritems(vars(self.parsed_data)):
785 for k, v in iteritems(vars(self.parsed_data)):
808 if v is None:
786 if v is None:
809 # it was a flag that shares the name of an alias
787 # it was a flag that shares the name of an alias
810 subcs.append(self.alias_flags[k])
788 subcs.append(self.alias_flags[k])
811 else:
789 else:
812 # eval the KV assignment
790 # eval the KV assignment
813 self._exec_config_str(k, v)
791 self._exec_config_str(k, v)
814
792
815 for subc in subcs:
793 for subc in subcs:
816 self._load_flag(subc)
794 self._load_flag(subc)
817
795
818 if self.extra_args:
796 if self.extra_args:
819 sub_parser = KeyValueConfigLoader(log=self.log)
797 sub_parser = KeyValueConfigLoader(log=self.log)
820 sub_parser.load_config(self.extra_args)
798 sub_parser.load_config(self.extra_args)
821 self.config.merge(sub_parser.config)
799 self.config.merge(sub_parser.config)
822 self.extra_args = sub_parser.extra_args
800 self.extra_args = sub_parser.extra_args
823
801
824
802
825 def load_pyconfig_files(config_files, path):
803 def load_pyconfig_files(config_files, path):
826 """Load multiple Python config files, merging each of them in turn.
804 """Load multiple Python config files, merging each of them in turn.
827
805
828 Parameters
806 Parameters
829 ==========
807 ==========
830 config_files : list of str
808 config_files : list of str
831 List of config files names to load and merge into the config.
809 List of config files names to load and merge into the config.
832 path : unicode
810 path : unicode
833 The full path to the location of the config files.
811 The full path to the location of the config files.
834 """
812 """
835 config = Config()
813 config = Config()
836 for cf in config_files:
814 for cf in config_files:
837 loader = PyFileConfigLoader(cf, path=path)
815 loader = PyFileConfigLoader(cf, path=path)
838 try:
816 try:
839 next_config = loader.load_config()
817 next_config = loader.load_config()
840 except ConfigFileNotFound:
818 except ConfigFileNotFound:
841 pass
819 pass
842 except:
820 except:
843 raise
821 raise
844 else:
822 else:
845 config.merge(next_config)
823 config.merge(next_config)
846 return config
824 return config
@@ -1,390 +1,376 b''
1 """Base Tornado handlers for the notebook.
1 """Base Tornado handlers for the notebook."""
2
3 Authors:
4
5 * Brian Granger
6 """
7
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
10 #
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
14
15 #-----------------------------------------------------------------------------
16 # Imports
17 #-----------------------------------------------------------------------------
18
2
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
19
5
20 import functools
6 import functools
21 import json
7 import json
22 import logging
8 import logging
23 import os
9 import os
24 import re
10 import re
25 import sys
11 import sys
26 import traceback
12 import traceback
27 try:
13 try:
28 # py3
14 # py3
29 from http.client import responses
15 from http.client import responses
30 except ImportError:
16 except ImportError:
31 from httplib import responses
17 from httplib import responses
32
18
33 from jinja2 import TemplateNotFound
19 from jinja2 import TemplateNotFound
34 from tornado import web
20 from tornado import web
35
21
36 try:
22 try:
37 from tornado.log import app_log
23 from tornado.log import app_log
38 except ImportError:
24 except ImportError:
39 app_log = logging.getLogger()
25 app_log = logging.getLogger()
40
26
41 from IPython.config import Application
27 from IPython.config import Application
42 from IPython.utils.path import filefind
28 from IPython.utils.path import filefind
43 from IPython.utils.py3compat import string_types
29 from IPython.utils.py3compat import string_types
44 from IPython.html.utils import is_hidden
30 from IPython.html.utils import is_hidden
45
31
46 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
47 # Top-level handlers
33 # Top-level handlers
48 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
49 non_alphanum = re.compile(r'[^A-Za-z0-9]')
35 non_alphanum = re.compile(r'[^A-Za-z0-9]')
50
36
51 class AuthenticatedHandler(web.RequestHandler):
37 class AuthenticatedHandler(web.RequestHandler):
52 """A RequestHandler with an authenticated user."""
38 """A RequestHandler with an authenticated user."""
53
39
54 def set_default_headers(self):
40 def set_default_headers(self):
55 headers = self.settings.get('headers', {})
41 headers = self.settings.get('headers', {})
56 for header_name,value in headers.items() :
42 for header_name,value in headers.items() :
57 try:
43 try:
58 self.set_header(header_name, value)
44 self.set_header(header_name, value)
59 except Exception:
45 except Exception:
60 # tornado raise Exception (not a subclass)
46 # tornado raise Exception (not a subclass)
61 # if method is unsupported (websocket and Access-Control-Allow-Origin
47 # if method is unsupported (websocket and Access-Control-Allow-Origin
62 # for example, so just ignore)
48 # for example, so just ignore)
63 pass
49 pass
64
50
65 def clear_login_cookie(self):
51 def clear_login_cookie(self):
66 self.clear_cookie(self.cookie_name)
52 self.clear_cookie(self.cookie_name)
67
53
68 def get_current_user(self):
54 def get_current_user(self):
69 user_id = self.get_secure_cookie(self.cookie_name)
55 user_id = self.get_secure_cookie(self.cookie_name)
70 # For now the user_id should not return empty, but it could eventually
56 # For now the user_id should not return empty, but it could eventually
71 if user_id == '':
57 if user_id == '':
72 user_id = 'anonymous'
58 user_id = 'anonymous'
73 if user_id is None:
59 if user_id is None:
74 # prevent extra Invalid cookie sig warnings:
60 # prevent extra Invalid cookie sig warnings:
75 self.clear_login_cookie()
61 self.clear_login_cookie()
76 if not self.login_available:
62 if not self.login_available:
77 user_id = 'anonymous'
63 user_id = 'anonymous'
78 return user_id
64 return user_id
79
65
80 @property
66 @property
81 def cookie_name(self):
67 def cookie_name(self):
82 default_cookie_name = non_alphanum.sub('-', 'username-{}'.format(
68 default_cookie_name = non_alphanum.sub('-', 'username-{}'.format(
83 self.request.host
69 self.request.host
84 ))
70 ))
85 return self.settings.get('cookie_name', default_cookie_name)
71 return self.settings.get('cookie_name', default_cookie_name)
86
72
87 @property
73 @property
88 def password(self):
74 def password(self):
89 """our password"""
75 """our password"""
90 return self.settings.get('password', '')
76 return self.settings.get('password', '')
91
77
92 @property
78 @property
93 def logged_in(self):
79 def logged_in(self):
94 """Is a user currently logged in?
80 """Is a user currently logged in?
95
81
96 """
82 """
97 user = self.get_current_user()
83 user = self.get_current_user()
98 return (user and not user == 'anonymous')
84 return (user and not user == 'anonymous')
99
85
100 @property
86 @property
101 def login_available(self):
87 def login_available(self):
102 """May a user proceed to log in?
88 """May a user proceed to log in?
103
89
104 This returns True if login capability is available, irrespective of
90 This returns True if login capability is available, irrespective of
105 whether the user is already logged in or not.
91 whether the user is already logged in or not.
106
92
107 """
93 """
108 return bool(self.settings.get('password', ''))
94 return bool(self.settings.get('password', ''))
109
95
110
96
111 class IPythonHandler(AuthenticatedHandler):
97 class IPythonHandler(AuthenticatedHandler):
112 """IPython-specific extensions to authenticated handling
98 """IPython-specific extensions to authenticated handling
113
99
114 Mostly property shortcuts to IPython-specific settings.
100 Mostly property shortcuts to IPython-specific settings.
115 """
101 """
116
102
117 @property
103 @property
118 def config(self):
104 def config(self):
119 return self.settings.get('config', None)
105 return self.settings.get('config', None)
120
106
121 @property
107 @property
122 def log(self):
108 def log(self):
123 """use the IPython log by default, falling back on tornado's logger"""
109 """use the IPython log by default, falling back on tornado's logger"""
124 if Application.initialized():
110 if Application.initialized():
125 return Application.instance().log
111 return Application.instance().log
126 else:
112 else:
127 return app_log
113 return app_log
128
114
129 #---------------------------------------------------------------
115 #---------------------------------------------------------------
130 # URLs
116 # URLs
131 #---------------------------------------------------------------
117 #---------------------------------------------------------------
132
118
133 @property
119 @property
134 def mathjax_url(self):
120 def mathjax_url(self):
135 return self.settings.get('mathjax_url', '')
121 return self.settings.get('mathjax_url', '')
136
122
137 @property
123 @property
138 def base_url(self):
124 def base_url(self):
139 return self.settings.get('base_url', '/')
125 return self.settings.get('base_url', '/')
140
126
141 #---------------------------------------------------------------
127 #---------------------------------------------------------------
142 # Manager objects
128 # Manager objects
143 #---------------------------------------------------------------
129 #---------------------------------------------------------------
144
130
145 @property
131 @property
146 def kernel_manager(self):
132 def kernel_manager(self):
147 return self.settings['kernel_manager']
133 return self.settings['kernel_manager']
148
134
149 @property
135 @property
150 def notebook_manager(self):
136 def notebook_manager(self):
151 return self.settings['notebook_manager']
137 return self.settings['notebook_manager']
152
138
153 @property
139 @property
154 def cluster_manager(self):
140 def cluster_manager(self):
155 return self.settings['cluster_manager']
141 return self.settings['cluster_manager']
156
142
157 @property
143 @property
158 def session_manager(self):
144 def session_manager(self):
159 return self.settings['session_manager']
145 return self.settings['session_manager']
160
146
161 @property
147 @property
162 def kernel_spec_manager(self):
148 def kernel_spec_manager(self):
163 return self.settings['kernel_spec_manager']
149 return self.settings['kernel_spec_manager']
164
150
165 @property
151 @property
166 def project_dir(self):
152 def project_dir(self):
167 return self.notebook_manager.notebook_dir
153 return self.notebook_manager.notebook_dir
168
154
169 #---------------------------------------------------------------
155 #---------------------------------------------------------------
170 # template rendering
156 # template rendering
171 #---------------------------------------------------------------
157 #---------------------------------------------------------------
172
158
173 def get_template(self, name):
159 def get_template(self, name):
174 """Return the jinja template object for a given name"""
160 """Return the jinja template object for a given name"""
175 return self.settings['jinja2_env'].get_template(name)
161 return self.settings['jinja2_env'].get_template(name)
176
162
177 def render_template(self, name, **ns):
163 def render_template(self, name, **ns):
178 ns.update(self.template_namespace)
164 ns.update(self.template_namespace)
179 template = self.get_template(name)
165 template = self.get_template(name)
180 return template.render(**ns)
166 return template.render(**ns)
181
167
182 @property
168 @property
183 def template_namespace(self):
169 def template_namespace(self):
184 return dict(
170 return dict(
185 base_url=self.base_url,
171 base_url=self.base_url,
186 logged_in=self.logged_in,
172 logged_in=self.logged_in,
187 login_available=self.login_available,
173 login_available=self.login_available,
188 static_url=self.static_url,
174 static_url=self.static_url,
189 )
175 )
190
176
191 def get_json_body(self):
177 def get_json_body(self):
192 """Return the body of the request as JSON data."""
178 """Return the body of the request as JSON data."""
193 if not self.request.body:
179 if not self.request.body:
194 return None
180 return None
195 # Do we need to call body.decode('utf-8') here?
181 # Do we need to call body.decode('utf-8') here?
196 body = self.request.body.strip().decode(u'utf-8')
182 body = self.request.body.strip().decode(u'utf-8')
197 try:
183 try:
198 model = json.loads(body)
184 model = json.loads(body)
199 except Exception:
185 except Exception:
200 self.log.debug("Bad JSON: %r", body)
186 self.log.debug("Bad JSON: %r", body)
201 self.log.error("Couldn't parse JSON", exc_info=True)
187 self.log.error("Couldn't parse JSON", exc_info=True)
202 raise web.HTTPError(400, u'Invalid JSON in body of request')
188 raise web.HTTPError(400, u'Invalid JSON in body of request')
203 return model
189 return model
204
190
205 def get_error_html(self, status_code, **kwargs):
191 def get_error_html(self, status_code, **kwargs):
206 """render custom error pages"""
192 """render custom error pages"""
207 exception = kwargs.get('exception')
193 exception = kwargs.get('exception')
208 message = ''
194 message = ''
209 status_message = responses.get(status_code, 'Unknown HTTP Error')
195 status_message = responses.get(status_code, 'Unknown HTTP Error')
210 if exception:
196 if exception:
211 # get the custom message, if defined
197 # get the custom message, if defined
212 try:
198 try:
213 message = exception.log_message % exception.args
199 message = exception.log_message % exception.args
214 except Exception:
200 except Exception:
215 pass
201 pass
216
202
217 # construct the custom reason, if defined
203 # construct the custom reason, if defined
218 reason = getattr(exception, 'reason', '')
204 reason = getattr(exception, 'reason', '')
219 if reason:
205 if reason:
220 status_message = reason
206 status_message = reason
221
207
222 # build template namespace
208 # build template namespace
223 ns = dict(
209 ns = dict(
224 status_code=status_code,
210 status_code=status_code,
225 status_message=status_message,
211 status_message=status_message,
226 message=message,
212 message=message,
227 exception=exception,
213 exception=exception,
228 )
214 )
229
215
230 # render the template
216 # render the template
231 try:
217 try:
232 html = self.render_template('%s.html' % status_code, **ns)
218 html = self.render_template('%s.html' % status_code, **ns)
233 except TemplateNotFound:
219 except TemplateNotFound:
234 self.log.debug("No template for %d", status_code)
220 self.log.debug("No template for %d", status_code)
235 html = self.render_template('error.html', **ns)
221 html = self.render_template('error.html', **ns)
236 return html
222 return html
237
223
238
224
239 class Template404(IPythonHandler):
225 class Template404(IPythonHandler):
240 """Render our 404 template"""
226 """Render our 404 template"""
241 def prepare(self):
227 def prepare(self):
242 raise web.HTTPError(404)
228 raise web.HTTPError(404)
243
229
244
230
245 class AuthenticatedFileHandler(IPythonHandler, web.StaticFileHandler):
231 class AuthenticatedFileHandler(IPythonHandler, web.StaticFileHandler):
246 """static files should only be accessible when logged in"""
232 """static files should only be accessible when logged in"""
247
233
248 @web.authenticated
234 @web.authenticated
249 def get(self, path):
235 def get(self, path):
250 if os.path.splitext(path)[1] == '.ipynb':
236 if os.path.splitext(path)[1] == '.ipynb':
251 name = os.path.basename(path)
237 name = os.path.basename(path)
252 self.set_header('Content-Type', 'application/json')
238 self.set_header('Content-Type', 'application/json')
253 self.set_header('Content-Disposition','attachment; filename="%s"' % name)
239 self.set_header('Content-Disposition','attachment; filename="%s"' % name)
254
240
255 return web.StaticFileHandler.get(self, path)
241 return web.StaticFileHandler.get(self, path)
256
242
257 def compute_etag(self):
243 def compute_etag(self):
258 return None
244 return None
259
245
260 def validate_absolute_path(self, root, absolute_path):
246 def validate_absolute_path(self, root, absolute_path):
261 """Validate and return the absolute path.
247 """Validate and return the absolute path.
262
248
263 Requires tornado 3.1
249 Requires tornado 3.1
264
250
265 Adding to tornado's own handling, forbids the serving of hidden files.
251 Adding to tornado's own handling, forbids the serving of hidden files.
266 """
252 """
267 abs_path = super(AuthenticatedFileHandler, self).validate_absolute_path(root, absolute_path)
253 abs_path = super(AuthenticatedFileHandler, self).validate_absolute_path(root, absolute_path)
268 abs_root = os.path.abspath(root)
254 abs_root = os.path.abspath(root)
269 if is_hidden(abs_path, abs_root):
255 if is_hidden(abs_path, abs_root):
270 self.log.info("Refusing to serve hidden file, via 404 Error")
256 self.log.info("Refusing to serve hidden file, via 404 Error")
271 raise web.HTTPError(404)
257 raise web.HTTPError(404)
272 return abs_path
258 return abs_path
273
259
274
260
275 def json_errors(method):
261 def json_errors(method):
276 """Decorate methods with this to return GitHub style JSON errors.
262 """Decorate methods with this to return GitHub style JSON errors.
277
263
278 This should be used on any JSON API on any handler method that can raise HTTPErrors.
264 This should be used on any JSON API on any handler method that can raise HTTPErrors.
279
265
280 This will grab the latest HTTPError exception using sys.exc_info
266 This will grab the latest HTTPError exception using sys.exc_info
281 and then:
267 and then:
282
268
283 1. Set the HTTP status code based on the HTTPError
269 1. Set the HTTP status code based on the HTTPError
284 2. Create and return a JSON body with a message field describing
270 2. Create and return a JSON body with a message field describing
285 the error in a human readable form.
271 the error in a human readable form.
286 """
272 """
287 @functools.wraps(method)
273 @functools.wraps(method)
288 def wrapper(self, *args, **kwargs):
274 def wrapper(self, *args, **kwargs):
289 try:
275 try:
290 result = method(self, *args, **kwargs)
276 result = method(self, *args, **kwargs)
291 except web.HTTPError as e:
277 except web.HTTPError as e:
292 status = e.status_code
278 status = e.status_code
293 message = e.log_message
279 message = e.log_message
294 self.log.warn(message)
280 self.log.warn(message)
295 self.set_status(e.status_code)
281 self.set_status(e.status_code)
296 self.finish(json.dumps(dict(message=message)))
282 self.finish(json.dumps(dict(message=message)))
297 except Exception:
283 except Exception:
298 self.log.error("Unhandled error in API request", exc_info=True)
284 self.log.error("Unhandled error in API request", exc_info=True)
299 status = 500
285 status = 500
300 message = "Unknown server error"
286 message = "Unknown server error"
301 t, value, tb = sys.exc_info()
287 t, value, tb = sys.exc_info()
302 self.set_status(status)
288 self.set_status(status)
303 tb_text = ''.join(traceback.format_exception(t, value, tb))
289 tb_text = ''.join(traceback.format_exception(t, value, tb))
304 reply = dict(message=message, traceback=tb_text)
290 reply = dict(message=message, traceback=tb_text)
305 self.finish(json.dumps(reply))
291 self.finish(json.dumps(reply))
306 else:
292 else:
307 return result
293 return result
308 return wrapper
294 return wrapper
309
295
310
296
311
297
312 #-----------------------------------------------------------------------------
298 #-----------------------------------------------------------------------------
313 # File handler
299 # File handler
314 #-----------------------------------------------------------------------------
300 #-----------------------------------------------------------------------------
315
301
316 # to minimize subclass changes:
302 # to minimize subclass changes:
317 HTTPError = web.HTTPError
303 HTTPError = web.HTTPError
318
304
319 class FileFindHandler(web.StaticFileHandler):
305 class FileFindHandler(web.StaticFileHandler):
320 """subclass of StaticFileHandler for serving files from a search path"""
306 """subclass of StaticFileHandler for serving files from a search path"""
321
307
322 # cache search results, don't search for files more than once
308 # cache search results, don't search for files more than once
323 _static_paths = {}
309 _static_paths = {}
324
310
325 def initialize(self, path, default_filename=None):
311 def initialize(self, path, default_filename=None):
326 if isinstance(path, string_types):
312 if isinstance(path, string_types):
327 path = [path]
313 path = [path]
328
314
329 self.root = tuple(
315 self.root = tuple(
330 os.path.abspath(os.path.expanduser(p)) + os.sep for p in path
316 os.path.abspath(os.path.expanduser(p)) + os.sep for p in path
331 )
317 )
332 self.default_filename = default_filename
318 self.default_filename = default_filename
333
319
334 def compute_etag(self):
320 def compute_etag(self):
335 return None
321 return None
336
322
337 @classmethod
323 @classmethod
338 def get_absolute_path(cls, roots, path):
324 def get_absolute_path(cls, roots, path):
339 """locate a file to serve on our static file search path"""
325 """locate a file to serve on our static file search path"""
340 with cls._lock:
326 with cls._lock:
341 if path in cls._static_paths:
327 if path in cls._static_paths:
342 return cls._static_paths[path]
328 return cls._static_paths[path]
343 try:
329 try:
344 abspath = os.path.abspath(filefind(path, roots))
330 abspath = os.path.abspath(filefind(path, roots))
345 except IOError:
331 except IOError:
346 # IOError means not found
332 # IOError means not found
347 return ''
333 return ''
348
334
349 cls._static_paths[path] = abspath
335 cls._static_paths[path] = abspath
350 return abspath
336 return abspath
351
337
352 def validate_absolute_path(self, root, absolute_path):
338 def validate_absolute_path(self, root, absolute_path):
353 """check if the file should be served (raises 404, 403, etc.)"""
339 """check if the file should be served (raises 404, 403, etc.)"""
354 if absolute_path == '':
340 if absolute_path == '':
355 raise web.HTTPError(404)
341 raise web.HTTPError(404)
356
342
357 for root in self.root:
343 for root in self.root:
358 if (absolute_path + os.sep).startswith(root):
344 if (absolute_path + os.sep).startswith(root):
359 break
345 break
360
346
361 return super(FileFindHandler, self).validate_absolute_path(root, absolute_path)
347 return super(FileFindHandler, self).validate_absolute_path(root, absolute_path)
362
348
363
349
364 class TrailingSlashHandler(web.RequestHandler):
350 class TrailingSlashHandler(web.RequestHandler):
365 """Simple redirect handler that strips trailing slashes
351 """Simple redirect handler that strips trailing slashes
366
352
367 This should be the first, highest priority handler.
353 This should be the first, highest priority handler.
368 """
354 """
369
355
370 SUPPORTED_METHODS = ['GET']
356 SUPPORTED_METHODS = ['GET']
371
357
372 def get(self):
358 def get(self):
373 self.redirect(self.request.uri.rstrip('/'))
359 self.redirect(self.request.uri.rstrip('/'))
374
360
375 #-----------------------------------------------------------------------------
361 #-----------------------------------------------------------------------------
376 # URL pattern fragments for re-use
362 # URL pattern fragments for re-use
377 #-----------------------------------------------------------------------------
363 #-----------------------------------------------------------------------------
378
364
379 path_regex = r"(?P<path>(?:/.*)*)"
365 path_regex = r"(?P<path>(?:/.*)*)"
380 notebook_name_regex = r"(?P<name>[^/]+\.ipynb)"
366 notebook_name_regex = r"(?P<name>[^/]+\.ipynb)"
381 notebook_path_regex = "%s/%s" % (path_regex, notebook_name_regex)
367 notebook_path_regex = "%s/%s" % (path_regex, notebook_name_regex)
382
368
383 #-----------------------------------------------------------------------------
369 #-----------------------------------------------------------------------------
384 # URL to handler mappings
370 # URL to handler mappings
385 #-----------------------------------------------------------------------------
371 #-----------------------------------------------------------------------------
386
372
387
373
388 default_handlers = [
374 default_handlers = [
389 (r".*/", TrailingSlashHandler)
375 (r".*/", TrailingSlashHandler)
390 ]
376 ]
@@ -1,238 +1,217 b''
1 """The official API for working with notebooks in the current format version.
1 """The official API for working with notebooks in the current format version."""
2
3 Authors:
4
5 * Brian Granger
6 * Jonathan Frederic
7 """
8
9 #-----------------------------------------------------------------------------
10 # Copyright (C) 2008-2011 The IPython Development Team
11 #
12 # Distributed under the terms of the BSD License. The full license is in
13 # the file COPYING, distributed as part of this software.
14 #-----------------------------------------------------------------------------
15
16 #-----------------------------------------------------------------------------
17 # Imports
18 #-----------------------------------------------------------------------------
19
2
20 from __future__ import print_function
3 from __future__ import print_function
21
4
22 from xml.etree import ElementTree as ET
5 from xml.etree import ElementTree as ET
23 import re
6 import re
24
7
25 from IPython.utils.py3compat import unicode_type
8 from IPython.utils.py3compat import unicode_type
26
9
27 from IPython.nbformat.v3 import (
10 from IPython.nbformat.v3 import (
28 NotebookNode,
11 NotebookNode,
29 new_code_cell, new_text_cell, new_notebook, new_output, new_worksheet,
12 new_code_cell, new_text_cell, new_notebook, new_output, new_worksheet,
30 parse_filename, new_metadata, new_author, new_heading_cell, nbformat,
13 parse_filename, new_metadata, new_author, new_heading_cell, nbformat,
31 nbformat_minor, nbformat_schema, to_notebook_json
14 nbformat_minor, nbformat_schema, to_notebook_json
32 )
15 )
33 from IPython.nbformat import v3 as _v_latest
16 from IPython.nbformat import v3 as _v_latest
34
17
35 from .reader import reads as reader_reads
18 from .reader import reads as reader_reads
36 from .reader import versions
19 from .reader import versions
37 from .convert import convert
20 from .convert import convert
38 from .validator import validate
21 from .validator import validate
39
22
40 import logging
23 from IPython.utils.log import get_logger
41 logger = logging.getLogger('NotebookApp')
42
24
43 #-----------------------------------------------------------------------------
44 # Code
45 #-----------------------------------------------------------------------------
46
25
47 current_nbformat = nbformat
26 current_nbformat = nbformat
48 current_nbformat_minor = nbformat_minor
27 current_nbformat_minor = nbformat_minor
49 current_nbformat_module = _v_latest.__name__
28 current_nbformat_module = _v_latest.__name__
50
29
51
30
52 def docstring_nbformat_mod(func):
31 def docstring_nbformat_mod(func):
53 """Decorator for docstrings referring to classes/functions accessed through
32 """Decorator for docstrings referring to classes/functions accessed through
54 nbformat.current.
33 nbformat.current.
55
34
56 Put {nbformat_mod} in the docstring in place of 'IPython.nbformat.v3'.
35 Put {nbformat_mod} in the docstring in place of 'IPython.nbformat.v3'.
57 """
36 """
58 func.__doc__ = func.__doc__.format(nbformat_mod=current_nbformat_module)
37 func.__doc__ = func.__doc__.format(nbformat_mod=current_nbformat_module)
59 return func
38 return func
60
39
61
40
62 class NBFormatError(ValueError):
41 class NBFormatError(ValueError):
63 pass
42 pass
64
43
65
44
66 def parse_py(s, **kwargs):
45 def parse_py(s, **kwargs):
67 """Parse a string into a (nbformat, string) tuple."""
46 """Parse a string into a (nbformat, string) tuple."""
68 nbf = current_nbformat
47 nbf = current_nbformat
69 nbm = current_nbformat_minor
48 nbm = current_nbformat_minor
70
49
71 pattern = r'# <nbformat>(?P<nbformat>\d+[\.\d+]*)</nbformat>'
50 pattern = r'# <nbformat>(?P<nbformat>\d+[\.\d+]*)</nbformat>'
72 m = re.search(pattern,s)
51 m = re.search(pattern,s)
73 if m is not None:
52 if m is not None:
74 digits = m.group('nbformat').split('.')
53 digits = m.group('nbformat').split('.')
75 nbf = int(digits[0])
54 nbf = int(digits[0])
76 if len(digits) > 1:
55 if len(digits) > 1:
77 nbm = int(digits[1])
56 nbm = int(digits[1])
78
57
79 return nbf, nbm, s
58 return nbf, nbm, s
80
59
81
60
82 def reads_json(nbjson, **kwargs):
61 def reads_json(nbjson, **kwargs):
83 """Read a JSON notebook from a string and return the NotebookNode
62 """Read a JSON notebook from a string and return the NotebookNode
84 object. Report if any JSON format errors are detected.
63 object. Report if any JSON format errors are detected.
85
64
86 """
65 """
87 nb = reader_reads(nbjson, **kwargs)
66 nb = reader_reads(nbjson, **kwargs)
88 nb_current = convert(nb, current_nbformat)
67 nb_current = convert(nb, current_nbformat)
89 errors = validate(nb_current)
68 errors = validate(nb_current)
90 if errors:
69 if errors:
91 logger.error(
70 get_logger().error(
92 "Notebook JSON is invalid (%d errors detected during read)",
71 "Notebook JSON is invalid (%d errors detected during read)",
93 len(errors))
72 len(errors))
94 return nb_current
73 return nb_current
95
74
96
75
97 def writes_json(nb, **kwargs):
76 def writes_json(nb, **kwargs):
98 """Take a NotebookNode object and write out a JSON string. Report if
77 """Take a NotebookNode object and write out a JSON string. Report if
99 any JSON format errors are detected.
78 any JSON format errors are detected.
100
79
101 """
80 """
102 errors = validate(nb)
81 errors = validate(nb)
103 if errors:
82 if errors:
104 logger.error(
83 get_logger().error(
105 "Notebook JSON is invalid (%d errors detected during write)",
84 "Notebook JSON is invalid (%d errors detected during write)",
106 len(errors))
85 len(errors))
107 nbjson = versions[current_nbformat].writes_json(nb, **kwargs)
86 nbjson = versions[current_nbformat].writes_json(nb, **kwargs)
108 return nbjson
87 return nbjson
109
88
110
89
111 def reads_py(s, **kwargs):
90 def reads_py(s, **kwargs):
112 """Read a .py notebook from a string and return the NotebookNode object."""
91 """Read a .py notebook from a string and return the NotebookNode object."""
113 nbf, nbm, s = parse_py(s, **kwargs)
92 nbf, nbm, s = parse_py(s, **kwargs)
114 if nbf in (2, 3):
93 if nbf in (2, 3):
115 nb = versions[nbf].to_notebook_py(s, **kwargs)
94 nb = versions[nbf].to_notebook_py(s, **kwargs)
116 else:
95 else:
117 raise NBFormatError('Unsupported PY nbformat version: %i' % nbf)
96 raise NBFormatError('Unsupported PY nbformat version: %i' % nbf)
118 return nb
97 return nb
119
98
120
99
121 def writes_py(nb, **kwargs):
100 def writes_py(nb, **kwargs):
122 # nbformat 3 is the latest format that supports py
101 # nbformat 3 is the latest format that supports py
123 return versions[3].writes_py(nb, **kwargs)
102 return versions[3].writes_py(nb, **kwargs)
124
103
125
104
126 # High level API
105 # High level API
127
106
128
107
129 def reads(s, format, **kwargs):
108 def reads(s, format, **kwargs):
130 """Read a notebook from a string and return the NotebookNode object.
109 """Read a notebook from a string and return the NotebookNode object.
131
110
132 This function properly handles notebooks of any version. The notebook
111 This function properly handles notebooks of any version. The notebook
133 returned will always be in the current version's format.
112 returned will always be in the current version's format.
134
113
135 Parameters
114 Parameters
136 ----------
115 ----------
137 s : unicode
116 s : unicode
138 The raw unicode string to read the notebook from.
117 The raw unicode string to read the notebook from.
139 format : (u'json', u'ipynb', u'py')
118 format : (u'json', u'ipynb', u'py')
140 The format that the string is in.
119 The format that the string is in.
141
120
142 Returns
121 Returns
143 -------
122 -------
144 nb : NotebookNode
123 nb : NotebookNode
145 The notebook that was read.
124 The notebook that was read.
146 """
125 """
147 format = unicode_type(format)
126 format = unicode_type(format)
148 if format == u'json' or format == u'ipynb':
127 if format == u'json' or format == u'ipynb':
149 return reads_json(s, **kwargs)
128 return reads_json(s, **kwargs)
150 elif format == u'py':
129 elif format == u'py':
151 return reads_py(s, **kwargs)
130 return reads_py(s, **kwargs)
152 else:
131 else:
153 raise NBFormatError('Unsupported format: %s' % format)
132 raise NBFormatError('Unsupported format: %s' % format)
154
133
155
134
156 def writes(nb, format, **kwargs):
135 def writes(nb, format, **kwargs):
157 """Write a notebook to a string in a given format in the current nbformat version.
136 """Write a notebook to a string in a given format in the current nbformat version.
158
137
159 This function always writes the notebook in the current nbformat version.
138 This function always writes the notebook in the current nbformat version.
160
139
161 Parameters
140 Parameters
162 ----------
141 ----------
163 nb : NotebookNode
142 nb : NotebookNode
164 The notebook to write.
143 The notebook to write.
165 format : (u'json', u'ipynb', u'py')
144 format : (u'json', u'ipynb', u'py')
166 The format to write the notebook in.
145 The format to write the notebook in.
167
146
168 Returns
147 Returns
169 -------
148 -------
170 s : unicode
149 s : unicode
171 The notebook string.
150 The notebook string.
172 """
151 """
173 format = unicode_type(format)
152 format = unicode_type(format)
174 if format == u'json' or format == u'ipynb':
153 if format == u'json' or format == u'ipynb':
175 return writes_json(nb, **kwargs)
154 return writes_json(nb, **kwargs)
176 elif format == u'py':
155 elif format == u'py':
177 return writes_py(nb, **kwargs)
156 return writes_py(nb, **kwargs)
178 else:
157 else:
179 raise NBFormatError('Unsupported format: %s' % format)
158 raise NBFormatError('Unsupported format: %s' % format)
180
159
181
160
182 def read(fp, format, **kwargs):
161 def read(fp, format, **kwargs):
183 """Read a notebook from a file and return the NotebookNode object.
162 """Read a notebook from a file and return the NotebookNode object.
184
163
185 This function properly handles notebooks of any version. The notebook
164 This function properly handles notebooks of any version. The notebook
186 returned will always be in the current version's format.
165 returned will always be in the current version's format.
187
166
188 Parameters
167 Parameters
189 ----------
168 ----------
190 fp : file
169 fp : file
191 Any file-like object with a read method.
170 Any file-like object with a read method.
192 format : (u'json', u'ipynb', u'py')
171 format : (u'json', u'ipynb', u'py')
193 The format that the string is in.
172 The format that the string is in.
194
173
195 Returns
174 Returns
196 -------
175 -------
197 nb : NotebookNode
176 nb : NotebookNode
198 The notebook that was read.
177 The notebook that was read.
199 """
178 """
200 return reads(fp.read(), format, **kwargs)
179 return reads(fp.read(), format, **kwargs)
201
180
202
181
203 def write(nb, fp, format, **kwargs):
182 def write(nb, fp, format, **kwargs):
204 """Write a notebook to a file in a given format in the current nbformat version.
183 """Write a notebook to a file in a given format in the current nbformat version.
205
184
206 This function always writes the notebook in the current nbformat version.
185 This function always writes the notebook in the current nbformat version.
207
186
208 Parameters
187 Parameters
209 ----------
188 ----------
210 nb : NotebookNode
189 nb : NotebookNode
211 The notebook to write.
190 The notebook to write.
212 fp : file
191 fp : file
213 Any file-like object with a write method.
192 Any file-like object with a write method.
214 format : (u'json', u'ipynb', u'py')
193 format : (u'json', u'ipynb', u'py')
215 The format to write the notebook in.
194 The format to write the notebook in.
216
195
217 Returns
196 Returns
218 -------
197 -------
219 s : unicode
198 s : unicode
220 The notebook string.
199 The notebook string.
221 """
200 """
222 return fp.write(writes(nb, format, **kwargs))
201 return fp.write(writes(nb, format, **kwargs))
223
202
224 def _convert_to_metadata():
203 def _convert_to_metadata():
225 """Convert to a notebook having notebook metadata."""
204 """Convert to a notebook having notebook metadata."""
226 import glob
205 import glob
227 for fname in glob.glob('*.ipynb'):
206 for fname in glob.glob('*.ipynb'):
228 print('Converting file:',fname)
207 print('Converting file:',fname)
229 with open(fname,'r') as f:
208 with open(fname,'r') as f:
230 nb = read(f,u'json')
209 nb = read(f,u'json')
231 md = new_metadata()
210 md = new_metadata()
232 if u'name' in nb:
211 if u'name' in nb:
233 md.name = nb.name
212 md.name = nb.name
234 del nb[u'name']
213 del nb[u'name']
235 nb.metadata = md
214 nb.metadata = md
236 with open(fname,'w') as f:
215 with open(fname,'w') as f:
237 write(nb, f, u'json')
216 write(nb, f, u'json')
238
217
@@ -1,859 +1,848 b''
1 """The Python scheduler for rich scheduling.
1 """The Python scheduler for rich scheduling.
2
2
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 Python Scheduler exists.
5 Python Scheduler exists.
6
7 Authors:
8
9 * Min RK
10 """
6 """
11 #-----------------------------------------------------------------------------
12 # Copyright (C) 2010-2011 The IPython Development Team
13 #
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
17
7
18 #----------------------------------------------------------------------
8 # Copyright (c) IPython Development Team.
19 # Imports
9 # Distributed under the terms of the Modified BSD License.
20 #----------------------------------------------------------------------
21
10
22 import logging
11 import logging
23 import sys
12 import sys
24 import time
13 import time
25
14
26 from collections import deque
15 from collections import deque
27 from datetime import datetime
16 from datetime import datetime
28 from random import randint, random
17 from random import randint, random
29 from types import FunctionType
18 from types import FunctionType
30
19
31 try:
20 try:
32 import numpy
21 import numpy
33 except ImportError:
22 except ImportError:
34 numpy = None
23 numpy = None
35
24
36 import zmq
25 import zmq
37 from zmq.eventloop import ioloop, zmqstream
26 from zmq.eventloop import ioloop, zmqstream
38
27
39 # local imports
28 # local imports
40 from IPython.external.decorator import decorator
29 from IPython.external.decorator import decorator
41 from IPython.config.application import Application
30 from IPython.config.application import Application
42 from IPython.config.loader import Config
31 from IPython.config.loader import Config
43 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
32 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
44 from IPython.utils.py3compat import cast_bytes
33 from IPython.utils.py3compat import cast_bytes
45
34
46 from IPython.parallel import error, util
35 from IPython.parallel import error, util
47 from IPython.parallel.factory import SessionFactory
36 from IPython.parallel.factory import SessionFactory
48 from IPython.parallel.util import connect_logger, local_logger
37 from IPython.parallel.util import connect_logger, local_logger
49
38
50 from .dependency import Dependency
39 from .dependency import Dependency
51
40
52 @decorator
41 @decorator
53 def logged(f,self,*args,**kwargs):
42 def logged(f,self,*args,**kwargs):
54 # print ("#--------------------")
43 # print ("#--------------------")
55 self.log.debug("scheduler::%s(*%s,**%s)", f.__name__, args, kwargs)
44 self.log.debug("scheduler::%s(*%s,**%s)", f.__name__, args, kwargs)
56 # print ("#--")
45 # print ("#--")
57 return f(self,*args, **kwargs)
46 return f(self,*args, **kwargs)
58
47
59 #----------------------------------------------------------------------
48 #----------------------------------------------------------------------
60 # Chooser functions
49 # Chooser functions
61 #----------------------------------------------------------------------
50 #----------------------------------------------------------------------
62
51
63 def plainrandom(loads):
52 def plainrandom(loads):
64 """Plain random pick."""
53 """Plain random pick."""
65 n = len(loads)
54 n = len(loads)
66 return randint(0,n-1)
55 return randint(0,n-1)
67
56
68 def lru(loads):
57 def lru(loads):
69 """Always pick the front of the line.
58 """Always pick the front of the line.
70
59
71 The content of `loads` is ignored.
60 The content of `loads` is ignored.
72
61
73 Assumes LRU ordering of loads, with oldest first.
62 Assumes LRU ordering of loads, with oldest first.
74 """
63 """
75 return 0
64 return 0
76
65
77 def twobin(loads):
66 def twobin(loads):
78 """Pick two at random, use the LRU of the two.
67 """Pick two at random, use the LRU of the two.
79
68
80 The content of loads is ignored.
69 The content of loads is ignored.
81
70
82 Assumes LRU ordering of loads, with oldest first.
71 Assumes LRU ordering of loads, with oldest first.
83 """
72 """
84 n = len(loads)
73 n = len(loads)
85 a = randint(0,n-1)
74 a = randint(0,n-1)
86 b = randint(0,n-1)
75 b = randint(0,n-1)
87 return min(a,b)
76 return min(a,b)
88
77
89 def weighted(loads):
78 def weighted(loads):
90 """Pick two at random using inverse load as weight.
79 """Pick two at random using inverse load as weight.
91
80
92 Return the less loaded of the two.
81 Return the less loaded of the two.
93 """
82 """
94 # weight 0 a million times more than 1:
83 # weight 0 a million times more than 1:
95 weights = 1./(1e-6+numpy.array(loads))
84 weights = 1./(1e-6+numpy.array(loads))
96 sums = weights.cumsum()
85 sums = weights.cumsum()
97 t = sums[-1]
86 t = sums[-1]
98 x = random()*t
87 x = random()*t
99 y = random()*t
88 y = random()*t
100 idx = 0
89 idx = 0
101 idy = 0
90 idy = 0
102 while sums[idx] < x:
91 while sums[idx] < x:
103 idx += 1
92 idx += 1
104 while sums[idy] < y:
93 while sums[idy] < y:
105 idy += 1
94 idy += 1
106 if weights[idy] > weights[idx]:
95 if weights[idy] > weights[idx]:
107 return idy
96 return idy
108 else:
97 else:
109 return idx
98 return idx
110
99
111 def leastload(loads):
100 def leastload(loads):
112 """Always choose the lowest load.
101 """Always choose the lowest load.
113
102
114 If the lowest load occurs more than once, the first
103 If the lowest load occurs more than once, the first
115 occurance will be used. If loads has LRU ordering, this means
104 occurance will be used. If loads has LRU ordering, this means
116 the LRU of those with the lowest load is chosen.
105 the LRU of those with the lowest load is chosen.
117 """
106 """
118 return loads.index(min(loads))
107 return loads.index(min(loads))
119
108
120 #---------------------------------------------------------------------
109 #---------------------------------------------------------------------
121 # Classes
110 # Classes
122 #---------------------------------------------------------------------
111 #---------------------------------------------------------------------
123
112
124
113
125 # store empty default dependency:
114 # store empty default dependency:
126 MET = Dependency([])
115 MET = Dependency([])
127
116
128
117
129 class Job(object):
118 class Job(object):
130 """Simple container for a job"""
119 """Simple container for a job"""
131 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
120 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
132 targets, after, follow, timeout):
121 targets, after, follow, timeout):
133 self.msg_id = msg_id
122 self.msg_id = msg_id
134 self.raw_msg = raw_msg
123 self.raw_msg = raw_msg
135 self.idents = idents
124 self.idents = idents
136 self.msg = msg
125 self.msg = msg
137 self.header = header
126 self.header = header
138 self.metadata = metadata
127 self.metadata = metadata
139 self.targets = targets
128 self.targets = targets
140 self.after = after
129 self.after = after
141 self.follow = follow
130 self.follow = follow
142 self.timeout = timeout
131 self.timeout = timeout
143
132
144 self.removed = False # used for lazy-delete from sorted queue
133 self.removed = False # used for lazy-delete from sorted queue
145 self.timestamp = time.time()
134 self.timestamp = time.time()
146 self.timeout_id = 0
135 self.timeout_id = 0
147 self.blacklist = set()
136 self.blacklist = set()
148
137
149 def __lt__(self, other):
138 def __lt__(self, other):
150 return self.timestamp < other.timestamp
139 return self.timestamp < other.timestamp
151
140
152 def __cmp__(self, other):
141 def __cmp__(self, other):
153 return cmp(self.timestamp, other.timestamp)
142 return cmp(self.timestamp, other.timestamp)
154
143
155 @property
144 @property
156 def dependents(self):
145 def dependents(self):
157 return self.follow.union(self.after)
146 return self.follow.union(self.after)
158
147
159
148
160 class TaskScheduler(SessionFactory):
149 class TaskScheduler(SessionFactory):
161 """Python TaskScheduler object.
150 """Python TaskScheduler object.
162
151
163 This is the simplest object that supports msg_id based
152 This is the simplest object that supports msg_id based
164 DAG dependencies. *Only* task msg_ids are checked, not
153 DAG dependencies. *Only* task msg_ids are checked, not
165 msg_ids of jobs submitted via the MUX queue.
154 msg_ids of jobs submitted via the MUX queue.
166
155
167 """
156 """
168
157
169 hwm = Integer(1, config=True,
158 hwm = Integer(1, config=True,
170 help="""specify the High Water Mark (HWM) for the downstream
159 help="""specify the High Water Mark (HWM) for the downstream
171 socket in the Task scheduler. This is the maximum number
160 socket in the Task scheduler. This is the maximum number
172 of allowed outstanding tasks on each engine.
161 of allowed outstanding tasks on each engine.
173
162
174 The default (1) means that only one task can be outstanding on each
163 The default (1) means that only one task can be outstanding on each
175 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
164 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
176 engines continue to be assigned tasks while they are working,
165 engines continue to be assigned tasks while they are working,
177 effectively hiding network latency behind computation, but can result
166 effectively hiding network latency behind computation, but can result
178 in an imbalance of work when submitting many heterogenous tasks all at
167 in an imbalance of work when submitting many heterogenous tasks all at
179 once. Any positive value greater than one is a compromise between the
168 once. Any positive value greater than one is a compromise between the
180 two.
169 two.
181
170
182 """
171 """
183 )
172 )
184 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
173 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
185 'leastload', config=True, allow_none=False,
174 'leastload', config=True, allow_none=False,
186 help="""select the task scheduler scheme [default: Python LRU]
175 help="""select the task scheduler scheme [default: Python LRU]
187 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
176 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
188 )
177 )
189 def _scheme_name_changed(self, old, new):
178 def _scheme_name_changed(self, old, new):
190 self.log.debug("Using scheme %r"%new)
179 self.log.debug("Using scheme %r"%new)
191 self.scheme = globals()[new]
180 self.scheme = globals()[new]
192
181
193 # input arguments:
182 # input arguments:
194 scheme = Instance(FunctionType) # function for determining the destination
183 scheme = Instance(FunctionType) # function for determining the destination
195 def _scheme_default(self):
184 def _scheme_default(self):
196 return leastload
185 return leastload
197 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
186 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
198 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
187 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
199 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
188 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
200 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
189 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
201 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
190 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
202
191
203 # internals:
192 # internals:
204 queue = Instance(deque) # sorted list of Jobs
193 queue = Instance(deque) # sorted list of Jobs
205 def _queue_default(self):
194 def _queue_default(self):
206 return deque()
195 return deque()
207 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
196 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
208 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
197 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
209 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
198 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
210 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
199 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
211 pending = Dict() # dict by engine_uuid of submitted tasks
200 pending = Dict() # dict by engine_uuid of submitted tasks
212 completed = Dict() # dict by engine_uuid of completed tasks
201 completed = Dict() # dict by engine_uuid of completed tasks
213 failed = Dict() # dict by engine_uuid of failed tasks
202 failed = Dict() # dict by engine_uuid of failed tasks
214 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
203 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
215 clients = Dict() # dict by msg_id for who submitted the task
204 clients = Dict() # dict by msg_id for who submitted the task
216 targets = List() # list of target IDENTs
205 targets = List() # list of target IDENTs
217 loads = List() # list of engine loads
206 loads = List() # list of engine loads
218 # full = Set() # set of IDENTs that have HWM outstanding tasks
207 # full = Set() # set of IDENTs that have HWM outstanding tasks
219 all_completed = Set() # set of all completed tasks
208 all_completed = Set() # set of all completed tasks
220 all_failed = Set() # set of all failed tasks
209 all_failed = Set() # set of all failed tasks
221 all_done = Set() # set of all finished tasks=union(completed,failed)
210 all_done = Set() # set of all finished tasks=union(completed,failed)
222 all_ids = Set() # set of all submitted task IDs
211 all_ids = Set() # set of all submitted task IDs
223
212
224 ident = CBytes() # ZMQ identity. This should just be self.session.session
213 ident = CBytes() # ZMQ identity. This should just be self.session.session
225 # but ensure Bytes
214 # but ensure Bytes
226 def _ident_default(self):
215 def _ident_default(self):
227 return self.session.bsession
216 return self.session.bsession
228
217
229 def start(self):
218 def start(self):
230 self.query_stream.on_recv(self.dispatch_query_reply)
219 self.query_stream.on_recv(self.dispatch_query_reply)
231 self.session.send(self.query_stream, "connection_request", {})
220 self.session.send(self.query_stream, "connection_request", {})
232
221
233 self.engine_stream.on_recv(self.dispatch_result, copy=False)
222 self.engine_stream.on_recv(self.dispatch_result, copy=False)
234 self.client_stream.on_recv(self.dispatch_submission, copy=False)
223 self.client_stream.on_recv(self.dispatch_submission, copy=False)
235
224
236 self._notification_handlers = dict(
225 self._notification_handlers = dict(
237 registration_notification = self._register_engine,
226 registration_notification = self._register_engine,
238 unregistration_notification = self._unregister_engine
227 unregistration_notification = self._unregister_engine
239 )
228 )
240 self.notifier_stream.on_recv(self.dispatch_notification)
229 self.notifier_stream.on_recv(self.dispatch_notification)
241 self.log.info("Scheduler started [%s]" % self.scheme_name)
230 self.log.info("Scheduler started [%s]" % self.scheme_name)
242
231
243 def resume_receiving(self):
232 def resume_receiving(self):
244 """Resume accepting jobs."""
233 """Resume accepting jobs."""
245 self.client_stream.on_recv(self.dispatch_submission, copy=False)
234 self.client_stream.on_recv(self.dispatch_submission, copy=False)
246
235
247 def stop_receiving(self):
236 def stop_receiving(self):
248 """Stop accepting jobs while there are no engines.
237 """Stop accepting jobs while there are no engines.
249 Leave them in the ZMQ queue."""
238 Leave them in the ZMQ queue."""
250 self.client_stream.on_recv(None)
239 self.client_stream.on_recv(None)
251
240
252 #-----------------------------------------------------------------------
241 #-----------------------------------------------------------------------
253 # [Un]Registration Handling
242 # [Un]Registration Handling
254 #-----------------------------------------------------------------------
243 #-----------------------------------------------------------------------
255
244
256
245
257 def dispatch_query_reply(self, msg):
246 def dispatch_query_reply(self, msg):
258 """handle reply to our initial connection request"""
247 """handle reply to our initial connection request"""
259 try:
248 try:
260 idents,msg = self.session.feed_identities(msg)
249 idents,msg = self.session.feed_identities(msg)
261 except ValueError:
250 except ValueError:
262 self.log.warn("task::Invalid Message: %r",msg)
251 self.log.warn("task::Invalid Message: %r",msg)
263 return
252 return
264 try:
253 try:
265 msg = self.session.unserialize(msg)
254 msg = self.session.unserialize(msg)
266 except ValueError:
255 except ValueError:
267 self.log.warn("task::Unauthorized message from: %r"%idents)
256 self.log.warn("task::Unauthorized message from: %r"%idents)
268 return
257 return
269
258
270 content = msg['content']
259 content = msg['content']
271 for uuid in content.get('engines', {}).values():
260 for uuid in content.get('engines', {}).values():
272 self._register_engine(cast_bytes(uuid))
261 self._register_engine(cast_bytes(uuid))
273
262
274
263
275 @util.log_errors
264 @util.log_errors
276 def dispatch_notification(self, msg):
265 def dispatch_notification(self, msg):
277 """dispatch register/unregister events."""
266 """dispatch register/unregister events."""
278 try:
267 try:
279 idents,msg = self.session.feed_identities(msg)
268 idents,msg = self.session.feed_identities(msg)
280 except ValueError:
269 except ValueError:
281 self.log.warn("task::Invalid Message: %r",msg)
270 self.log.warn("task::Invalid Message: %r",msg)
282 return
271 return
283 try:
272 try:
284 msg = self.session.unserialize(msg)
273 msg = self.session.unserialize(msg)
285 except ValueError:
274 except ValueError:
286 self.log.warn("task::Unauthorized message from: %r"%idents)
275 self.log.warn("task::Unauthorized message from: %r"%idents)
287 return
276 return
288
277
289 msg_type = msg['header']['msg_type']
278 msg_type = msg['header']['msg_type']
290
279
291 handler = self._notification_handlers.get(msg_type, None)
280 handler = self._notification_handlers.get(msg_type, None)
292 if handler is None:
281 if handler is None:
293 self.log.error("Unhandled message type: %r"%msg_type)
282 self.log.error("Unhandled message type: %r"%msg_type)
294 else:
283 else:
295 try:
284 try:
296 handler(cast_bytes(msg['content']['uuid']))
285 handler(cast_bytes(msg['content']['uuid']))
297 except Exception:
286 except Exception:
298 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
287 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
299
288
300 def _register_engine(self, uid):
289 def _register_engine(self, uid):
301 """New engine with ident `uid` became available."""
290 """New engine with ident `uid` became available."""
302 # head of the line:
291 # head of the line:
303 self.targets.insert(0,uid)
292 self.targets.insert(0,uid)
304 self.loads.insert(0,0)
293 self.loads.insert(0,0)
305
294
306 # initialize sets
295 # initialize sets
307 self.completed[uid] = set()
296 self.completed[uid] = set()
308 self.failed[uid] = set()
297 self.failed[uid] = set()
309 self.pending[uid] = {}
298 self.pending[uid] = {}
310
299
311 # rescan the graph:
300 # rescan the graph:
312 self.update_graph(None)
301 self.update_graph(None)
313
302
314 def _unregister_engine(self, uid):
303 def _unregister_engine(self, uid):
315 """Existing engine with ident `uid` became unavailable."""
304 """Existing engine with ident `uid` became unavailable."""
316 if len(self.targets) == 1:
305 if len(self.targets) == 1:
317 # this was our only engine
306 # this was our only engine
318 pass
307 pass
319
308
320 # handle any potentially finished tasks:
309 # handle any potentially finished tasks:
321 self.engine_stream.flush()
310 self.engine_stream.flush()
322
311
323 # don't pop destinations, because they might be used later
312 # don't pop destinations, because they might be used later
324 # map(self.destinations.pop, self.completed.pop(uid))
313 # map(self.destinations.pop, self.completed.pop(uid))
325 # map(self.destinations.pop, self.failed.pop(uid))
314 # map(self.destinations.pop, self.failed.pop(uid))
326
315
327 # prevent this engine from receiving work
316 # prevent this engine from receiving work
328 idx = self.targets.index(uid)
317 idx = self.targets.index(uid)
329 self.targets.pop(idx)
318 self.targets.pop(idx)
330 self.loads.pop(idx)
319 self.loads.pop(idx)
331
320
332 # wait 5 seconds before cleaning up pending jobs, since the results might
321 # wait 5 seconds before cleaning up pending jobs, since the results might
333 # still be incoming
322 # still be incoming
334 if self.pending[uid]:
323 if self.pending[uid]:
335 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
324 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
336 dc.start()
325 dc.start()
337 else:
326 else:
338 self.completed.pop(uid)
327 self.completed.pop(uid)
339 self.failed.pop(uid)
328 self.failed.pop(uid)
340
329
341
330
342 def handle_stranded_tasks(self, engine):
331 def handle_stranded_tasks(self, engine):
343 """Deal with jobs resident in an engine that died."""
332 """Deal with jobs resident in an engine that died."""
344 lost = self.pending[engine]
333 lost = self.pending[engine]
345 for msg_id in lost.keys():
334 for msg_id in lost.keys():
346 if msg_id not in self.pending[engine]:
335 if msg_id not in self.pending[engine]:
347 # prevent double-handling of messages
336 # prevent double-handling of messages
348 continue
337 continue
349
338
350 raw_msg = lost[msg_id].raw_msg
339 raw_msg = lost[msg_id].raw_msg
351 idents,msg = self.session.feed_identities(raw_msg, copy=False)
340 idents,msg = self.session.feed_identities(raw_msg, copy=False)
352 parent = self.session.unpack(msg[1].bytes)
341 parent = self.session.unpack(msg[1].bytes)
353 idents = [engine, idents[0]]
342 idents = [engine, idents[0]]
354
343
355 # build fake error reply
344 # build fake error reply
356 try:
345 try:
357 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
346 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
358 except:
347 except:
359 content = error.wrap_exception()
348 content = error.wrap_exception()
360 # build fake metadata
349 # build fake metadata
361 md = dict(
350 md = dict(
362 status=u'error',
351 status=u'error',
363 engine=engine.decode('ascii'),
352 engine=engine.decode('ascii'),
364 date=datetime.now(),
353 date=datetime.now(),
365 )
354 )
366 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
355 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
367 raw_reply = list(map(zmq.Message, self.session.serialize(msg, ident=idents)))
356 raw_reply = list(map(zmq.Message, self.session.serialize(msg, ident=idents)))
368 # and dispatch it
357 # and dispatch it
369 self.dispatch_result(raw_reply)
358 self.dispatch_result(raw_reply)
370
359
371 # finally scrub completed/failed lists
360 # finally scrub completed/failed lists
372 self.completed.pop(engine)
361 self.completed.pop(engine)
373 self.failed.pop(engine)
362 self.failed.pop(engine)
374
363
375
364
376 #-----------------------------------------------------------------------
365 #-----------------------------------------------------------------------
377 # Job Submission
366 # Job Submission
378 #-----------------------------------------------------------------------
367 #-----------------------------------------------------------------------
379
368
380
369
381 @util.log_errors
370 @util.log_errors
382 def dispatch_submission(self, raw_msg):
371 def dispatch_submission(self, raw_msg):
383 """Dispatch job submission to appropriate handlers."""
372 """Dispatch job submission to appropriate handlers."""
384 # ensure targets up to date:
373 # ensure targets up to date:
385 self.notifier_stream.flush()
374 self.notifier_stream.flush()
386 try:
375 try:
387 idents, msg = self.session.feed_identities(raw_msg, copy=False)
376 idents, msg = self.session.feed_identities(raw_msg, copy=False)
388 msg = self.session.unserialize(msg, content=False, copy=False)
377 msg = self.session.unserialize(msg, content=False, copy=False)
389 except Exception:
378 except Exception:
390 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
379 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
391 return
380 return
392
381
393
382
394 # send to monitor
383 # send to monitor
395 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
384 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
396
385
397 header = msg['header']
386 header = msg['header']
398 md = msg['metadata']
387 md = msg['metadata']
399 msg_id = header['msg_id']
388 msg_id = header['msg_id']
400 self.all_ids.add(msg_id)
389 self.all_ids.add(msg_id)
401
390
402 # get targets as a set of bytes objects
391 # get targets as a set of bytes objects
403 # from a list of unicode objects
392 # from a list of unicode objects
404 targets = md.get('targets', [])
393 targets = md.get('targets', [])
405 targets = set(map(cast_bytes, targets))
394 targets = set(map(cast_bytes, targets))
406
395
407 retries = md.get('retries', 0)
396 retries = md.get('retries', 0)
408 self.retries[msg_id] = retries
397 self.retries[msg_id] = retries
409
398
410 # time dependencies
399 # time dependencies
411 after = md.get('after', None)
400 after = md.get('after', None)
412 if after:
401 if after:
413 after = Dependency(after)
402 after = Dependency(after)
414 if after.all:
403 if after.all:
415 if after.success:
404 if after.success:
416 after = Dependency(after.difference(self.all_completed),
405 after = Dependency(after.difference(self.all_completed),
417 success=after.success,
406 success=after.success,
418 failure=after.failure,
407 failure=after.failure,
419 all=after.all,
408 all=after.all,
420 )
409 )
421 if after.failure:
410 if after.failure:
422 after = Dependency(after.difference(self.all_failed),
411 after = Dependency(after.difference(self.all_failed),
423 success=after.success,
412 success=after.success,
424 failure=after.failure,
413 failure=after.failure,
425 all=after.all,
414 all=after.all,
426 )
415 )
427 if after.check(self.all_completed, self.all_failed):
416 if after.check(self.all_completed, self.all_failed):
428 # recast as empty set, if `after` already met,
417 # recast as empty set, if `after` already met,
429 # to prevent unnecessary set comparisons
418 # to prevent unnecessary set comparisons
430 after = MET
419 after = MET
431 else:
420 else:
432 after = MET
421 after = MET
433
422
434 # location dependencies
423 # location dependencies
435 follow = Dependency(md.get('follow', []))
424 follow = Dependency(md.get('follow', []))
436
425
437 timeout = md.get('timeout', None)
426 timeout = md.get('timeout', None)
438 if timeout:
427 if timeout:
439 timeout = float(timeout)
428 timeout = float(timeout)
440
429
441 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
430 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
442 header=header, targets=targets, after=after, follow=follow,
431 header=header, targets=targets, after=after, follow=follow,
443 timeout=timeout, metadata=md,
432 timeout=timeout, metadata=md,
444 )
433 )
445 # validate and reduce dependencies:
434 # validate and reduce dependencies:
446 for dep in after,follow:
435 for dep in after,follow:
447 if not dep: # empty dependency
436 if not dep: # empty dependency
448 continue
437 continue
449 # check valid:
438 # check valid:
450 if msg_id in dep or dep.difference(self.all_ids):
439 if msg_id in dep or dep.difference(self.all_ids):
451 self.queue_map[msg_id] = job
440 self.queue_map[msg_id] = job
452 return self.fail_unreachable(msg_id, error.InvalidDependency)
441 return self.fail_unreachable(msg_id, error.InvalidDependency)
453 # check if unreachable:
442 # check if unreachable:
454 if dep.unreachable(self.all_completed, self.all_failed):
443 if dep.unreachable(self.all_completed, self.all_failed):
455 self.queue_map[msg_id] = job
444 self.queue_map[msg_id] = job
456 return self.fail_unreachable(msg_id)
445 return self.fail_unreachable(msg_id)
457
446
458 if after.check(self.all_completed, self.all_failed):
447 if after.check(self.all_completed, self.all_failed):
459 # time deps already met, try to run
448 # time deps already met, try to run
460 if not self.maybe_run(job):
449 if not self.maybe_run(job):
461 # can't run yet
450 # can't run yet
462 if msg_id not in self.all_failed:
451 if msg_id not in self.all_failed:
463 # could have failed as unreachable
452 # could have failed as unreachable
464 self.save_unmet(job)
453 self.save_unmet(job)
465 else:
454 else:
466 self.save_unmet(job)
455 self.save_unmet(job)
467
456
468 def job_timeout(self, job, timeout_id):
457 def job_timeout(self, job, timeout_id):
469 """callback for a job's timeout.
458 """callback for a job's timeout.
470
459
471 The job may or may not have been run at this point.
460 The job may or may not have been run at this point.
472 """
461 """
473 if job.timeout_id != timeout_id:
462 if job.timeout_id != timeout_id:
474 # not the most recent call
463 # not the most recent call
475 return
464 return
476 now = time.time()
465 now = time.time()
477 if job.timeout >= (now + 1):
466 if job.timeout >= (now + 1):
478 self.log.warn("task %s timeout fired prematurely: %s > %s",
467 self.log.warn("task %s timeout fired prematurely: %s > %s",
479 job.msg_id, job.timeout, now
468 job.msg_id, job.timeout, now
480 )
469 )
481 if job.msg_id in self.queue_map:
470 if job.msg_id in self.queue_map:
482 # still waiting, but ran out of time
471 # still waiting, but ran out of time
483 self.log.info("task %r timed out", job.msg_id)
472 self.log.info("task %r timed out", job.msg_id)
484 self.fail_unreachable(job.msg_id, error.TaskTimeout)
473 self.fail_unreachable(job.msg_id, error.TaskTimeout)
485
474
486 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
475 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
487 """a task has become unreachable, send a reply with an ImpossibleDependency
476 """a task has become unreachable, send a reply with an ImpossibleDependency
488 error."""
477 error."""
489 if msg_id not in self.queue_map:
478 if msg_id not in self.queue_map:
490 self.log.error("task %r already failed!", msg_id)
479 self.log.error("task %r already failed!", msg_id)
491 return
480 return
492 job = self.queue_map.pop(msg_id)
481 job = self.queue_map.pop(msg_id)
493 # lazy-delete from the queue
482 # lazy-delete from the queue
494 job.removed = True
483 job.removed = True
495 for mid in job.dependents:
484 for mid in job.dependents:
496 if mid in self.graph:
485 if mid in self.graph:
497 self.graph[mid].remove(msg_id)
486 self.graph[mid].remove(msg_id)
498
487
499 try:
488 try:
500 raise why()
489 raise why()
501 except:
490 except:
502 content = error.wrap_exception()
491 content = error.wrap_exception()
503 self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename'])
492 self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename'])
504
493
505 self.all_done.add(msg_id)
494 self.all_done.add(msg_id)
506 self.all_failed.add(msg_id)
495 self.all_failed.add(msg_id)
507
496
508 msg = self.session.send(self.client_stream, 'apply_reply', content,
497 msg = self.session.send(self.client_stream, 'apply_reply', content,
509 parent=job.header, ident=job.idents)
498 parent=job.header, ident=job.idents)
510 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
499 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
511
500
512 self.update_graph(msg_id, success=False)
501 self.update_graph(msg_id, success=False)
513
502
514 def available_engines(self):
503 def available_engines(self):
515 """return a list of available engine indices based on HWM"""
504 """return a list of available engine indices based on HWM"""
516 if not self.hwm:
505 if not self.hwm:
517 return list(range(len(self.targets)))
506 return list(range(len(self.targets)))
518 available = []
507 available = []
519 for idx in range(len(self.targets)):
508 for idx in range(len(self.targets)):
520 if self.loads[idx] < self.hwm:
509 if self.loads[idx] < self.hwm:
521 available.append(idx)
510 available.append(idx)
522 return available
511 return available
523
512
524 def maybe_run(self, job):
513 def maybe_run(self, job):
525 """check location dependencies, and run if they are met."""
514 """check location dependencies, and run if they are met."""
526 msg_id = job.msg_id
515 msg_id = job.msg_id
527 self.log.debug("Attempting to assign task %s", msg_id)
516 self.log.debug("Attempting to assign task %s", msg_id)
528 available = self.available_engines()
517 available = self.available_engines()
529 if not available:
518 if not available:
530 # no engines, definitely can't run
519 # no engines, definitely can't run
531 return False
520 return False
532
521
533 if job.follow or job.targets or job.blacklist or self.hwm:
522 if job.follow or job.targets or job.blacklist or self.hwm:
534 # we need a can_run filter
523 # we need a can_run filter
535 def can_run(idx):
524 def can_run(idx):
536 # check hwm
525 # check hwm
537 if self.hwm and self.loads[idx] == self.hwm:
526 if self.hwm and self.loads[idx] == self.hwm:
538 return False
527 return False
539 target = self.targets[idx]
528 target = self.targets[idx]
540 # check blacklist
529 # check blacklist
541 if target in job.blacklist:
530 if target in job.blacklist:
542 return False
531 return False
543 # check targets
532 # check targets
544 if job.targets and target not in job.targets:
533 if job.targets and target not in job.targets:
545 return False
534 return False
546 # check follow
535 # check follow
547 return job.follow.check(self.completed[target], self.failed[target])
536 return job.follow.check(self.completed[target], self.failed[target])
548
537
549 indices = list(filter(can_run, available))
538 indices = list(filter(can_run, available))
550
539
551 if not indices:
540 if not indices:
552 # couldn't run
541 # couldn't run
553 if job.follow.all:
542 if job.follow.all:
554 # check follow for impossibility
543 # check follow for impossibility
555 dests = set()
544 dests = set()
556 relevant = set()
545 relevant = set()
557 if job.follow.success:
546 if job.follow.success:
558 relevant = self.all_completed
547 relevant = self.all_completed
559 if job.follow.failure:
548 if job.follow.failure:
560 relevant = relevant.union(self.all_failed)
549 relevant = relevant.union(self.all_failed)
561 for m in job.follow.intersection(relevant):
550 for m in job.follow.intersection(relevant):
562 dests.add(self.destinations[m])
551 dests.add(self.destinations[m])
563 if len(dests) > 1:
552 if len(dests) > 1:
564 self.queue_map[msg_id] = job
553 self.queue_map[msg_id] = job
565 self.fail_unreachable(msg_id)
554 self.fail_unreachable(msg_id)
566 return False
555 return False
567 if job.targets:
556 if job.targets:
568 # check blacklist+targets for impossibility
557 # check blacklist+targets for impossibility
569 job.targets.difference_update(job.blacklist)
558 job.targets.difference_update(job.blacklist)
570 if not job.targets or not job.targets.intersection(self.targets):
559 if not job.targets or not job.targets.intersection(self.targets):
571 self.queue_map[msg_id] = job
560 self.queue_map[msg_id] = job
572 self.fail_unreachable(msg_id)
561 self.fail_unreachable(msg_id)
573 return False
562 return False
574 return False
563 return False
575 else:
564 else:
576 indices = None
565 indices = None
577
566
578 self.submit_task(job, indices)
567 self.submit_task(job, indices)
579 return True
568 return True
580
569
581 def save_unmet(self, job):
570 def save_unmet(self, job):
582 """Save a message for later submission when its dependencies are met."""
571 """Save a message for later submission when its dependencies are met."""
583 msg_id = job.msg_id
572 msg_id = job.msg_id
584 self.log.debug("Adding task %s to the queue", msg_id)
573 self.log.debug("Adding task %s to the queue", msg_id)
585 self.queue_map[msg_id] = job
574 self.queue_map[msg_id] = job
586 self.queue.append(job)
575 self.queue.append(job)
587 # track the ids in follow or after, but not those already finished
576 # track the ids in follow or after, but not those already finished
588 for dep_id in job.after.union(job.follow).difference(self.all_done):
577 for dep_id in job.after.union(job.follow).difference(self.all_done):
589 if dep_id not in self.graph:
578 if dep_id not in self.graph:
590 self.graph[dep_id] = set()
579 self.graph[dep_id] = set()
591 self.graph[dep_id].add(msg_id)
580 self.graph[dep_id].add(msg_id)
592
581
593 # schedule timeout callback
582 # schedule timeout callback
594 if job.timeout:
583 if job.timeout:
595 timeout_id = job.timeout_id = job.timeout_id + 1
584 timeout_id = job.timeout_id = job.timeout_id + 1
596 self.loop.add_timeout(time.time() + job.timeout,
585 self.loop.add_timeout(time.time() + job.timeout,
597 lambda : self.job_timeout(job, timeout_id)
586 lambda : self.job_timeout(job, timeout_id)
598 )
587 )
599
588
600
589
601 def submit_task(self, job, indices=None):
590 def submit_task(self, job, indices=None):
602 """Submit a task to any of a subset of our targets."""
591 """Submit a task to any of a subset of our targets."""
603 if indices:
592 if indices:
604 loads = [self.loads[i] for i in indices]
593 loads = [self.loads[i] for i in indices]
605 else:
594 else:
606 loads = self.loads
595 loads = self.loads
607 idx = self.scheme(loads)
596 idx = self.scheme(loads)
608 if indices:
597 if indices:
609 idx = indices[idx]
598 idx = indices[idx]
610 target = self.targets[idx]
599 target = self.targets[idx]
611 # print (target, map(str, msg[:3]))
600 # print (target, map(str, msg[:3]))
612 # send job to the engine
601 # send job to the engine
613 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
602 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
614 self.engine_stream.send_multipart(job.raw_msg, copy=False)
603 self.engine_stream.send_multipart(job.raw_msg, copy=False)
615 # update load
604 # update load
616 self.add_job(idx)
605 self.add_job(idx)
617 self.pending[target][job.msg_id] = job
606 self.pending[target][job.msg_id] = job
618 # notify Hub
607 # notify Hub
619 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
608 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
620 self.session.send(self.mon_stream, 'task_destination', content=content,
609 self.session.send(self.mon_stream, 'task_destination', content=content,
621 ident=[b'tracktask',self.ident])
610 ident=[b'tracktask',self.ident])
622
611
623
612
624 #-----------------------------------------------------------------------
613 #-----------------------------------------------------------------------
625 # Result Handling
614 # Result Handling
626 #-----------------------------------------------------------------------
615 #-----------------------------------------------------------------------
627
616
628
617
629 @util.log_errors
618 @util.log_errors
630 def dispatch_result(self, raw_msg):
619 def dispatch_result(self, raw_msg):
631 """dispatch method for result replies"""
620 """dispatch method for result replies"""
632 try:
621 try:
633 idents,msg = self.session.feed_identities(raw_msg, copy=False)
622 idents,msg = self.session.feed_identities(raw_msg, copy=False)
634 msg = self.session.unserialize(msg, content=False, copy=False)
623 msg = self.session.unserialize(msg, content=False, copy=False)
635 engine = idents[0]
624 engine = idents[0]
636 try:
625 try:
637 idx = self.targets.index(engine)
626 idx = self.targets.index(engine)
638 except ValueError:
627 except ValueError:
639 pass # skip load-update for dead engines
628 pass # skip load-update for dead engines
640 else:
629 else:
641 self.finish_job(idx)
630 self.finish_job(idx)
642 except Exception:
631 except Exception:
643 self.log.error("task::Invalid result: %r", raw_msg, exc_info=True)
632 self.log.error("task::Invalid result: %r", raw_msg, exc_info=True)
644 return
633 return
645
634
646 md = msg['metadata']
635 md = msg['metadata']
647 parent = msg['parent_header']
636 parent = msg['parent_header']
648 if md.get('dependencies_met', True):
637 if md.get('dependencies_met', True):
649 success = (md['status'] == 'ok')
638 success = (md['status'] == 'ok')
650 msg_id = parent['msg_id']
639 msg_id = parent['msg_id']
651 retries = self.retries[msg_id]
640 retries = self.retries[msg_id]
652 if not success and retries > 0:
641 if not success and retries > 0:
653 # failed
642 # failed
654 self.retries[msg_id] = retries - 1
643 self.retries[msg_id] = retries - 1
655 self.handle_unmet_dependency(idents, parent)
644 self.handle_unmet_dependency(idents, parent)
656 else:
645 else:
657 del self.retries[msg_id]
646 del self.retries[msg_id]
658 # relay to client and update graph
647 # relay to client and update graph
659 self.handle_result(idents, parent, raw_msg, success)
648 self.handle_result(idents, parent, raw_msg, success)
660 # send to Hub monitor
649 # send to Hub monitor
661 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
650 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
662 else:
651 else:
663 self.handle_unmet_dependency(idents, parent)
652 self.handle_unmet_dependency(idents, parent)
664
653
665 def handle_result(self, idents, parent, raw_msg, success=True):
654 def handle_result(self, idents, parent, raw_msg, success=True):
666 """handle a real task result, either success or failure"""
655 """handle a real task result, either success or failure"""
667 # first, relay result to client
656 # first, relay result to client
668 engine = idents[0]
657 engine = idents[0]
669 client = idents[1]
658 client = idents[1]
670 # swap_ids for ROUTER-ROUTER mirror
659 # swap_ids for ROUTER-ROUTER mirror
671 raw_msg[:2] = [client,engine]
660 raw_msg[:2] = [client,engine]
672 # print (map(str, raw_msg[:4]))
661 # print (map(str, raw_msg[:4]))
673 self.client_stream.send_multipart(raw_msg, copy=False)
662 self.client_stream.send_multipart(raw_msg, copy=False)
674 # now, update our data structures
663 # now, update our data structures
675 msg_id = parent['msg_id']
664 msg_id = parent['msg_id']
676 self.pending[engine].pop(msg_id)
665 self.pending[engine].pop(msg_id)
677 if success:
666 if success:
678 self.completed[engine].add(msg_id)
667 self.completed[engine].add(msg_id)
679 self.all_completed.add(msg_id)
668 self.all_completed.add(msg_id)
680 else:
669 else:
681 self.failed[engine].add(msg_id)
670 self.failed[engine].add(msg_id)
682 self.all_failed.add(msg_id)
671 self.all_failed.add(msg_id)
683 self.all_done.add(msg_id)
672 self.all_done.add(msg_id)
684 self.destinations[msg_id] = engine
673 self.destinations[msg_id] = engine
685
674
686 self.update_graph(msg_id, success)
675 self.update_graph(msg_id, success)
687
676
688 def handle_unmet_dependency(self, idents, parent):
677 def handle_unmet_dependency(self, idents, parent):
689 """handle an unmet dependency"""
678 """handle an unmet dependency"""
690 engine = idents[0]
679 engine = idents[0]
691 msg_id = parent['msg_id']
680 msg_id = parent['msg_id']
692
681
693 job = self.pending[engine].pop(msg_id)
682 job = self.pending[engine].pop(msg_id)
694 job.blacklist.add(engine)
683 job.blacklist.add(engine)
695
684
696 if job.blacklist == job.targets:
685 if job.blacklist == job.targets:
697 self.queue_map[msg_id] = job
686 self.queue_map[msg_id] = job
698 self.fail_unreachable(msg_id)
687 self.fail_unreachable(msg_id)
699 elif not self.maybe_run(job):
688 elif not self.maybe_run(job):
700 # resubmit failed
689 # resubmit failed
701 if msg_id not in self.all_failed:
690 if msg_id not in self.all_failed:
702 # put it back in our dependency tree
691 # put it back in our dependency tree
703 self.save_unmet(job)
692 self.save_unmet(job)
704
693
705 if self.hwm:
694 if self.hwm:
706 try:
695 try:
707 idx = self.targets.index(engine)
696 idx = self.targets.index(engine)
708 except ValueError:
697 except ValueError:
709 pass # skip load-update for dead engines
698 pass # skip load-update for dead engines
710 else:
699 else:
711 if self.loads[idx] == self.hwm-1:
700 if self.loads[idx] == self.hwm-1:
712 self.update_graph(None)
701 self.update_graph(None)
713
702
714 def update_graph(self, dep_id=None, success=True):
703 def update_graph(self, dep_id=None, success=True):
715 """dep_id just finished. Update our dependency
704 """dep_id just finished. Update our dependency
716 graph and submit any jobs that just became runnable.
705 graph and submit any jobs that just became runnable.
717
706
718 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
707 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
719 """
708 """
720 # print ("\n\n***********")
709 # print ("\n\n***********")
721 # pprint (dep_id)
710 # pprint (dep_id)
722 # pprint (self.graph)
711 # pprint (self.graph)
723 # pprint (self.queue_map)
712 # pprint (self.queue_map)
724 # pprint (self.all_completed)
713 # pprint (self.all_completed)
725 # pprint (self.all_failed)
714 # pprint (self.all_failed)
726 # print ("\n\n***********\n\n")
715 # print ("\n\n***********\n\n")
727 # update any jobs that depended on the dependency
716 # update any jobs that depended on the dependency
728 msg_ids = self.graph.pop(dep_id, [])
717 msg_ids = self.graph.pop(dep_id, [])
729
718
730 # recheck *all* jobs if
719 # recheck *all* jobs if
731 # a) we have HWM and an engine just become no longer full
720 # a) we have HWM and an engine just become no longer full
732 # or b) dep_id was given as None
721 # or b) dep_id was given as None
733
722
734 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
723 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
735 jobs = self.queue
724 jobs = self.queue
736 using_queue = True
725 using_queue = True
737 else:
726 else:
738 using_queue = False
727 using_queue = False
739 jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids ))
728 jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids ))
740
729
741 to_restore = []
730 to_restore = []
742 while jobs:
731 while jobs:
743 job = jobs.popleft()
732 job = jobs.popleft()
744 if job.removed:
733 if job.removed:
745 continue
734 continue
746 msg_id = job.msg_id
735 msg_id = job.msg_id
747
736
748 put_it_back = True
737 put_it_back = True
749
738
750 if job.after.unreachable(self.all_completed, self.all_failed)\
739 if job.after.unreachable(self.all_completed, self.all_failed)\
751 or job.follow.unreachable(self.all_completed, self.all_failed):
740 or job.follow.unreachable(self.all_completed, self.all_failed):
752 self.fail_unreachable(msg_id)
741 self.fail_unreachable(msg_id)
753 put_it_back = False
742 put_it_back = False
754
743
755 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
744 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
756 if self.maybe_run(job):
745 if self.maybe_run(job):
757 put_it_back = False
746 put_it_back = False
758 self.queue_map.pop(msg_id)
747 self.queue_map.pop(msg_id)
759 for mid in job.dependents:
748 for mid in job.dependents:
760 if mid in self.graph:
749 if mid in self.graph:
761 self.graph[mid].remove(msg_id)
750 self.graph[mid].remove(msg_id)
762
751
763 # abort the loop if we just filled up all of our engines.
752 # abort the loop if we just filled up all of our engines.
764 # avoids an O(N) operation in situation of full queue,
753 # avoids an O(N) operation in situation of full queue,
765 # where graph update is triggered as soon as an engine becomes
754 # where graph update is triggered as soon as an engine becomes
766 # non-full, and all tasks after the first are checked,
755 # non-full, and all tasks after the first are checked,
767 # even though they can't run.
756 # even though they can't run.
768 if not self.available_engines():
757 if not self.available_engines():
769 break
758 break
770
759
771 if using_queue and put_it_back:
760 if using_queue and put_it_back:
772 # popped a job from the queue but it neither ran nor failed,
761 # popped a job from the queue but it neither ran nor failed,
773 # so we need to put it back when we are done
762 # so we need to put it back when we are done
774 # make sure to_restore preserves the same ordering
763 # make sure to_restore preserves the same ordering
775 to_restore.append(job)
764 to_restore.append(job)
776
765
777 # put back any tasks we popped but didn't run
766 # put back any tasks we popped but didn't run
778 if using_queue:
767 if using_queue:
779 self.queue.extendleft(to_restore)
768 self.queue.extendleft(to_restore)
780
769
781 #----------------------------------------------------------------------
770 #----------------------------------------------------------------------
782 # methods to be overridden by subclasses
771 # methods to be overridden by subclasses
783 #----------------------------------------------------------------------
772 #----------------------------------------------------------------------
784
773
785 def add_job(self, idx):
774 def add_job(self, idx):
786 """Called after self.targets[idx] just got the job with header.
775 """Called after self.targets[idx] just got the job with header.
787 Override with subclasses. The default ordering is simple LRU.
776 Override with subclasses. The default ordering is simple LRU.
788 The default loads are the number of outstanding jobs."""
777 The default loads are the number of outstanding jobs."""
789 self.loads[idx] += 1
778 self.loads[idx] += 1
790 for lis in (self.targets, self.loads):
779 for lis in (self.targets, self.loads):
791 lis.append(lis.pop(idx))
780 lis.append(lis.pop(idx))
792
781
793
782
794 def finish_job(self, idx):
783 def finish_job(self, idx):
795 """Called after self.targets[idx] just finished a job.
784 """Called after self.targets[idx] just finished a job.
796 Override with subclasses."""
785 Override with subclasses."""
797 self.loads[idx] -= 1
786 self.loads[idx] -= 1
798
787
799
788
800
789
801 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
790 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
802 logname='root', log_url=None, loglevel=logging.DEBUG,
791 logname='root', log_url=None, loglevel=logging.DEBUG,
803 identity=b'task', in_thread=False):
792 identity=b'task', in_thread=False):
804
793
805 ZMQStream = zmqstream.ZMQStream
794 ZMQStream = zmqstream.ZMQStream
806
795
807 if config:
796 if config:
808 # unwrap dict back into Config
797 # unwrap dict back into Config
809 config = Config(config)
798 config = Config(config)
810
799
811 if in_thread:
800 if in_thread:
812 # use instance() to get the same Context/Loop as our parent
801 # use instance() to get the same Context/Loop as our parent
813 ctx = zmq.Context.instance()
802 ctx = zmq.Context.instance()
814 loop = ioloop.IOLoop.instance()
803 loop = ioloop.IOLoop.instance()
815 else:
804 else:
816 # in a process, don't use instance()
805 # in a process, don't use instance()
817 # for safety with multiprocessing
806 # for safety with multiprocessing
818 ctx = zmq.Context()
807 ctx = zmq.Context()
819 loop = ioloop.IOLoop()
808 loop = ioloop.IOLoop()
820 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
809 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
821 util.set_hwm(ins, 0)
810 util.set_hwm(ins, 0)
822 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
811 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
823 ins.bind(in_addr)
812 ins.bind(in_addr)
824
813
825 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
814 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
826 util.set_hwm(outs, 0)
815 util.set_hwm(outs, 0)
827 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
816 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
828 outs.bind(out_addr)
817 outs.bind(out_addr)
829 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
818 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
830 util.set_hwm(mons, 0)
819 util.set_hwm(mons, 0)
831 mons.connect(mon_addr)
820 mons.connect(mon_addr)
832 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
821 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
833 nots.setsockopt(zmq.SUBSCRIBE, b'')
822 nots.setsockopt(zmq.SUBSCRIBE, b'')
834 nots.connect(not_addr)
823 nots.connect(not_addr)
835
824
836 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
825 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
837 querys.connect(reg_addr)
826 querys.connect(reg_addr)
838
827
839 # setup logging.
828 # setup logging.
840 if in_thread:
829 if in_thread:
841 log = Application.instance().log
830 log = Application.instance().log
842 else:
831 else:
843 if log_url:
832 if log_url:
844 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
833 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
845 else:
834 else:
846 log = local_logger(logname, loglevel)
835 log = local_logger(logname, loglevel)
847
836
848 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
837 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
849 mon_stream=mons, notifier_stream=nots,
838 mon_stream=mons, notifier_stream=nots,
850 query_stream=querys,
839 query_stream=querys,
851 loop=loop, log=log,
840 loop=loop, log=log,
852 config=config)
841 config=config)
853 scheduler.start()
842 scheduler.start()
854 if not in_thread:
843 if not in_thread:
855 try:
844 try:
856 loop.start()
845 loop.start()
857 except KeyboardInterrupt:
846 except KeyboardInterrupt:
858 scheduler.log.critical("Interrupted, exiting...")
847 scheduler.log.critical("Interrupted, exiting...")
859
848
@@ -1,388 +1,389 b''
1 """Some generic utilities for dealing with classes, urls, and serialization."""
1 """Some generic utilities for dealing with classes, urls, and serialization."""
2
2
3 # Copyright (c) IPython Development Team.
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
4 # Distributed under the terms of the Modified BSD License.
5
5
6 import logging
6 import logging
7 import os
7 import os
8 import re
8 import re
9 import stat
9 import stat
10 import socket
10 import socket
11 import sys
11 import sys
12 import warnings
12 import warnings
13 from signal import signal, SIGINT, SIGABRT, SIGTERM
13 from signal import signal, SIGINT, SIGABRT, SIGTERM
14 try:
14 try:
15 from signal import SIGKILL
15 from signal import SIGKILL
16 except ImportError:
16 except ImportError:
17 SIGKILL=None
17 SIGKILL=None
18 from types import FunctionType
18 from types import FunctionType
19
19
20 try:
20 try:
21 import cPickle
21 import cPickle
22 pickle = cPickle
22 pickle = cPickle
23 except:
23 except:
24 cPickle = None
24 cPickle = None
25 import pickle
25 import pickle
26
26
27 import zmq
27 import zmq
28 from zmq.log import handlers
28 from zmq.log import handlers
29
29
30 from IPython.utils.log import get_logger
30 from IPython.external.decorator import decorator
31 from IPython.external.decorator import decorator
31
32
32 from IPython.config.application import Application
33 from IPython.config.application import Application
33 from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips
34 from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips
34 from IPython.utils.py3compat import string_types, iteritems, itervalues
35 from IPython.utils.py3compat import string_types, iteritems, itervalues
35 from IPython.kernel.zmq.log import EnginePUBHandler
36 from IPython.kernel.zmq.log import EnginePUBHandler
36
37
37
38
38 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
39 # Classes
40 # Classes
40 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
41
42
42 class Namespace(dict):
43 class Namespace(dict):
43 """Subclass of dict for attribute access to keys."""
44 """Subclass of dict for attribute access to keys."""
44
45
45 def __getattr__(self, key):
46 def __getattr__(self, key):
46 """getattr aliased to getitem"""
47 """getattr aliased to getitem"""
47 if key in self:
48 if key in self:
48 return self[key]
49 return self[key]
49 else:
50 else:
50 raise NameError(key)
51 raise NameError(key)
51
52
52 def __setattr__(self, key, value):
53 def __setattr__(self, key, value):
53 """setattr aliased to setitem, with strict"""
54 """setattr aliased to setitem, with strict"""
54 if hasattr(dict, key):
55 if hasattr(dict, key):
55 raise KeyError("Cannot override dict keys %r"%key)
56 raise KeyError("Cannot override dict keys %r"%key)
56 self[key] = value
57 self[key] = value
57
58
58
59
59 class ReverseDict(dict):
60 class ReverseDict(dict):
60 """simple double-keyed subset of dict methods."""
61 """simple double-keyed subset of dict methods."""
61
62
62 def __init__(self, *args, **kwargs):
63 def __init__(self, *args, **kwargs):
63 dict.__init__(self, *args, **kwargs)
64 dict.__init__(self, *args, **kwargs)
64 self._reverse = dict()
65 self._reverse = dict()
65 for key, value in iteritems(self):
66 for key, value in iteritems(self):
66 self._reverse[value] = key
67 self._reverse[value] = key
67
68
68 def __getitem__(self, key):
69 def __getitem__(self, key):
69 try:
70 try:
70 return dict.__getitem__(self, key)
71 return dict.__getitem__(self, key)
71 except KeyError:
72 except KeyError:
72 return self._reverse[key]
73 return self._reverse[key]
73
74
74 def __setitem__(self, key, value):
75 def __setitem__(self, key, value):
75 if key in self._reverse:
76 if key in self._reverse:
76 raise KeyError("Can't have key %r on both sides!"%key)
77 raise KeyError("Can't have key %r on both sides!"%key)
77 dict.__setitem__(self, key, value)
78 dict.__setitem__(self, key, value)
78 self._reverse[value] = key
79 self._reverse[value] = key
79
80
80 def pop(self, key):
81 def pop(self, key):
81 value = dict.pop(self, key)
82 value = dict.pop(self, key)
82 self._reverse.pop(value)
83 self._reverse.pop(value)
83 return value
84 return value
84
85
85 def get(self, key, default=None):
86 def get(self, key, default=None):
86 try:
87 try:
87 return self[key]
88 return self[key]
88 except KeyError:
89 except KeyError:
89 return default
90 return default
90
91
91 #-----------------------------------------------------------------------------
92 #-----------------------------------------------------------------------------
92 # Functions
93 # Functions
93 #-----------------------------------------------------------------------------
94 #-----------------------------------------------------------------------------
94
95
95 @decorator
96 @decorator
96 def log_errors(f, self, *args, **kwargs):
97 def log_errors(f, self, *args, **kwargs):
97 """decorator to log unhandled exceptions raised in a method.
98 """decorator to log unhandled exceptions raised in a method.
98
99
99 For use wrapping on_recv callbacks, so that exceptions
100 For use wrapping on_recv callbacks, so that exceptions
100 do not cause the stream to be closed.
101 do not cause the stream to be closed.
101 """
102 """
102 try:
103 try:
103 return f(self, *args, **kwargs)
104 return f(self, *args, **kwargs)
104 except Exception:
105 except Exception:
105 self.log.error("Uncaught exception in %r" % f, exc_info=True)
106 self.log.error("Uncaught exception in %r" % f, exc_info=True)
106
107
107
108
108 def is_url(url):
109 def is_url(url):
109 """boolean check for whether a string is a zmq url"""
110 """boolean check for whether a string is a zmq url"""
110 if '://' not in url:
111 if '://' not in url:
111 return False
112 return False
112 proto, addr = url.split('://', 1)
113 proto, addr = url.split('://', 1)
113 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
114 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
114 return False
115 return False
115 return True
116 return True
116
117
117 def validate_url(url):
118 def validate_url(url):
118 """validate a url for zeromq"""
119 """validate a url for zeromq"""
119 if not isinstance(url, string_types):
120 if not isinstance(url, string_types):
120 raise TypeError("url must be a string, not %r"%type(url))
121 raise TypeError("url must be a string, not %r"%type(url))
121 url = url.lower()
122 url = url.lower()
122
123
123 proto_addr = url.split('://')
124 proto_addr = url.split('://')
124 assert len(proto_addr) == 2, 'Invalid url: %r'%url
125 assert len(proto_addr) == 2, 'Invalid url: %r'%url
125 proto, addr = proto_addr
126 proto, addr = proto_addr
126 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
127 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
127
128
128 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
129 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
129 # author: Remi Sabourin
130 # author: Remi Sabourin
130 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
131 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
131
132
132 if proto == 'tcp':
133 if proto == 'tcp':
133 lis = addr.split(':')
134 lis = addr.split(':')
134 assert len(lis) == 2, 'Invalid url: %r'%url
135 assert len(lis) == 2, 'Invalid url: %r'%url
135 addr,s_port = lis
136 addr,s_port = lis
136 try:
137 try:
137 port = int(s_port)
138 port = int(s_port)
138 except ValueError:
139 except ValueError:
139 raise AssertionError("Invalid port %r in url: %r"%(port, url))
140 raise AssertionError("Invalid port %r in url: %r"%(port, url))
140
141
141 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
142 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
142
143
143 else:
144 else:
144 # only validate tcp urls currently
145 # only validate tcp urls currently
145 pass
146 pass
146
147
147 return True
148 return True
148
149
149
150
150 def validate_url_container(container):
151 def validate_url_container(container):
151 """validate a potentially nested collection of urls."""
152 """validate a potentially nested collection of urls."""
152 if isinstance(container, string_types):
153 if isinstance(container, string_types):
153 url = container
154 url = container
154 return validate_url(url)
155 return validate_url(url)
155 elif isinstance(container, dict):
156 elif isinstance(container, dict):
156 container = itervalues(container)
157 container = itervalues(container)
157
158
158 for element in container:
159 for element in container:
159 validate_url_container(element)
160 validate_url_container(element)
160
161
161
162
162 def split_url(url):
163 def split_url(url):
163 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
164 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
164 proto_addr = url.split('://')
165 proto_addr = url.split('://')
165 assert len(proto_addr) == 2, 'Invalid url: %r'%url
166 assert len(proto_addr) == 2, 'Invalid url: %r'%url
166 proto, addr = proto_addr
167 proto, addr = proto_addr
167 lis = addr.split(':')
168 lis = addr.split(':')
168 assert len(lis) == 2, 'Invalid url: %r'%url
169 assert len(lis) == 2, 'Invalid url: %r'%url
169 addr,s_port = lis
170 addr,s_port = lis
170 return proto,addr,s_port
171 return proto,addr,s_port
171
172
172
173
173 def disambiguate_ip_address(ip, location=None):
174 def disambiguate_ip_address(ip, location=None):
174 """turn multi-ip interfaces '0.0.0.0' and '*' into a connectable address
175 """turn multi-ip interfaces '0.0.0.0' and '*' into a connectable address
175
176
176 Explicit IP addresses are returned unmodified.
177 Explicit IP addresses are returned unmodified.
177
178
178 Parameters
179 Parameters
179 ----------
180 ----------
180
181
181 ip : IP address
182 ip : IP address
182 An IP address, or the special values 0.0.0.0, or *
183 An IP address, or the special values 0.0.0.0, or *
183 location: IP address, optional
184 location: IP address, optional
184 A public IP of the target machine.
185 A public IP of the target machine.
185 If location is an IP of the current machine,
186 If location is an IP of the current machine,
186 localhost will be returned,
187 localhost will be returned,
187 otherwise location will be returned.
188 otherwise location will be returned.
188 """
189 """
189 if ip in {'0.0.0.0', '*'}:
190 if ip in {'0.0.0.0', '*'}:
190 if not location:
191 if not location:
191 # unspecified location, localhost is the only choice
192 # unspecified location, localhost is the only choice
192 ip = localhost()
193 ip = localhost()
193 elif is_public_ip(location):
194 elif is_public_ip(location):
194 # location is a public IP on this machine, use localhost
195 # location is a public IP on this machine, use localhost
195 ip = localhost()
196 ip = localhost()
196 elif not public_ips():
197 elif not public_ips():
197 # this machine's public IPs cannot be determined,
198 # this machine's public IPs cannot be determined,
198 # assume `location` is not this machine
199 # assume `location` is not this machine
199 warnings.warn("IPython could not determine public IPs", RuntimeWarning)
200 warnings.warn("IPython could not determine public IPs", RuntimeWarning)
200 ip = location
201 ip = location
201 else:
202 else:
202 # location is not this machine, do not use loopback
203 # location is not this machine, do not use loopback
203 ip = location
204 ip = location
204 return ip
205 return ip
205
206
206
207
207 def disambiguate_url(url, location=None):
208 def disambiguate_url(url, location=None):
208 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
209 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
209 ones, based on the location (default interpretation is localhost).
210 ones, based on the location (default interpretation is localhost).
210
211
211 This is for zeromq urls, such as ``tcp://*:10101``.
212 This is for zeromq urls, such as ``tcp://*:10101``.
212 """
213 """
213 try:
214 try:
214 proto,ip,port = split_url(url)
215 proto,ip,port = split_url(url)
215 except AssertionError:
216 except AssertionError:
216 # probably not tcp url; could be ipc, etc.
217 # probably not tcp url; could be ipc, etc.
217 return url
218 return url
218
219
219 ip = disambiguate_ip_address(ip,location)
220 ip = disambiguate_ip_address(ip,location)
220
221
221 return "%s://%s:%s"%(proto,ip,port)
222 return "%s://%s:%s"%(proto,ip,port)
222
223
223
224
224 #--------------------------------------------------------------------------
225 #--------------------------------------------------------------------------
225 # helpers for implementing old MEC API via view.apply
226 # helpers for implementing old MEC API via view.apply
226 #--------------------------------------------------------------------------
227 #--------------------------------------------------------------------------
227
228
228 def interactive(f):
229 def interactive(f):
229 """decorator for making functions appear as interactively defined.
230 """decorator for making functions appear as interactively defined.
230 This results in the function being linked to the user_ns as globals()
231 This results in the function being linked to the user_ns as globals()
231 instead of the module globals().
232 instead of the module globals().
232 """
233 """
233
234
234 # build new FunctionType, so it can have the right globals
235 # build new FunctionType, so it can have the right globals
235 # interactive functions never have closures, that's kind of the point
236 # interactive functions never have closures, that's kind of the point
236 if isinstance(f, FunctionType):
237 if isinstance(f, FunctionType):
237 mainmod = __import__('__main__')
238 mainmod = __import__('__main__')
238 f = FunctionType(f.__code__, mainmod.__dict__,
239 f = FunctionType(f.__code__, mainmod.__dict__,
239 f.__name__, f.__defaults__,
240 f.__name__, f.__defaults__,
240 )
241 )
241 # associate with __main__ for uncanning
242 # associate with __main__ for uncanning
242 f.__module__ = '__main__'
243 f.__module__ = '__main__'
243 return f
244 return f
244
245
245 @interactive
246 @interactive
246 def _push(**ns):
247 def _push(**ns):
247 """helper method for implementing `client.push` via `client.apply`"""
248 """helper method for implementing `client.push` via `client.apply`"""
248 user_ns = globals()
249 user_ns = globals()
249 tmp = '_IP_PUSH_TMP_'
250 tmp = '_IP_PUSH_TMP_'
250 while tmp in user_ns:
251 while tmp in user_ns:
251 tmp = tmp + '_'
252 tmp = tmp + '_'
252 try:
253 try:
253 for name, value in ns.items():
254 for name, value in ns.items():
254 user_ns[tmp] = value
255 user_ns[tmp] = value
255 exec("%s = %s" % (name, tmp), user_ns)
256 exec("%s = %s" % (name, tmp), user_ns)
256 finally:
257 finally:
257 user_ns.pop(tmp, None)
258 user_ns.pop(tmp, None)
258
259
259 @interactive
260 @interactive
260 def _pull(keys):
261 def _pull(keys):
261 """helper method for implementing `client.pull` via `client.apply`"""
262 """helper method for implementing `client.pull` via `client.apply`"""
262 if isinstance(keys, (list,tuple, set)):
263 if isinstance(keys, (list,tuple, set)):
263 return [eval(key, globals()) for key in keys]
264 return [eval(key, globals()) for key in keys]
264 else:
265 else:
265 return eval(keys, globals())
266 return eval(keys, globals())
266
267
267 @interactive
268 @interactive
268 def _execute(code):
269 def _execute(code):
269 """helper method for implementing `client.execute` via `client.apply`"""
270 """helper method for implementing `client.execute` via `client.apply`"""
270 exec(code, globals())
271 exec(code, globals())
271
272
272 #--------------------------------------------------------------------------
273 #--------------------------------------------------------------------------
273 # extra process management utilities
274 # extra process management utilities
274 #--------------------------------------------------------------------------
275 #--------------------------------------------------------------------------
275
276
276 _random_ports = set()
277 _random_ports = set()
277
278
278 def select_random_ports(n):
279 def select_random_ports(n):
279 """Selects and return n random ports that are available."""
280 """Selects and return n random ports that are available."""
280 ports = []
281 ports = []
281 for i in range(n):
282 for i in range(n):
282 sock = socket.socket()
283 sock = socket.socket()
283 sock.bind(('', 0))
284 sock.bind(('', 0))
284 while sock.getsockname()[1] in _random_ports:
285 while sock.getsockname()[1] in _random_ports:
285 sock.close()
286 sock.close()
286 sock = socket.socket()
287 sock = socket.socket()
287 sock.bind(('', 0))
288 sock.bind(('', 0))
288 ports.append(sock)
289 ports.append(sock)
289 for i, sock in enumerate(ports):
290 for i, sock in enumerate(ports):
290 port = sock.getsockname()[1]
291 port = sock.getsockname()[1]
291 sock.close()
292 sock.close()
292 ports[i] = port
293 ports[i] = port
293 _random_ports.add(port)
294 _random_ports.add(port)
294 return ports
295 return ports
295
296
296 def signal_children(children):
297 def signal_children(children):
297 """Relay interupt/term signals to children, for more solid process cleanup."""
298 """Relay interupt/term signals to children, for more solid process cleanup."""
298 def terminate_children(sig, frame):
299 def terminate_children(sig, frame):
299 log = Application.instance().log
300 log = get_logger()
300 log.critical("Got signal %i, terminating children..."%sig)
301 log.critical("Got signal %i, terminating children..."%sig)
301 for child in children:
302 for child in children:
302 child.terminate()
303 child.terminate()
303
304
304 sys.exit(sig != SIGINT)
305 sys.exit(sig != SIGINT)
305 # sys.exit(sig)
306 # sys.exit(sig)
306 for sig in (SIGINT, SIGABRT, SIGTERM):
307 for sig in (SIGINT, SIGABRT, SIGTERM):
307 signal(sig, terminate_children)
308 signal(sig, terminate_children)
308
309
309 def generate_exec_key(keyfile):
310 def generate_exec_key(keyfile):
310 import uuid
311 import uuid
311 newkey = str(uuid.uuid4())
312 newkey = str(uuid.uuid4())
312 with open(keyfile, 'w') as f:
313 with open(keyfile, 'w') as f:
313 # f.write('ipython-key ')
314 # f.write('ipython-key ')
314 f.write(newkey+'\n')
315 f.write(newkey+'\n')
315 # set user-only RW permissions (0600)
316 # set user-only RW permissions (0600)
316 # this will have no effect on Windows
317 # this will have no effect on Windows
317 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
318 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
318
319
319
320
320 def integer_loglevel(loglevel):
321 def integer_loglevel(loglevel):
321 try:
322 try:
322 loglevel = int(loglevel)
323 loglevel = int(loglevel)
323 except ValueError:
324 except ValueError:
324 if isinstance(loglevel, str):
325 if isinstance(loglevel, str):
325 loglevel = getattr(logging, loglevel)
326 loglevel = getattr(logging, loglevel)
326 return loglevel
327 return loglevel
327
328
328 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
329 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
329 logger = logging.getLogger(logname)
330 logger = logging.getLogger(logname)
330 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
331 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
331 # don't add a second PUBHandler
332 # don't add a second PUBHandler
332 return
333 return
333 loglevel = integer_loglevel(loglevel)
334 loglevel = integer_loglevel(loglevel)
334 lsock = context.socket(zmq.PUB)
335 lsock = context.socket(zmq.PUB)
335 lsock.connect(iface)
336 lsock.connect(iface)
336 handler = handlers.PUBHandler(lsock)
337 handler = handlers.PUBHandler(lsock)
337 handler.setLevel(loglevel)
338 handler.setLevel(loglevel)
338 handler.root_topic = root
339 handler.root_topic = root
339 logger.addHandler(handler)
340 logger.addHandler(handler)
340 logger.setLevel(loglevel)
341 logger.setLevel(loglevel)
341
342
342 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
343 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
343 logger = logging.getLogger()
344 logger = logging.getLogger()
344 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
345 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
345 # don't add a second PUBHandler
346 # don't add a second PUBHandler
346 return
347 return
347 loglevel = integer_loglevel(loglevel)
348 loglevel = integer_loglevel(loglevel)
348 lsock = context.socket(zmq.PUB)
349 lsock = context.socket(zmq.PUB)
349 lsock.connect(iface)
350 lsock.connect(iface)
350 handler = EnginePUBHandler(engine, lsock)
351 handler = EnginePUBHandler(engine, lsock)
351 handler.setLevel(loglevel)
352 handler.setLevel(loglevel)
352 logger.addHandler(handler)
353 logger.addHandler(handler)
353 logger.setLevel(loglevel)
354 logger.setLevel(loglevel)
354 return logger
355 return logger
355
356
356 def local_logger(logname, loglevel=logging.DEBUG):
357 def local_logger(logname, loglevel=logging.DEBUG):
357 loglevel = integer_loglevel(loglevel)
358 loglevel = integer_loglevel(loglevel)
358 logger = logging.getLogger(logname)
359 logger = logging.getLogger(logname)
359 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
360 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
360 # don't add a second StreamHandler
361 # don't add a second StreamHandler
361 return
362 return
362 handler = logging.StreamHandler()
363 handler = logging.StreamHandler()
363 handler.setLevel(loglevel)
364 handler.setLevel(loglevel)
364 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
365 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
365 datefmt="%Y-%m-%d %H:%M:%S")
366 datefmt="%Y-%m-%d %H:%M:%S")
366 handler.setFormatter(formatter)
367 handler.setFormatter(formatter)
367
368
368 logger.addHandler(handler)
369 logger.addHandler(handler)
369 logger.setLevel(loglevel)
370 logger.setLevel(loglevel)
370 return logger
371 return logger
371
372
372 def set_hwm(sock, hwm=0):
373 def set_hwm(sock, hwm=0):
373 """set zmq High Water Mark on a socket
374 """set zmq High Water Mark on a socket
374
375
375 in a way that always works for various pyzmq / libzmq versions.
376 in a way that always works for various pyzmq / libzmq versions.
376 """
377 """
377 import zmq
378 import zmq
378
379
379 for key in ('HWM', 'SNDHWM', 'RCVHWM'):
380 for key in ('HWM', 'SNDHWM', 'RCVHWM'):
380 opt = getattr(zmq, key, None)
381 opt = getattr(zmq, key, None)
381 if opt is None:
382 if opt is None:
382 continue
383 continue
383 try:
384 try:
384 sock.setsockopt(opt, hwm)
385 sock.setsockopt(opt, hwm)
385 except zmq.ZMQError:
386 except zmq.ZMQError:
386 pass
387 pass
387
388
388
389
@@ -1,433 +1,420 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Pickle related utilities. Perhaps this should be called 'can'."""
2 """Pickle related utilities. Perhaps this should be called 'can'."""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 import copy
7 import copy
8 import logging
8 import logging
9 import sys
9 import sys
10 from types import FunctionType
10 from types import FunctionType
11
11
12 try:
12 try:
13 import cPickle as pickle
13 import cPickle as pickle
14 except ImportError:
14 except ImportError:
15 import pickle
15 import pickle
16
16
17 from . import codeutil # This registers a hook when it's imported
17 from . import codeutil # This registers a hook when it's imported
18 from . import py3compat
18 from . import py3compat
19 from .importstring import import_item
19 from .importstring import import_item
20 from .py3compat import string_types, iteritems
20 from .py3compat import string_types, iteritems
21
21
22 from IPython.config import Application
22 from IPython.config import Application
23 from IPython.utils.log import get_logger
23
24
24 if py3compat.PY3:
25 if py3compat.PY3:
25 buffer = memoryview
26 buffer = memoryview
26 class_type = type
27 class_type = type
27 else:
28 else:
28 from types import ClassType
29 from types import ClassType
29 class_type = (type, ClassType)
30 class_type = (type, ClassType)
30
31
31 def _get_cell_type(a=None):
32 def _get_cell_type(a=None):
32 """the type of a closure cell doesn't seem to be importable,
33 """the type of a closure cell doesn't seem to be importable,
33 so just create one
34 so just create one
34 """
35 """
35 def inner():
36 def inner():
36 return a
37 return a
37 return type(py3compat.get_closure(inner)[0])
38 return type(py3compat.get_closure(inner)[0])
38
39
39 cell_type = _get_cell_type()
40 cell_type = _get_cell_type()
40
41
41 #-------------------------------------------------------------------------------
42 #-------------------------------------------------------------------------------
42 # Functions
43 # Functions
43 #-------------------------------------------------------------------------------
44 #-------------------------------------------------------------------------------
44
45
45
46
46 def use_dill():
47 def use_dill():
47 """use dill to expand serialization support
48 """use dill to expand serialization support
48
49
49 adds support for object methods and closures to serialization.
50 adds support for object methods and closures to serialization.
50 """
51 """
51 # import dill causes most of the magic
52 # import dill causes most of the magic
52 import dill
53 import dill
53
54
54 # dill doesn't work with cPickle,
55 # dill doesn't work with cPickle,
55 # tell the two relevant modules to use plain pickle
56 # tell the two relevant modules to use plain pickle
56
57
57 global pickle
58 global pickle
58 pickle = dill
59 pickle = dill
59
60
60 try:
61 try:
61 from IPython.kernel.zmq import serialize
62 from IPython.kernel.zmq import serialize
62 except ImportError:
63 except ImportError:
63 pass
64 pass
64 else:
65 else:
65 serialize.pickle = dill
66 serialize.pickle = dill
66
67
67 # disable special function handling, let dill take care of it
68 # disable special function handling, let dill take care of it
68 can_map.pop(FunctionType, None)
69 can_map.pop(FunctionType, None)
69
70
70 def use_cloudpickle():
71 def use_cloudpickle():
71 """use cloudpickle to expand serialization support
72 """use cloudpickle to expand serialization support
72
73
73 adds support for object methods and closures to serialization.
74 adds support for object methods and closures to serialization.
74 """
75 """
75 from cloud.serialization import cloudpickle
76 from cloud.serialization import cloudpickle
76
77
77 global pickle
78 global pickle
78 pickle = cloudpickle
79 pickle = cloudpickle
79
80
80 try:
81 try:
81 from IPython.kernel.zmq import serialize
82 from IPython.kernel.zmq import serialize
82 except ImportError:
83 except ImportError:
83 pass
84 pass
84 else:
85 else:
85 serialize.pickle = cloudpickle
86 serialize.pickle = cloudpickle
86
87
87 # disable special function handling, let cloudpickle take care of it
88 # disable special function handling, let cloudpickle take care of it
88 can_map.pop(FunctionType, None)
89 can_map.pop(FunctionType, None)
89
90
90
91
91 #-------------------------------------------------------------------------------
92 #-------------------------------------------------------------------------------
92 # Classes
93 # Classes
93 #-------------------------------------------------------------------------------
94 #-------------------------------------------------------------------------------
94
95
95
96
96 class CannedObject(object):
97 class CannedObject(object):
97 def __init__(self, obj, keys=[], hook=None):
98 def __init__(self, obj, keys=[], hook=None):
98 """can an object for safe pickling
99 """can an object for safe pickling
99
100
100 Parameters
101 Parameters
101 ==========
102 ==========
102
103
103 obj:
104 obj:
104 The object to be canned
105 The object to be canned
105 keys: list (optional)
106 keys: list (optional)
106 list of attribute names that will be explicitly canned / uncanned
107 list of attribute names that will be explicitly canned / uncanned
107 hook: callable (optional)
108 hook: callable (optional)
108 An optional extra callable,
109 An optional extra callable,
109 which can do additional processing of the uncanned object.
110 which can do additional processing of the uncanned object.
110
111
111 large data may be offloaded into the buffers list,
112 large data may be offloaded into the buffers list,
112 used for zero-copy transfers.
113 used for zero-copy transfers.
113 """
114 """
114 self.keys = keys
115 self.keys = keys
115 self.obj = copy.copy(obj)
116 self.obj = copy.copy(obj)
116 self.hook = can(hook)
117 self.hook = can(hook)
117 for key in keys:
118 for key in keys:
118 setattr(self.obj, key, can(getattr(obj, key)))
119 setattr(self.obj, key, can(getattr(obj, key)))
119
120
120 self.buffers = []
121 self.buffers = []
121
122
122 def get_object(self, g=None):
123 def get_object(self, g=None):
123 if g is None:
124 if g is None:
124 g = {}
125 g = {}
125 obj = self.obj
126 obj = self.obj
126 for key in self.keys:
127 for key in self.keys:
127 setattr(obj, key, uncan(getattr(obj, key), g))
128 setattr(obj, key, uncan(getattr(obj, key), g))
128
129
129 if self.hook:
130 if self.hook:
130 self.hook = uncan(self.hook, g)
131 self.hook = uncan(self.hook, g)
131 self.hook(obj, g)
132 self.hook(obj, g)
132 return self.obj
133 return self.obj
133
134
134
135
135 class Reference(CannedObject):
136 class Reference(CannedObject):
136 """object for wrapping a remote reference by name."""
137 """object for wrapping a remote reference by name."""
137 def __init__(self, name):
138 def __init__(self, name):
138 if not isinstance(name, string_types):
139 if not isinstance(name, string_types):
139 raise TypeError("illegal name: %r"%name)
140 raise TypeError("illegal name: %r"%name)
140 self.name = name
141 self.name = name
141 self.buffers = []
142 self.buffers = []
142
143
143 def __repr__(self):
144 def __repr__(self):
144 return "<Reference: %r>"%self.name
145 return "<Reference: %r>"%self.name
145
146
146 def get_object(self, g=None):
147 def get_object(self, g=None):
147 if g is None:
148 if g is None:
148 g = {}
149 g = {}
149
150
150 return eval(self.name, g)
151 return eval(self.name, g)
151
152
152
153
153 class CannedCell(CannedObject):
154 class CannedCell(CannedObject):
154 """Can a closure cell"""
155 """Can a closure cell"""
155 def __init__(self, cell):
156 def __init__(self, cell):
156 self.cell_contents = can(cell.cell_contents)
157 self.cell_contents = can(cell.cell_contents)
157
158
158 def get_object(self, g=None):
159 def get_object(self, g=None):
159 cell_contents = uncan(self.cell_contents, g)
160 cell_contents = uncan(self.cell_contents, g)
160 def inner():
161 def inner():
161 return cell_contents
162 return cell_contents
162 return py3compat.get_closure(inner)[0]
163 return py3compat.get_closure(inner)[0]
163
164
164
165
165 class CannedFunction(CannedObject):
166 class CannedFunction(CannedObject):
166
167
167 def __init__(self, f):
168 def __init__(self, f):
168 self._check_type(f)
169 self._check_type(f)
169 self.code = f.__code__
170 self.code = f.__code__
170 if f.__defaults__:
171 if f.__defaults__:
171 self.defaults = [ can(fd) for fd in f.__defaults__ ]
172 self.defaults = [ can(fd) for fd in f.__defaults__ ]
172 else:
173 else:
173 self.defaults = None
174 self.defaults = None
174
175
175 closure = py3compat.get_closure(f)
176 closure = py3compat.get_closure(f)
176 if closure:
177 if closure:
177 self.closure = tuple( can(cell) for cell in closure )
178 self.closure = tuple( can(cell) for cell in closure )
178 else:
179 else:
179 self.closure = None
180 self.closure = None
180
181
181 self.module = f.__module__ or '__main__'
182 self.module = f.__module__ or '__main__'
182 self.__name__ = f.__name__
183 self.__name__ = f.__name__
183 self.buffers = []
184 self.buffers = []
184
185
185 def _check_type(self, obj):
186 def _check_type(self, obj):
186 assert isinstance(obj, FunctionType), "Not a function type"
187 assert isinstance(obj, FunctionType), "Not a function type"
187
188
188 def get_object(self, g=None):
189 def get_object(self, g=None):
189 # try to load function back into its module:
190 # try to load function back into its module:
190 if not self.module.startswith('__'):
191 if not self.module.startswith('__'):
191 __import__(self.module)
192 __import__(self.module)
192 g = sys.modules[self.module].__dict__
193 g = sys.modules[self.module].__dict__
193
194
194 if g is None:
195 if g is None:
195 g = {}
196 g = {}
196 if self.defaults:
197 if self.defaults:
197 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
198 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
198 else:
199 else:
199 defaults = None
200 defaults = None
200 if self.closure:
201 if self.closure:
201 closure = tuple(uncan(cell, g) for cell in self.closure)
202 closure = tuple(uncan(cell, g) for cell in self.closure)
202 else:
203 else:
203 closure = None
204 closure = None
204 newFunc = FunctionType(self.code, g, self.__name__, defaults, closure)
205 newFunc = FunctionType(self.code, g, self.__name__, defaults, closure)
205 return newFunc
206 return newFunc
206
207
207 class CannedClass(CannedObject):
208 class CannedClass(CannedObject):
208
209
209 def __init__(self, cls):
210 def __init__(self, cls):
210 self._check_type(cls)
211 self._check_type(cls)
211 self.name = cls.__name__
212 self.name = cls.__name__
212 self.old_style = not isinstance(cls, type)
213 self.old_style = not isinstance(cls, type)
213 self._canned_dict = {}
214 self._canned_dict = {}
214 for k,v in cls.__dict__.items():
215 for k,v in cls.__dict__.items():
215 if k not in ('__weakref__', '__dict__'):
216 if k not in ('__weakref__', '__dict__'):
216 self._canned_dict[k] = can(v)
217 self._canned_dict[k] = can(v)
217 if self.old_style:
218 if self.old_style:
218 mro = []
219 mro = []
219 else:
220 else:
220 mro = cls.mro()
221 mro = cls.mro()
221
222
222 self.parents = [ can(c) for c in mro[1:] ]
223 self.parents = [ can(c) for c in mro[1:] ]
223 self.buffers = []
224 self.buffers = []
224
225
225 def _check_type(self, obj):
226 def _check_type(self, obj):
226 assert isinstance(obj, class_type), "Not a class type"
227 assert isinstance(obj, class_type), "Not a class type"
227
228
228 def get_object(self, g=None):
229 def get_object(self, g=None):
229 parents = tuple(uncan(p, g) for p in self.parents)
230 parents = tuple(uncan(p, g) for p in self.parents)
230 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
231 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
231
232
232 class CannedArray(CannedObject):
233 class CannedArray(CannedObject):
233 def __init__(self, obj):
234 def __init__(self, obj):
234 from numpy import ascontiguousarray
235 from numpy import ascontiguousarray
235 self.shape = obj.shape
236 self.shape = obj.shape
236 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
237 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
237 self.pickled = False
238 self.pickled = False
238 if sum(obj.shape) == 0:
239 if sum(obj.shape) == 0:
239 self.pickled = True
240 self.pickled = True
240 elif obj.dtype == 'O':
241 elif obj.dtype == 'O':
241 # can't handle object dtype with buffer approach
242 # can't handle object dtype with buffer approach
242 self.pickled = True
243 self.pickled = True
243 elif obj.dtype.fields and any(dt == 'O' for dt,sz in obj.dtype.fields.values()):
244 elif obj.dtype.fields and any(dt == 'O' for dt,sz in obj.dtype.fields.values()):
244 self.pickled = True
245 self.pickled = True
245 if self.pickled:
246 if self.pickled:
246 # just pickle it
247 # just pickle it
247 self.buffers = [pickle.dumps(obj, -1)]
248 self.buffers = [pickle.dumps(obj, -1)]
248 else:
249 else:
249 # ensure contiguous
250 # ensure contiguous
250 obj = ascontiguousarray(obj, dtype=None)
251 obj = ascontiguousarray(obj, dtype=None)
251 self.buffers = [buffer(obj)]
252 self.buffers = [buffer(obj)]
252
253
253 def get_object(self, g=None):
254 def get_object(self, g=None):
254 from numpy import frombuffer
255 from numpy import frombuffer
255 data = self.buffers[0]
256 data = self.buffers[0]
256 if self.pickled:
257 if self.pickled:
257 # no shape, we just pickled it
258 # no shape, we just pickled it
258 return pickle.loads(data)
259 return pickle.loads(data)
259 else:
260 else:
260 return frombuffer(data, dtype=self.dtype).reshape(self.shape)
261 return frombuffer(data, dtype=self.dtype).reshape(self.shape)
261
262
262
263
263 class CannedBytes(CannedObject):
264 class CannedBytes(CannedObject):
264 wrap = bytes
265 wrap = bytes
265 def __init__(self, obj):
266 def __init__(self, obj):
266 self.buffers = [obj]
267 self.buffers = [obj]
267
268
268 def get_object(self, g=None):
269 def get_object(self, g=None):
269 data = self.buffers[0]
270 data = self.buffers[0]
270 return self.wrap(data)
271 return self.wrap(data)
271
272
272 def CannedBuffer(CannedBytes):
273 def CannedBuffer(CannedBytes):
273 wrap = buffer
274 wrap = buffer
274
275
275 #-------------------------------------------------------------------------------
276 #-------------------------------------------------------------------------------
276 # Functions
277 # Functions
277 #-------------------------------------------------------------------------------
278 #-------------------------------------------------------------------------------
278
279
279 def _logger():
280 """get the logger for the current Application
281
282 the root logger will be used if no Application is running
283 """
284 if Application.initialized():
285 logger = Application.instance().log
286 else:
287 logger = logging.getLogger()
288 if not logger.handlers:
289 logging.basicConfig()
290
291 return logger
292
293 def _import_mapping(mapping, original=None):
280 def _import_mapping(mapping, original=None):
294 """import any string-keys in a type mapping
281 """import any string-keys in a type mapping
295
282
296 """
283 """
297 log = _logger()
284 log = get_logger()
298 log.debug("Importing canning map")
285 log.debug("Importing canning map")
299 for key,value in list(mapping.items()):
286 for key,value in list(mapping.items()):
300 if isinstance(key, string_types):
287 if isinstance(key, string_types):
301 try:
288 try:
302 cls = import_item(key)
289 cls = import_item(key)
303 except Exception:
290 except Exception:
304 if original and key not in original:
291 if original and key not in original:
305 # only message on user-added classes
292 # only message on user-added classes
306 log.error("canning class not importable: %r", key, exc_info=True)
293 log.error("canning class not importable: %r", key, exc_info=True)
307 mapping.pop(key)
294 mapping.pop(key)
308 else:
295 else:
309 mapping[cls] = mapping.pop(key)
296 mapping[cls] = mapping.pop(key)
310
297
311 def istype(obj, check):
298 def istype(obj, check):
312 """like isinstance(obj, check), but strict
299 """like isinstance(obj, check), but strict
313
300
314 This won't catch subclasses.
301 This won't catch subclasses.
315 """
302 """
316 if isinstance(check, tuple):
303 if isinstance(check, tuple):
317 for cls in check:
304 for cls in check:
318 if type(obj) is cls:
305 if type(obj) is cls:
319 return True
306 return True
320 return False
307 return False
321 else:
308 else:
322 return type(obj) is check
309 return type(obj) is check
323
310
324 def can(obj):
311 def can(obj):
325 """prepare an object for pickling"""
312 """prepare an object for pickling"""
326
313
327 import_needed = False
314 import_needed = False
328
315
329 for cls,canner in iteritems(can_map):
316 for cls,canner in iteritems(can_map):
330 if isinstance(cls, string_types):
317 if isinstance(cls, string_types):
331 import_needed = True
318 import_needed = True
332 break
319 break
333 elif istype(obj, cls):
320 elif istype(obj, cls):
334 return canner(obj)
321 return canner(obj)
335
322
336 if import_needed:
323 if import_needed:
337 # perform can_map imports, then try again
324 # perform can_map imports, then try again
338 # this will usually only happen once
325 # this will usually only happen once
339 _import_mapping(can_map, _original_can_map)
326 _import_mapping(can_map, _original_can_map)
340 return can(obj)
327 return can(obj)
341
328
342 return obj
329 return obj
343
330
344 def can_class(obj):
331 def can_class(obj):
345 if isinstance(obj, class_type) and obj.__module__ == '__main__':
332 if isinstance(obj, class_type) and obj.__module__ == '__main__':
346 return CannedClass(obj)
333 return CannedClass(obj)
347 else:
334 else:
348 return obj
335 return obj
349
336
350 def can_dict(obj):
337 def can_dict(obj):
351 """can the *values* of a dict"""
338 """can the *values* of a dict"""
352 if istype(obj, dict):
339 if istype(obj, dict):
353 newobj = {}
340 newobj = {}
354 for k, v in iteritems(obj):
341 for k, v in iteritems(obj):
355 newobj[k] = can(v)
342 newobj[k] = can(v)
356 return newobj
343 return newobj
357 else:
344 else:
358 return obj
345 return obj
359
346
360 sequence_types = (list, tuple, set)
347 sequence_types = (list, tuple, set)
361
348
362 def can_sequence(obj):
349 def can_sequence(obj):
363 """can the elements of a sequence"""
350 """can the elements of a sequence"""
364 if istype(obj, sequence_types):
351 if istype(obj, sequence_types):
365 t = type(obj)
352 t = type(obj)
366 return t([can(i) for i in obj])
353 return t([can(i) for i in obj])
367 else:
354 else:
368 return obj
355 return obj
369
356
370 def uncan(obj, g=None):
357 def uncan(obj, g=None):
371 """invert canning"""
358 """invert canning"""
372
359
373 import_needed = False
360 import_needed = False
374 for cls,uncanner in iteritems(uncan_map):
361 for cls,uncanner in iteritems(uncan_map):
375 if isinstance(cls, string_types):
362 if isinstance(cls, string_types):
376 import_needed = True
363 import_needed = True
377 break
364 break
378 elif isinstance(obj, cls):
365 elif isinstance(obj, cls):
379 return uncanner(obj, g)
366 return uncanner(obj, g)
380
367
381 if import_needed:
368 if import_needed:
382 # perform uncan_map imports, then try again
369 # perform uncan_map imports, then try again
383 # this will usually only happen once
370 # this will usually only happen once
384 _import_mapping(uncan_map, _original_uncan_map)
371 _import_mapping(uncan_map, _original_uncan_map)
385 return uncan(obj, g)
372 return uncan(obj, g)
386
373
387 return obj
374 return obj
388
375
389 def uncan_dict(obj, g=None):
376 def uncan_dict(obj, g=None):
390 if istype(obj, dict):
377 if istype(obj, dict):
391 newobj = {}
378 newobj = {}
392 for k, v in iteritems(obj):
379 for k, v in iteritems(obj):
393 newobj[k] = uncan(v,g)
380 newobj[k] = uncan(v,g)
394 return newobj
381 return newobj
395 else:
382 else:
396 return obj
383 return obj
397
384
398 def uncan_sequence(obj, g=None):
385 def uncan_sequence(obj, g=None):
399 if istype(obj, sequence_types):
386 if istype(obj, sequence_types):
400 t = type(obj)
387 t = type(obj)
401 return t([uncan(i,g) for i in obj])
388 return t([uncan(i,g) for i in obj])
402 else:
389 else:
403 return obj
390 return obj
404
391
405 def _uncan_dependent_hook(dep, g=None):
392 def _uncan_dependent_hook(dep, g=None):
406 dep.check_dependency()
393 dep.check_dependency()
407
394
408 def can_dependent(obj):
395 def can_dependent(obj):
409 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
396 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
410
397
411 #-------------------------------------------------------------------------------
398 #-------------------------------------------------------------------------------
412 # API dictionaries
399 # API dictionaries
413 #-------------------------------------------------------------------------------
400 #-------------------------------------------------------------------------------
414
401
415 # These dicts can be extended for custom serialization of new objects
402 # These dicts can be extended for custom serialization of new objects
416
403
417 can_map = {
404 can_map = {
418 'IPython.parallel.dependent' : can_dependent,
405 'IPython.parallel.dependent' : can_dependent,
419 'numpy.ndarray' : CannedArray,
406 'numpy.ndarray' : CannedArray,
420 FunctionType : CannedFunction,
407 FunctionType : CannedFunction,
421 bytes : CannedBytes,
408 bytes : CannedBytes,
422 buffer : CannedBuffer,
409 buffer : CannedBuffer,
423 cell_type : CannedCell,
410 cell_type : CannedCell,
424 class_type : can_class,
411 class_type : can_class,
425 }
412 }
426
413
427 uncan_map = {
414 uncan_map = {
428 CannedObject : lambda obj, g: obj.get_object(g),
415 CannedObject : lambda obj, g: obj.get_object(g),
429 }
416 }
430
417
431 # for use in _import_mapping:
418 # for use in _import_mapping:
432 _original_can_map = can_map.copy()
419 _original_can_map = can_map.copy()
433 _original_uncan_map = uncan_map.copy()
420 _original_uncan_map = uncan_map.copy()
General Comments 0
You need to be logged in to leave comments. Login now