##// END OF EJS Templates
Merge pull request #6044 from minrk/core.log...
Thomas Kluyver -
r17065:0c7348f7 merge
parent child Browse files
Show More
@@ -0,0 +1,25 b''
1 """Grab the global logger instance."""
2
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
5
6 import logging
7
8 _logger = None
9
10 def get_logger():
11 """Grab the global logger instance.
12
13 If a global IPython Application is instantiated, grab its logger.
14 Otherwise, grab the root logger.
15 """
16 global _logger
17
18 if _logger is None:
19 from IPython.config import Application
20 if Application.initialized():
21 _logger = Application.instance().log
22 else:
23 logging.basicConfig()
24 _logger = logging.getLogger()
25 return _logger
@@ -1,390 +1,366 b''
1 1 # encoding: utf-8
2 """
3 A base class for objects that are configurable.
4
5 Inheritance diagram:
6
7 .. inheritance-diagram:: IPython.config.configurable
8 :parts: 3
2 """A base class for objects that are configurable."""
9 3
10 Authors:
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
11 6
12 * Brian Granger
13 * Fernando Perez
14 * Min RK
15 """
16 7 from __future__ import print_function
17 8
18 #-----------------------------------------------------------------------------
19 # Copyright (C) 2008-2011 The IPython Development Team
20 #
21 # Distributed under the terms of the BSD License. The full license is in
22 # the file COPYING, distributed as part of this software.
23 #-----------------------------------------------------------------------------
24
25 #-----------------------------------------------------------------------------
26 # Imports
27 #-----------------------------------------------------------------------------
28
29 9 import logging
30 10 from copy import deepcopy
31 11
32 12 from .loader import Config, LazyConfigValue
33 13 from IPython.utils.traitlets import HasTraits, Instance
34 14 from IPython.utils.text import indent, wrap_paragraphs
35 15 from IPython.utils.py3compat import iteritems
36 16
37 17
38 18 #-----------------------------------------------------------------------------
39 19 # Helper classes for Configurables
40 20 #-----------------------------------------------------------------------------
41 21
42 22
43 23 class ConfigurableError(Exception):
44 24 pass
45 25
46 26
47 27 class MultipleInstanceError(ConfigurableError):
48 28 pass
49 29
50 30 #-----------------------------------------------------------------------------
51 31 # Configurable implementation
52 32 #-----------------------------------------------------------------------------
53 33
54 34 class Configurable(HasTraits):
55 35
56 36 config = Instance(Config, (), {})
57 37 parent = Instance('IPython.config.configurable.Configurable')
58 38
59 39 def __init__(self, **kwargs):
60 40 """Create a configurable given a config config.
61 41
62 42 Parameters
63 43 ----------
64 44 config : Config
65 45 If this is empty, default values are used. If config is a
66 46 :class:`Config` instance, it will be used to configure the
67 47 instance.
68 48 parent : Configurable instance, optional
69 49 The parent Configurable instance of this object.
70 50
71 51 Notes
72 52 -----
73 53 Subclasses of Configurable must call the :meth:`__init__` method of
74 54 :class:`Configurable` *before* doing anything else and using
75 55 :func:`super`::
76 56
77 57 class MyConfigurable(Configurable):
78 58 def __init__(self, config=None):
79 59 super(MyConfigurable, self).__init__(config=config)
80 60 # Then any other code you need to finish initialization.
81 61
82 62 This ensures that instances will be configured properly.
83 63 """
84 64 parent = kwargs.pop('parent', None)
85 65 if parent is not None:
86 66 # config is implied from parent
87 67 if kwargs.get('config', None) is None:
88 68 kwargs['config'] = parent.config
89 69 self.parent = parent
90 70
91 71 config = kwargs.pop('config', None)
92 72 if config is not None:
93 73 # We used to deepcopy, but for now we are trying to just save
94 74 # by reference. This *could* have side effects as all components
95 75 # will share config. In fact, I did find such a side effect in
96 76 # _config_changed below. If a config attribute value was a mutable type
97 77 # all instances of a component were getting the same copy, effectively
98 78 # making that a class attribute.
99 79 # self.config = deepcopy(config)
100 80 self.config = config
101 81 # This should go second so individual keyword arguments override
102 82 # the values in config.
103 83 super(Configurable, self).__init__(**kwargs)
104 84
105 85 #-------------------------------------------------------------------------
106 86 # Static trait notifiations
107 87 #-------------------------------------------------------------------------
108 88
109 89 @classmethod
110 90 def section_names(cls):
111 91 """return section names as a list"""
112 92 return [c.__name__ for c in reversed(cls.__mro__) if
113 93 issubclass(c, Configurable) and issubclass(cls, c)
114 94 ]
115 95
116 96 def _find_my_config(self, cfg):
117 97 """extract my config from a global Config object
118 98
119 99 will construct a Config object of only the config values that apply to me
120 100 based on my mro(), as well as those of my parent(s) if they exist.
121 101
122 102 If I am Bar and my parent is Foo, and their parent is Tim,
123 103 this will return merge following config sections, in this order::
124 104
125 105 [Bar, Foo.bar, Tim.Foo.Bar]
126 106
127 107 With the last item being the highest priority.
128 108 """
129 109 cfgs = [cfg]
130 110 if self.parent:
131 111 cfgs.append(self.parent._find_my_config(cfg))
132 112 my_config = Config()
133 113 for c in cfgs:
134 114 for sname in self.section_names():
135 115 # Don't do a blind getattr as that would cause the config to
136 116 # dynamically create the section with name Class.__name__.
137 117 if c._has_section(sname):
138 118 my_config.merge(c[sname])
139 119 return my_config
140 120
141 121 def _load_config(self, cfg, section_names=None, traits=None):
142 122 """load traits from a Config object"""
143 123
144 124 if traits is None:
145 125 traits = self.traits(config=True)
146 126 if section_names is None:
147 127 section_names = self.section_names()
148 128
149 129 my_config = self._find_my_config(cfg)
150 130 for name, config_value in iteritems(my_config):
151 131 if name in traits:
152 132 if isinstance(config_value, LazyConfigValue):
153 133 # ConfigValue is a wrapper for using append / update on containers
154 134 # without having to copy the
155 135 initial = getattr(self, name)
156 136 config_value = config_value.get_value(initial)
157 137 # We have to do a deepcopy here if we don't deepcopy the entire
158 138 # config object. If we don't, a mutable config_value will be
159 139 # shared by all instances, effectively making it a class attribute.
160 140 setattr(self, name, deepcopy(config_value))
161 141
162 142 def _config_changed(self, name, old, new):
163 143 """Update all the class traits having ``config=True`` as metadata.
164 144
165 145 For any class trait with a ``config`` metadata attribute that is
166 146 ``True``, we update the trait with the value of the corresponding
167 147 config entry.
168 148 """
169 149 # Get all traits with a config metadata entry that is True
170 150 traits = self.traits(config=True)
171 151
172 152 # We auto-load config section for this class as well as any parent
173 153 # classes that are Configurable subclasses. This starts with Configurable
174 154 # and works down the mro loading the config for each section.
175 155 section_names = self.section_names()
176 156 self._load_config(new, traits=traits, section_names=section_names)
177 157
178 158 def update_config(self, config):
179 159 """Fire the traits events when the config is updated."""
180 160 # Save a copy of the current config.
181 161 newconfig = deepcopy(self.config)
182 162 # Merge the new config into the current one.
183 163 newconfig.merge(config)
184 164 # Save the combined config as self.config, which triggers the traits
185 165 # events.
186 166 self.config = newconfig
187 167
188 168 @classmethod
189 169 def class_get_help(cls, inst=None):
190 170 """Get the help string for this class in ReST format.
191 171
192 172 If `inst` is given, it's current trait values will be used in place of
193 173 class defaults.
194 174 """
195 175 assert inst is None or isinstance(inst, cls)
196 176 final_help = []
197 177 final_help.append(u'%s options' % cls.__name__)
198 178 final_help.append(len(final_help[0])*u'-')
199 179 for k, v in sorted(cls.class_traits(config=True).items()):
200 180 help = cls.class_get_trait_help(v, inst)
201 181 final_help.append(help)
202 182 return '\n'.join(final_help)
203 183
204 184 @classmethod
205 185 def class_get_trait_help(cls, trait, inst=None):
206 186 """Get the help string for a single trait.
207 187
208 188 If `inst` is given, it's current trait values will be used in place of
209 189 the class default.
210 190 """
211 191 assert inst is None or isinstance(inst, cls)
212 192 lines = []
213 193 header = "--%s.%s=<%s>" % (cls.__name__, trait.name, trait.__class__.__name__)
214 194 lines.append(header)
215 195 if inst is not None:
216 196 lines.append(indent('Current: %r' % getattr(inst, trait.name), 4))
217 197 else:
218 198 try:
219 199 dvr = repr(trait.get_default_value())
220 200 except Exception:
221 201 dvr = None # ignore defaults we can't construct
222 202 if dvr is not None:
223 203 if len(dvr) > 64:
224 204 dvr = dvr[:61]+'...'
225 205 lines.append(indent('Default: %s' % dvr, 4))
226 206 if 'Enum' in trait.__class__.__name__:
227 207 # include Enum choices
228 208 lines.append(indent('Choices: %r' % (trait.values,)))
229 209
230 210 help = trait.get_metadata('help')
231 211 if help is not None:
232 212 help = '\n'.join(wrap_paragraphs(help, 76))
233 213 lines.append(indent(help, 4))
234 214 return '\n'.join(lines)
235 215
236 216 @classmethod
237 217 def class_print_help(cls, inst=None):
238 218 """Get the help string for a single trait and print it."""
239 219 print(cls.class_get_help(inst))
240 220
241 221 @classmethod
242 222 def class_config_section(cls):
243 223 """Get the config class config section"""
244 224 def c(s):
245 225 """return a commented, wrapped block."""
246 226 s = '\n\n'.join(wrap_paragraphs(s, 78))
247 227
248 228 return '# ' + s.replace('\n', '\n# ')
249 229
250 230 # section header
251 231 breaker = '#' + '-'*78
252 232 s = "# %s configuration" % cls.__name__
253 233 lines = [breaker, s, breaker, '']
254 234 # get the description trait
255 235 desc = cls.class_traits().get('description')
256 236 if desc:
257 237 desc = desc.default_value
258 238 else:
259 239 # no description trait, use __doc__
260 240 desc = getattr(cls, '__doc__', '')
261 241 if desc:
262 242 lines.append(c(desc))
263 243 lines.append('')
264 244
265 245 parents = []
266 246 for parent in cls.mro():
267 247 # only include parents that are not base classes
268 248 # and are not the class itself
269 249 # and have some configurable traits to inherit
270 250 if parent is not cls and issubclass(parent, Configurable) and \
271 251 parent.class_traits(config=True):
272 252 parents.append(parent)
273 253
274 254 if parents:
275 255 pstr = ', '.join([ p.__name__ for p in parents ])
276 256 lines.append(c('%s will inherit config from: %s'%(cls.__name__, pstr)))
277 257 lines.append('')
278 258
279 259 for name, trait in iteritems(cls.class_traits(config=True)):
280 260 help = trait.get_metadata('help') or ''
281 261 lines.append(c(help))
282 262 lines.append('# c.%s.%s = %r'%(cls.__name__, name, trait.get_default_value()))
283 263 lines.append('')
284 264 return '\n'.join(lines)
285 265
286 266
287 267
288 268 class SingletonConfigurable(Configurable):
289 269 """A configurable that only allows one instance.
290 270
291 271 This class is for classes that should only have one instance of itself
292 272 or *any* subclass. To create and retrieve such a class use the
293 273 :meth:`SingletonConfigurable.instance` method.
294 274 """
295 275
296 276 _instance = None
297 277
298 278 @classmethod
299 279 def _walk_mro(cls):
300 280 """Walk the cls.mro() for parent classes that are also singletons
301 281
302 282 For use in instance()
303 283 """
304 284
305 285 for subclass in cls.mro():
306 286 if issubclass(cls, subclass) and \
307 287 issubclass(subclass, SingletonConfigurable) and \
308 288 subclass != SingletonConfigurable:
309 289 yield subclass
310 290
311 291 @classmethod
312 292 def clear_instance(cls):
313 293 """unset _instance for this class and singleton parents.
314 294 """
315 295 if not cls.initialized():
316 296 return
317 297 for subclass in cls._walk_mro():
318 298 if isinstance(subclass._instance, cls):
319 299 # only clear instances that are instances
320 300 # of the calling class
321 301 subclass._instance = None
322 302
323 303 @classmethod
324 304 def instance(cls, *args, **kwargs):
325 305 """Returns a global instance of this class.
326 306
327 307 This method create a new instance if none have previously been created
328 308 and returns a previously created instance is one already exists.
329 309
330 310 The arguments and keyword arguments passed to this method are passed
331 311 on to the :meth:`__init__` method of the class upon instantiation.
332 312
333 313 Examples
334 314 --------
335 315
336 316 Create a singleton class using instance, and retrieve it::
337 317
338 318 >>> from IPython.config.configurable import SingletonConfigurable
339 319 >>> class Foo(SingletonConfigurable): pass
340 320 >>> foo = Foo.instance()
341 321 >>> foo == Foo.instance()
342 322 True
343 323
344 324 Create a subclass that is retrived using the base class instance::
345 325
346 326 >>> class Bar(SingletonConfigurable): pass
347 327 >>> class Bam(Bar): pass
348 328 >>> bam = Bam.instance()
349 329 >>> bam == Bar.instance()
350 330 True
351 331 """
352 332 # Create and save the instance
353 333 if cls._instance is None:
354 334 inst = cls(*args, **kwargs)
355 335 # Now make sure that the instance will also be returned by
356 336 # parent classes' _instance attribute.
357 337 for subclass in cls._walk_mro():
358 338 subclass._instance = inst
359 339
360 340 if isinstance(cls._instance, cls):
361 341 return cls._instance
362 342 else:
363 343 raise MultipleInstanceError(
364 344 'Multiple incompatible subclass instances of '
365 345 '%s are being created.' % cls.__name__
366 346 )
367 347
368 348 @classmethod
369 349 def initialized(cls):
370 350 """Has an instance been created?"""
371 351 return hasattr(cls, "_instance") and cls._instance is not None
372 352
373 353
374 354 class LoggingConfigurable(Configurable):
375 355 """A parent class for Configurables that log.
376 356
377 357 Subclasses have a log trait, and the default behavior
378 is to get the logger from the currently running Application
379 via Application.instance().log.
358 is to get the logger from the currently running Application.
380 359 """
381 360
382 361 log = Instance('logging.Logger')
383 362 def _log_default(self):
384 from IPython.config.application import Application
385 if Application.initialized():
386 return Application.instance().log
387 else:
388 return logging.getLogger()
363 from IPython.utils import log
364 return log.get_logger()
389 365
390 366
@@ -1,846 +1,824 b''
1 """A simple configuration system.
1 # encoding: utf-8
2 """A simple configuration system."""
2 3
3 Inheritance diagram:
4
5 .. inheritance-diagram:: IPython.config.loader
6 :parts: 3
7
8 Authors
9 -------
10 * Brian Granger
11 * Fernando Perez
12 * Min RK
13 """
14
15 #-----------------------------------------------------------------------------
16 # Copyright (C) 2008-2011 The IPython Development Team
17 #
18 # Distributed under the terms of the BSD License. The full license is in
19 # the file COPYING, distributed as part of this software.
20 #-----------------------------------------------------------------------------
21
22 #-----------------------------------------------------------------------------
23 # Imports
24 #-----------------------------------------------------------------------------
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
25 6
26 7 import argparse
27 8 import copy
28 9 import logging
29 10 import os
30 11 import re
31 12 import sys
32 13 import json
33 14
34 15 from IPython.utils.path import filefind, get_ipython_dir
35 16 from IPython.utils import py3compat
36 17 from IPython.utils.encoding import DEFAULT_ENCODING
37 18 from IPython.utils.py3compat import unicode_type, iteritems
38 19 from IPython.utils.traitlets import HasTraits, List, Any
39 20
40 21 #-----------------------------------------------------------------------------
41 22 # Exceptions
42 23 #-----------------------------------------------------------------------------
43 24
44 25
45 26 class ConfigError(Exception):
46 27 pass
47 28
48 29 class ConfigLoaderError(ConfigError):
49 30 pass
50 31
51 32 class ConfigFileNotFound(ConfigError):
52 33 pass
53 34
54 35 class ArgumentError(ConfigLoaderError):
55 36 pass
56 37
57 38 #-----------------------------------------------------------------------------
58 39 # Argparse fix
59 40 #-----------------------------------------------------------------------------
60 41
61 42 # Unfortunately argparse by default prints help messages to stderr instead of
62 43 # stdout. This makes it annoying to capture long help screens at the command
63 44 # line, since one must know how to pipe stderr, which many users don't know how
64 45 # to do. So we override the print_help method with one that defaults to
65 46 # stdout and use our class instead.
66 47
67 48 class ArgumentParser(argparse.ArgumentParser):
68 49 """Simple argparse subclass that prints help to stdout by default."""
69 50
70 51 def print_help(self, file=None):
71 52 if file is None:
72 53 file = sys.stdout
73 54 return super(ArgumentParser, self).print_help(file)
74 55
75 56 print_help.__doc__ = argparse.ArgumentParser.print_help.__doc__
76 57
77 58 #-----------------------------------------------------------------------------
78 59 # Config class for holding config information
79 60 #-----------------------------------------------------------------------------
80 61
81 62 class LazyConfigValue(HasTraits):
82 63 """Proxy object for exposing methods on configurable containers
83 64
84 65 Exposes:
85 66
86 67 - append, extend, insert on lists
87 68 - update on dicts
88 69 - update, add on sets
89 70 """
90 71
91 72 _value = None
92 73
93 74 # list methods
94 75 _extend = List()
95 76 _prepend = List()
96 77
97 78 def append(self, obj):
98 79 self._extend.append(obj)
99 80
100 81 def extend(self, other):
101 82 self._extend.extend(other)
102 83
103 84 def prepend(self, other):
104 85 """like list.extend, but for the front"""
105 86 self._prepend[:0] = other
106 87
107 88 _inserts = List()
108 89 def insert(self, index, other):
109 90 if not isinstance(index, int):
110 91 raise TypeError("An integer is required")
111 92 self._inserts.append((index, other))
112 93
113 94 # dict methods
114 95 # update is used for both dict and set
115 96 _update = Any()
116 97 def update(self, other):
117 98 if self._update is None:
118 99 if isinstance(other, dict):
119 100 self._update = {}
120 101 else:
121 102 self._update = set()
122 103 self._update.update(other)
123 104
124 105 # set methods
125 106 def add(self, obj):
126 107 self.update({obj})
127 108
128 109 def get_value(self, initial):
129 110 """construct the value from the initial one
130 111
131 112 after applying any insert / extend / update changes
132 113 """
133 114 if self._value is not None:
134 115 return self._value
135 116 value = copy.deepcopy(initial)
136 117 if isinstance(value, list):
137 118 for idx, obj in self._inserts:
138 119 value.insert(idx, obj)
139 120 value[:0] = self._prepend
140 121 value.extend(self._extend)
141 122
142 123 elif isinstance(value, dict):
143 124 if self._update:
144 125 value.update(self._update)
145 126 elif isinstance(value, set):
146 127 if self._update:
147 128 value.update(self._update)
148 129 self._value = value
149 130 return value
150 131
151 132 def to_dict(self):
152 133 """return JSONable dict form of my data
153 134
154 135 Currently update as dict or set, extend, prepend as lists, and inserts as list of tuples.
155 136 """
156 137 d = {}
157 138 if self._update:
158 139 d['update'] = self._update
159 140 if self._extend:
160 141 d['extend'] = self._extend
161 142 if self._prepend:
162 143 d['prepend'] = self._prepend
163 144 elif self._inserts:
164 145 d['inserts'] = self._inserts
165 146 return d
166 147
167 148
168 149 def _is_section_key(key):
169 150 """Is a Config key a section name (does it start with a capital)?"""
170 151 if key and key[0].upper()==key[0] and not key.startswith('_'):
171 152 return True
172 153 else:
173 154 return False
174 155
175 156
176 157 class Config(dict):
177 158 """An attribute based dict that can do smart merges."""
178 159
179 160 def __init__(self, *args, **kwds):
180 161 dict.__init__(self, *args, **kwds)
181 162 self._ensure_subconfig()
182 163
183 164 def _ensure_subconfig(self):
184 165 """ensure that sub-dicts that should be Config objects are
185 166
186 167 casts dicts that are under section keys to Config objects,
187 168 which is necessary for constructing Config objects from dict literals.
188 169 """
189 170 for key in self:
190 171 obj = self[key]
191 172 if _is_section_key(key) \
192 173 and isinstance(obj, dict) \
193 174 and not isinstance(obj, Config):
194 175 setattr(self, key, Config(obj))
195 176
196 177 def _merge(self, other):
197 178 """deprecated alias, use Config.merge()"""
198 179 self.merge(other)
199 180
200 181 def merge(self, other):
201 182 """merge another config object into this one"""
202 183 to_update = {}
203 184 for k, v in iteritems(other):
204 185 if k not in self:
205 186 to_update[k] = copy.deepcopy(v)
206 187 else: # I have this key
207 188 if isinstance(v, Config) and isinstance(self[k], Config):
208 189 # Recursively merge common sub Configs
209 190 self[k].merge(v)
210 191 else:
211 192 # Plain updates for non-Configs
212 193 to_update[k] = copy.deepcopy(v)
213 194
214 195 self.update(to_update)
215 196
216 197 def __contains__(self, key):
217 198 # allow nested contains of the form `"Section.key" in config`
218 199 if '.' in key:
219 200 first, remainder = key.split('.', 1)
220 201 if first not in self:
221 202 return False
222 203 return remainder in self[first]
223 204
224 205 return super(Config, self).__contains__(key)
225 206
226 207 # .has_key is deprecated for dictionaries.
227 208 has_key = __contains__
228 209
229 210 def _has_section(self, key):
230 211 return _is_section_key(key) and key in self
231 212
232 213 def copy(self):
233 214 return type(self)(dict.copy(self))
234 215
235 216 def __copy__(self):
236 217 return self.copy()
237 218
238 219 def __deepcopy__(self, memo):
239 220 import copy
240 221 return type(self)(copy.deepcopy(list(self.items())))
241 222
242 223 def __getitem__(self, key):
243 224 try:
244 225 return dict.__getitem__(self, key)
245 226 except KeyError:
246 227 if _is_section_key(key):
247 228 c = Config()
248 229 dict.__setitem__(self, key, c)
249 230 return c
250 231 elif not key.startswith('_'):
251 232 # undefined, create lazy value, used for container methods
252 233 v = LazyConfigValue()
253 234 dict.__setitem__(self, key, v)
254 235 return v
255 236 else:
256 237 raise KeyError
257 238
258 239 def __setitem__(self, key, value):
259 240 if _is_section_key(key):
260 241 if not isinstance(value, Config):
261 242 raise ValueError('values whose keys begin with an uppercase '
262 243 'char must be Config instances: %r, %r' % (key, value))
263 244 dict.__setitem__(self, key, value)
264 245
265 246 def __getattr__(self, key):
266 247 if key.startswith('__'):
267 248 return dict.__getattr__(self, key)
268 249 try:
269 250 return self.__getitem__(key)
270 251 except KeyError as e:
271 252 raise AttributeError(e)
272 253
273 254 def __setattr__(self, key, value):
274 255 if key.startswith('__'):
275 256 return dict.__setattr__(self, key, value)
276 257 try:
277 258 self.__setitem__(key, value)
278 259 except KeyError as e:
279 260 raise AttributeError(e)
280 261
281 262 def __delattr__(self, key):
282 263 if key.startswith('__'):
283 264 return dict.__delattr__(self, key)
284 265 try:
285 266 dict.__delitem__(self, key)
286 267 except KeyError as e:
287 268 raise AttributeError(e)
288 269
289 270
290 271 #-----------------------------------------------------------------------------
291 272 # Config loading classes
292 273 #-----------------------------------------------------------------------------
293 274
294 275
295 276 class ConfigLoader(object):
296 277 """A object for loading configurations from just about anywhere.
297 278
298 279 The resulting configuration is packaged as a :class:`Config`.
299 280
300 281 Notes
301 282 -----
302 283 A :class:`ConfigLoader` does one thing: load a config from a source
303 284 (file, command line arguments) and returns the data as a :class:`Config` object.
304 285 There are lots of things that :class:`ConfigLoader` does not do. It does
305 286 not implement complex logic for finding config files. It does not handle
306 287 default values or merge multiple configs. These things need to be
307 288 handled elsewhere.
308 289 """
309 290
310 291 def _log_default(self):
311 from IPython.config.application import Application
312 if Application.initialized():
313 return Application.instance().log
314 else:
315 return logging.getLogger()
292 from IPython.utils.log import get_logger
293 return get_logger()
316 294
317 295 def __init__(self, log=None):
318 296 """A base class for config loaders.
319 297
320 298 log : instance of :class:`logging.Logger` to use.
321 299 By default loger of :meth:`IPython.config.application.Application.instance()`
322 300 will be used
323 301
324 302 Examples
325 303 --------
326 304
327 305 >>> cl = ConfigLoader()
328 306 >>> config = cl.load_config()
329 307 >>> config
330 308 {}
331 309 """
332 310 self.clear()
333 311 if log is None:
334 312 self.log = self._log_default()
335 313 self.log.debug('Using default logger')
336 314 else:
337 315 self.log = log
338 316
339 317 def clear(self):
340 318 self.config = Config()
341 319
342 320 def load_config(self):
343 321 """Load a config from somewhere, return a :class:`Config` instance.
344 322
345 323 Usually, this will cause self.config to be set and then returned.
346 324 However, in most cases, :meth:`ConfigLoader.clear` should be called
347 325 to erase any previous state.
348 326 """
349 327 self.clear()
350 328 return self.config
351 329
352 330
353 331 class FileConfigLoader(ConfigLoader):
354 332 """A base class for file based configurations.
355 333
356 334 As we add more file based config loaders, the common logic should go
357 335 here.
358 336 """
359 337
360 338 def __init__(self, filename, path=None, **kw):
361 339 """Build a config loader for a filename and path.
362 340
363 341 Parameters
364 342 ----------
365 343 filename : str
366 344 The file name of the config file.
367 345 path : str, list, tuple
368 346 The path to search for the config file on, or a sequence of
369 347 paths to try in order.
370 348 """
371 349 super(FileConfigLoader, self).__init__(**kw)
372 350 self.filename = filename
373 351 self.path = path
374 352 self.full_filename = ''
375 353
376 354 def _find_file(self):
377 355 """Try to find the file by searching the paths."""
378 356 self.full_filename = filefind(self.filename, self.path)
379 357
380 358 class JSONFileConfigLoader(FileConfigLoader):
381 359 """A Json file loader for config"""
382 360
383 361 def load_config(self):
384 362 """Load the config from a file and return it as a Config object."""
385 363 self.clear()
386 364 try:
387 365 self._find_file()
388 366 except IOError as e:
389 367 raise ConfigFileNotFound(str(e))
390 368 dct = self._read_file_as_dict()
391 369 self.config = self._convert_to_config(dct)
392 370 return self.config
393 371
394 372 def _read_file_as_dict(self):
395 373 with open(self.full_filename) as f:
396 374 return json.load(f)
397 375
398 376 def _convert_to_config(self, dictionary):
399 377 if 'version' in dictionary:
400 378 version = dictionary.pop('version')
401 379 else:
402 380 version = 1
403 381 self.log.warn("Unrecognized JSON config file version, assuming version {}".format(version))
404 382
405 383 if version == 1:
406 384 return Config(dictionary)
407 385 else:
408 386 raise ValueError('Unknown version of JSON config file: {version}'.format(version=version))
409 387
410 388
411 389 class PyFileConfigLoader(FileConfigLoader):
412 390 """A config loader for pure python files.
413 391
414 392 This is responsible for locating a Python config file by filename and
415 393 path, then executing it to construct a Config object.
416 394 """
417 395
418 396 def load_config(self):
419 397 """Load the config from a file and return it as a Config object."""
420 398 self.clear()
421 399 try:
422 400 self._find_file()
423 401 except IOError as e:
424 402 raise ConfigFileNotFound(str(e))
425 403 self._read_file_as_dict()
426 404 return self.config
427 405
428 406
429 407 def _read_file_as_dict(self):
430 408 """Load the config file into self.config, with recursive loading."""
431 409 # This closure is made available in the namespace that is used
432 410 # to exec the config file. It allows users to call
433 411 # load_subconfig('myconfig.py') to load config files recursively.
434 412 # It needs to be a closure because it has references to self.path
435 413 # and self.config. The sub-config is loaded with the same path
436 414 # as the parent, but it uses an empty config which is then merged
437 415 # with the parents.
438 416
439 417 # If a profile is specified, the config file will be loaded
440 418 # from that profile
441 419
442 420 def load_subconfig(fname, profile=None):
443 421 # import here to prevent circular imports
444 422 from IPython.core.profiledir import ProfileDir, ProfileDirError
445 423 if profile is not None:
446 424 try:
447 425 profile_dir = ProfileDir.find_profile_dir_by_name(
448 426 get_ipython_dir(),
449 427 profile,
450 428 )
451 429 except ProfileDirError:
452 430 return
453 431 path = profile_dir.location
454 432 else:
455 433 path = self.path
456 434 loader = PyFileConfigLoader(fname, path)
457 435 try:
458 436 sub_config = loader.load_config()
459 437 except ConfigFileNotFound:
460 438 # Pass silently if the sub config is not there. This happens
461 439 # when a user s using a profile, but not the default config.
462 440 pass
463 441 else:
464 442 self.config.merge(sub_config)
465 443
466 444 # Again, this needs to be a closure and should be used in config
467 445 # files to get the config being loaded.
468 446 def get_config():
469 447 return self.config
470 448
471 449 namespace = dict(
472 450 load_subconfig=load_subconfig,
473 451 get_config=get_config,
474 452 __file__=self.full_filename,
475 453 )
476 454 fs_encoding = sys.getfilesystemencoding() or 'ascii'
477 455 conf_filename = self.full_filename.encode(fs_encoding)
478 456 py3compat.execfile(conf_filename, namespace)
479 457
480 458
481 459 class CommandLineConfigLoader(ConfigLoader):
482 460 """A config loader for command line arguments.
483 461
484 462 As we add more command line based loaders, the common logic should go
485 463 here.
486 464 """
487 465
488 466 def _exec_config_str(self, lhs, rhs):
489 467 """execute self.config.<lhs> = <rhs>
490 468
491 469 * expands ~ with expanduser
492 470 * tries to assign with raw eval, otherwise assigns with just the string,
493 471 allowing `--C.a=foobar` and `--C.a="foobar"` to be equivalent. *Not*
494 472 equivalent are `--C.a=4` and `--C.a='4'`.
495 473 """
496 474 rhs = os.path.expanduser(rhs)
497 475 try:
498 476 # Try to see if regular Python syntax will work. This
499 477 # won't handle strings as the quote marks are removed
500 478 # by the system shell.
501 479 value = eval(rhs)
502 480 except (NameError, SyntaxError):
503 481 # This case happens if the rhs is a string.
504 482 value = rhs
505 483
506 484 exec(u'self.config.%s = value' % lhs)
507 485
508 486 def _load_flag(self, cfg):
509 487 """update self.config from a flag, which can be a dict or Config"""
510 488 if isinstance(cfg, (dict, Config)):
511 489 # don't clobber whole config sections, update
512 490 # each section from config:
513 491 for sec,c in iteritems(cfg):
514 492 self.config[sec].update(c)
515 493 else:
516 494 raise TypeError("Invalid flag: %r" % cfg)
517 495
518 496 # raw --identifier=value pattern
519 497 # but *also* accept '-' as wordsep, for aliases
520 498 # accepts: --foo=a
521 499 # --Class.trait=value
522 500 # --alias-name=value
523 501 # rejects: -foo=value
524 502 # --foo
525 503 # --Class.trait
526 504 kv_pattern = re.compile(r'\-\-[A-Za-z][\w\-]*(\.[\w\-]+)*\=.*')
527 505
528 506 # just flags, no assignments, with two *or one* leading '-'
529 507 # accepts: --foo
530 508 # -foo-bar-again
531 509 # rejects: --anything=anything
532 510 # --two.word
533 511
534 512 flag_pattern = re.compile(r'\-\-?\w+[\-\w]*$')
535 513
536 514 class KeyValueConfigLoader(CommandLineConfigLoader):
537 515 """A config loader that loads key value pairs from the command line.
538 516
539 517 This allows command line options to be gives in the following form::
540 518
541 519 ipython --profile="foo" --InteractiveShell.autocall=False
542 520 """
543 521
544 522 def __init__(self, argv=None, aliases=None, flags=None, **kw):
545 523 """Create a key value pair config loader.
546 524
547 525 Parameters
548 526 ----------
549 527 argv : list
550 528 A list that has the form of sys.argv[1:] which has unicode
551 529 elements of the form u"key=value". If this is None (default),
552 530 then sys.argv[1:] will be used.
553 531 aliases : dict
554 532 A dict of aliases for configurable traits.
555 533 Keys are the short aliases, Values are the resolved trait.
556 534 Of the form: `{'alias' : 'Configurable.trait'}`
557 535 flags : dict
558 536 A dict of flags, keyed by str name. Vaues can be Config objects,
559 537 dicts, or "key=value" strings. If Config or dict, when the flag
560 538 is triggered, The flag is loaded as `self.config.update(m)`.
561 539
562 540 Returns
563 541 -------
564 542 config : Config
565 543 The resulting Config object.
566 544
567 545 Examples
568 546 --------
569 547
570 548 >>> from IPython.config.loader import KeyValueConfigLoader
571 549 >>> cl = KeyValueConfigLoader()
572 550 >>> d = cl.load_config(["--A.name='brian'","--B.number=0"])
573 551 >>> sorted(d.items())
574 552 [('A', {'name': 'brian'}), ('B', {'number': 0})]
575 553 """
576 554 super(KeyValueConfigLoader, self).__init__(**kw)
577 555 if argv is None:
578 556 argv = sys.argv[1:]
579 557 self.argv = argv
580 558 self.aliases = aliases or {}
581 559 self.flags = flags or {}
582 560
583 561
584 562 def clear(self):
585 563 super(KeyValueConfigLoader, self).clear()
586 564 self.extra_args = []
587 565
588 566
589 567 def _decode_argv(self, argv, enc=None):
590 568 """decode argv if bytes, using stin.encoding, falling back on default enc"""
591 569 uargv = []
592 570 if enc is None:
593 571 enc = DEFAULT_ENCODING
594 572 for arg in argv:
595 573 if not isinstance(arg, unicode_type):
596 574 # only decode if not already decoded
597 575 arg = arg.decode(enc)
598 576 uargv.append(arg)
599 577 return uargv
600 578
601 579
602 580 def load_config(self, argv=None, aliases=None, flags=None):
603 581 """Parse the configuration and generate the Config object.
604 582
605 583 After loading, any arguments that are not key-value or
606 584 flags will be stored in self.extra_args - a list of
607 585 unparsed command-line arguments. This is used for
608 586 arguments such as input files or subcommands.
609 587
610 588 Parameters
611 589 ----------
612 590 argv : list, optional
613 591 A list that has the form of sys.argv[1:] which has unicode
614 592 elements of the form u"key=value". If this is None (default),
615 593 then self.argv will be used.
616 594 aliases : dict
617 595 A dict of aliases for configurable traits.
618 596 Keys are the short aliases, Values are the resolved trait.
619 597 Of the form: `{'alias' : 'Configurable.trait'}`
620 598 flags : dict
621 599 A dict of flags, keyed by str name. Values can be Config objects
622 600 or dicts. When the flag is triggered, The config is loaded as
623 601 `self.config.update(cfg)`.
624 602 """
625 603 self.clear()
626 604 if argv is None:
627 605 argv = self.argv
628 606 if aliases is None:
629 607 aliases = self.aliases
630 608 if flags is None:
631 609 flags = self.flags
632 610
633 611 # ensure argv is a list of unicode strings:
634 612 uargv = self._decode_argv(argv)
635 613 for idx,raw in enumerate(uargv):
636 614 # strip leading '-'
637 615 item = raw.lstrip('-')
638 616
639 617 if raw == '--':
640 618 # don't parse arguments after '--'
641 619 # this is useful for relaying arguments to scripts, e.g.
642 620 # ipython -i foo.py --matplotlib=qt -- args after '--' go-to-foo.py
643 621 self.extra_args.extend(uargv[idx+1:])
644 622 break
645 623
646 624 if kv_pattern.match(raw):
647 625 lhs,rhs = item.split('=',1)
648 626 # Substitute longnames for aliases.
649 627 if lhs in aliases:
650 628 lhs = aliases[lhs]
651 629 if '.' not in lhs:
652 630 # probably a mistyped alias, but not technically illegal
653 631 self.log.warn("Unrecognized alias: '%s', it will probably have no effect.", raw)
654 632 try:
655 633 self._exec_config_str(lhs, rhs)
656 634 except Exception:
657 635 raise ArgumentError("Invalid argument: '%s'" % raw)
658 636
659 637 elif flag_pattern.match(raw):
660 638 if item in flags:
661 639 cfg,help = flags[item]
662 640 self._load_flag(cfg)
663 641 else:
664 642 raise ArgumentError("Unrecognized flag: '%s'"%raw)
665 643 elif raw.startswith('-'):
666 644 kv = '--'+item
667 645 if kv_pattern.match(kv):
668 646 raise ArgumentError("Invalid argument: '%s', did you mean '%s'?"%(raw, kv))
669 647 else:
670 648 raise ArgumentError("Invalid argument: '%s'"%raw)
671 649 else:
672 650 # keep all args that aren't valid in a list,
673 651 # in case our parent knows what to do with them.
674 652 self.extra_args.append(item)
675 653 return self.config
676 654
677 655 class ArgParseConfigLoader(CommandLineConfigLoader):
678 656 """A loader that uses the argparse module to load from the command line."""
679 657
680 658 def __init__(self, argv=None, aliases=None, flags=None, log=None, *parser_args, **parser_kw):
681 659 """Create a config loader for use with argparse.
682 660
683 661 Parameters
684 662 ----------
685 663
686 664 argv : optional, list
687 665 If given, used to read command-line arguments from, otherwise
688 666 sys.argv[1:] is used.
689 667
690 668 parser_args : tuple
691 669 A tuple of positional arguments that will be passed to the
692 670 constructor of :class:`argparse.ArgumentParser`.
693 671
694 672 parser_kw : dict
695 673 A tuple of keyword arguments that will be passed to the
696 674 constructor of :class:`argparse.ArgumentParser`.
697 675
698 676 Returns
699 677 -------
700 678 config : Config
701 679 The resulting Config object.
702 680 """
703 681 super(CommandLineConfigLoader, self).__init__(log=log)
704 682 self.clear()
705 683 if argv is None:
706 684 argv = sys.argv[1:]
707 685 self.argv = argv
708 686 self.aliases = aliases or {}
709 687 self.flags = flags or {}
710 688
711 689 self.parser_args = parser_args
712 690 self.version = parser_kw.pop("version", None)
713 691 kwargs = dict(argument_default=argparse.SUPPRESS)
714 692 kwargs.update(parser_kw)
715 693 self.parser_kw = kwargs
716 694
717 695 def load_config(self, argv=None, aliases=None, flags=None):
718 696 """Parse command line arguments and return as a Config object.
719 697
720 698 Parameters
721 699 ----------
722 700
723 701 args : optional, list
724 702 If given, a list with the structure of sys.argv[1:] to parse
725 703 arguments from. If not given, the instance's self.argv attribute
726 704 (given at construction time) is used."""
727 705 self.clear()
728 706 if argv is None:
729 707 argv = self.argv
730 708 if aliases is None:
731 709 aliases = self.aliases
732 710 if flags is None:
733 711 flags = self.flags
734 712 self._create_parser(aliases, flags)
735 713 self._parse_args(argv)
736 714 self._convert_to_config()
737 715 return self.config
738 716
739 717 def get_extra_args(self):
740 718 if hasattr(self, 'extra_args'):
741 719 return self.extra_args
742 720 else:
743 721 return []
744 722
745 723 def _create_parser(self, aliases=None, flags=None):
746 724 self.parser = ArgumentParser(*self.parser_args, **self.parser_kw)
747 725 self._add_arguments(aliases, flags)
748 726
749 727 def _add_arguments(self, aliases=None, flags=None):
750 728 raise NotImplementedError("subclasses must implement _add_arguments")
751 729
752 730 def _parse_args(self, args):
753 731 """self.parser->self.parsed_data"""
754 732 # decode sys.argv to support unicode command-line options
755 733 enc = DEFAULT_ENCODING
756 734 uargs = [py3compat.cast_unicode(a, enc) for a in args]
757 735 self.parsed_data, self.extra_args = self.parser.parse_known_args(uargs)
758 736
759 737 def _convert_to_config(self):
760 738 """self.parsed_data->self.config"""
761 739 for k, v in iteritems(vars(self.parsed_data)):
762 740 exec("self.config.%s = v"%k, locals(), globals())
763 741
764 742 class KVArgParseConfigLoader(ArgParseConfigLoader):
765 743 """A config loader that loads aliases and flags with argparse,
766 744 but will use KVLoader for the rest. This allows better parsing
767 745 of common args, such as `ipython -c 'print 5'`, but still gets
768 746 arbitrary config with `ipython --InteractiveShell.use_readline=False`"""
769 747
770 748 def _add_arguments(self, aliases=None, flags=None):
771 749 self.alias_flags = {}
772 750 # print aliases, flags
773 751 if aliases is None:
774 752 aliases = self.aliases
775 753 if flags is None:
776 754 flags = self.flags
777 755 paa = self.parser.add_argument
778 756 for key,value in iteritems(aliases):
779 757 if key in flags:
780 758 # flags
781 759 nargs = '?'
782 760 else:
783 761 nargs = None
784 762 if len(key) is 1:
785 763 paa('-'+key, '--'+key, type=unicode_type, dest=value, nargs=nargs)
786 764 else:
787 765 paa('--'+key, type=unicode_type, dest=value, nargs=nargs)
788 766 for key, (value, help) in iteritems(flags):
789 767 if key in self.aliases:
790 768 #
791 769 self.alias_flags[self.aliases[key]] = value
792 770 continue
793 771 if len(key) is 1:
794 772 paa('-'+key, '--'+key, action='append_const', dest='_flags', const=value)
795 773 else:
796 774 paa('--'+key, action='append_const', dest='_flags', const=value)
797 775
798 776 def _convert_to_config(self):
799 777 """self.parsed_data->self.config, parse unrecognized extra args via KVLoader."""
800 778 # remove subconfigs list from namespace before transforming the Namespace
801 779 if '_flags' in self.parsed_data:
802 780 subcs = self.parsed_data._flags
803 781 del self.parsed_data._flags
804 782 else:
805 783 subcs = []
806 784
807 785 for k, v in iteritems(vars(self.parsed_data)):
808 786 if v is None:
809 787 # it was a flag that shares the name of an alias
810 788 subcs.append(self.alias_flags[k])
811 789 else:
812 790 # eval the KV assignment
813 791 self._exec_config_str(k, v)
814 792
815 793 for subc in subcs:
816 794 self._load_flag(subc)
817 795
818 796 if self.extra_args:
819 797 sub_parser = KeyValueConfigLoader(log=self.log)
820 798 sub_parser.load_config(self.extra_args)
821 799 self.config.merge(sub_parser.config)
822 800 self.extra_args = sub_parser.extra_args
823 801
824 802
825 803 def load_pyconfig_files(config_files, path):
826 804 """Load multiple Python config files, merging each of them in turn.
827 805
828 806 Parameters
829 807 ==========
830 808 config_files : list of str
831 809 List of config files names to load and merge into the config.
832 810 path : unicode
833 811 The full path to the location of the config files.
834 812 """
835 813 config = Config()
836 814 for cf in config_files:
837 815 loader = PyFileConfigLoader(cf, path=path)
838 816 try:
839 817 next_config = loader.load_config()
840 818 except ConfigFileNotFound:
841 819 pass
842 820 except:
843 821 raise
844 822 else:
845 823 config.merge(next_config)
846 824 return config
@@ -1,390 +1,376 b''
1 """Base Tornado handlers for the notebook.
2
3 Authors:
4
5 * Brian Granger
6 """
7
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
10 #
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
14
15 #-----------------------------------------------------------------------------
16 # Imports
17 #-----------------------------------------------------------------------------
1 """Base Tornado handlers for the notebook."""
18 2
3 # Copyright (c) IPython Development Team.
4 # Distributed under the terms of the Modified BSD License.
19 5
20 6 import functools
21 7 import json
22 8 import logging
23 9 import os
24 10 import re
25 11 import sys
26 12 import traceback
27 13 try:
28 14 # py3
29 15 from http.client import responses
30 16 except ImportError:
31 17 from httplib import responses
32 18
33 19 from jinja2 import TemplateNotFound
34 20 from tornado import web
35 21
36 22 try:
37 23 from tornado.log import app_log
38 24 except ImportError:
39 25 app_log = logging.getLogger()
40 26
41 27 from IPython.config import Application
42 28 from IPython.utils.path import filefind
43 29 from IPython.utils.py3compat import string_types
44 30 from IPython.html.utils import is_hidden
45 31
46 32 #-----------------------------------------------------------------------------
47 33 # Top-level handlers
48 34 #-----------------------------------------------------------------------------
49 35 non_alphanum = re.compile(r'[^A-Za-z0-9]')
50 36
51 37 class AuthenticatedHandler(web.RequestHandler):
52 38 """A RequestHandler with an authenticated user."""
53 39
54 40 def set_default_headers(self):
55 41 headers = self.settings.get('headers', {})
56 42 for header_name,value in headers.items() :
57 43 try:
58 44 self.set_header(header_name, value)
59 45 except Exception:
60 46 # tornado raise Exception (not a subclass)
61 47 # if method is unsupported (websocket and Access-Control-Allow-Origin
62 48 # for example, so just ignore)
63 49 pass
64 50
65 51 def clear_login_cookie(self):
66 52 self.clear_cookie(self.cookie_name)
67 53
68 54 def get_current_user(self):
69 55 user_id = self.get_secure_cookie(self.cookie_name)
70 56 # For now the user_id should not return empty, but it could eventually
71 57 if user_id == '':
72 58 user_id = 'anonymous'
73 59 if user_id is None:
74 60 # prevent extra Invalid cookie sig warnings:
75 61 self.clear_login_cookie()
76 62 if not self.login_available:
77 63 user_id = 'anonymous'
78 64 return user_id
79 65
80 66 @property
81 67 def cookie_name(self):
82 68 default_cookie_name = non_alphanum.sub('-', 'username-{}'.format(
83 69 self.request.host
84 70 ))
85 71 return self.settings.get('cookie_name', default_cookie_name)
86 72
87 73 @property
88 74 def password(self):
89 75 """our password"""
90 76 return self.settings.get('password', '')
91 77
92 78 @property
93 79 def logged_in(self):
94 80 """Is a user currently logged in?
95 81
96 82 """
97 83 user = self.get_current_user()
98 84 return (user and not user == 'anonymous')
99 85
100 86 @property
101 87 def login_available(self):
102 88 """May a user proceed to log in?
103 89
104 90 This returns True if login capability is available, irrespective of
105 91 whether the user is already logged in or not.
106 92
107 93 """
108 94 return bool(self.settings.get('password', ''))
109 95
110 96
111 97 class IPythonHandler(AuthenticatedHandler):
112 98 """IPython-specific extensions to authenticated handling
113 99
114 100 Mostly property shortcuts to IPython-specific settings.
115 101 """
116 102
117 103 @property
118 104 def config(self):
119 105 return self.settings.get('config', None)
120 106
121 107 @property
122 108 def log(self):
123 109 """use the IPython log by default, falling back on tornado's logger"""
124 110 if Application.initialized():
125 111 return Application.instance().log
126 112 else:
127 113 return app_log
128 114
129 115 #---------------------------------------------------------------
130 116 # URLs
131 117 #---------------------------------------------------------------
132 118
133 119 @property
134 120 def mathjax_url(self):
135 121 return self.settings.get('mathjax_url', '')
136 122
137 123 @property
138 124 def base_url(self):
139 125 return self.settings.get('base_url', '/')
140 126
141 127 #---------------------------------------------------------------
142 128 # Manager objects
143 129 #---------------------------------------------------------------
144 130
145 131 @property
146 132 def kernel_manager(self):
147 133 return self.settings['kernel_manager']
148 134
149 135 @property
150 136 def notebook_manager(self):
151 137 return self.settings['notebook_manager']
152 138
153 139 @property
154 140 def cluster_manager(self):
155 141 return self.settings['cluster_manager']
156 142
157 143 @property
158 144 def session_manager(self):
159 145 return self.settings['session_manager']
160 146
161 147 @property
162 148 def kernel_spec_manager(self):
163 149 return self.settings['kernel_spec_manager']
164 150
165 151 @property
166 152 def project_dir(self):
167 153 return self.notebook_manager.notebook_dir
168 154
169 155 #---------------------------------------------------------------
170 156 # template rendering
171 157 #---------------------------------------------------------------
172 158
173 159 def get_template(self, name):
174 160 """Return the jinja template object for a given name"""
175 161 return self.settings['jinja2_env'].get_template(name)
176 162
177 163 def render_template(self, name, **ns):
178 164 ns.update(self.template_namespace)
179 165 template = self.get_template(name)
180 166 return template.render(**ns)
181 167
182 168 @property
183 169 def template_namespace(self):
184 170 return dict(
185 171 base_url=self.base_url,
186 172 logged_in=self.logged_in,
187 173 login_available=self.login_available,
188 174 static_url=self.static_url,
189 175 )
190 176
191 177 def get_json_body(self):
192 178 """Return the body of the request as JSON data."""
193 179 if not self.request.body:
194 180 return None
195 181 # Do we need to call body.decode('utf-8') here?
196 182 body = self.request.body.strip().decode(u'utf-8')
197 183 try:
198 184 model = json.loads(body)
199 185 except Exception:
200 186 self.log.debug("Bad JSON: %r", body)
201 187 self.log.error("Couldn't parse JSON", exc_info=True)
202 188 raise web.HTTPError(400, u'Invalid JSON in body of request')
203 189 return model
204 190
205 191 def get_error_html(self, status_code, **kwargs):
206 192 """render custom error pages"""
207 193 exception = kwargs.get('exception')
208 194 message = ''
209 195 status_message = responses.get(status_code, 'Unknown HTTP Error')
210 196 if exception:
211 197 # get the custom message, if defined
212 198 try:
213 199 message = exception.log_message % exception.args
214 200 except Exception:
215 201 pass
216 202
217 203 # construct the custom reason, if defined
218 204 reason = getattr(exception, 'reason', '')
219 205 if reason:
220 206 status_message = reason
221 207
222 208 # build template namespace
223 209 ns = dict(
224 210 status_code=status_code,
225 211 status_message=status_message,
226 212 message=message,
227 213 exception=exception,
228 214 )
229 215
230 216 # render the template
231 217 try:
232 218 html = self.render_template('%s.html' % status_code, **ns)
233 219 except TemplateNotFound:
234 220 self.log.debug("No template for %d", status_code)
235 221 html = self.render_template('error.html', **ns)
236 222 return html
237 223
238 224
239 225 class Template404(IPythonHandler):
240 226 """Render our 404 template"""
241 227 def prepare(self):
242 228 raise web.HTTPError(404)
243 229
244 230
245 231 class AuthenticatedFileHandler(IPythonHandler, web.StaticFileHandler):
246 232 """static files should only be accessible when logged in"""
247 233
248 234 @web.authenticated
249 235 def get(self, path):
250 236 if os.path.splitext(path)[1] == '.ipynb':
251 237 name = os.path.basename(path)
252 238 self.set_header('Content-Type', 'application/json')
253 239 self.set_header('Content-Disposition','attachment; filename="%s"' % name)
254 240
255 241 return web.StaticFileHandler.get(self, path)
256 242
257 243 def compute_etag(self):
258 244 return None
259 245
260 246 def validate_absolute_path(self, root, absolute_path):
261 247 """Validate and return the absolute path.
262 248
263 249 Requires tornado 3.1
264 250
265 251 Adding to tornado's own handling, forbids the serving of hidden files.
266 252 """
267 253 abs_path = super(AuthenticatedFileHandler, self).validate_absolute_path(root, absolute_path)
268 254 abs_root = os.path.abspath(root)
269 255 if is_hidden(abs_path, abs_root):
270 256 self.log.info("Refusing to serve hidden file, via 404 Error")
271 257 raise web.HTTPError(404)
272 258 return abs_path
273 259
274 260
275 261 def json_errors(method):
276 262 """Decorate methods with this to return GitHub style JSON errors.
277 263
278 264 This should be used on any JSON API on any handler method that can raise HTTPErrors.
279 265
280 266 This will grab the latest HTTPError exception using sys.exc_info
281 267 and then:
282 268
283 269 1. Set the HTTP status code based on the HTTPError
284 270 2. Create and return a JSON body with a message field describing
285 271 the error in a human readable form.
286 272 """
287 273 @functools.wraps(method)
288 274 def wrapper(self, *args, **kwargs):
289 275 try:
290 276 result = method(self, *args, **kwargs)
291 277 except web.HTTPError as e:
292 278 status = e.status_code
293 279 message = e.log_message
294 280 self.log.warn(message)
295 281 self.set_status(e.status_code)
296 282 self.finish(json.dumps(dict(message=message)))
297 283 except Exception:
298 284 self.log.error("Unhandled error in API request", exc_info=True)
299 285 status = 500
300 286 message = "Unknown server error"
301 287 t, value, tb = sys.exc_info()
302 288 self.set_status(status)
303 289 tb_text = ''.join(traceback.format_exception(t, value, tb))
304 290 reply = dict(message=message, traceback=tb_text)
305 291 self.finish(json.dumps(reply))
306 292 else:
307 293 return result
308 294 return wrapper
309 295
310 296
311 297
312 298 #-----------------------------------------------------------------------------
313 299 # File handler
314 300 #-----------------------------------------------------------------------------
315 301
316 302 # to minimize subclass changes:
317 303 HTTPError = web.HTTPError
318 304
319 305 class FileFindHandler(web.StaticFileHandler):
320 306 """subclass of StaticFileHandler for serving files from a search path"""
321 307
322 308 # cache search results, don't search for files more than once
323 309 _static_paths = {}
324 310
325 311 def initialize(self, path, default_filename=None):
326 312 if isinstance(path, string_types):
327 313 path = [path]
328 314
329 315 self.root = tuple(
330 316 os.path.abspath(os.path.expanduser(p)) + os.sep for p in path
331 317 )
332 318 self.default_filename = default_filename
333 319
334 320 def compute_etag(self):
335 321 return None
336 322
337 323 @classmethod
338 324 def get_absolute_path(cls, roots, path):
339 325 """locate a file to serve on our static file search path"""
340 326 with cls._lock:
341 327 if path in cls._static_paths:
342 328 return cls._static_paths[path]
343 329 try:
344 330 abspath = os.path.abspath(filefind(path, roots))
345 331 except IOError:
346 332 # IOError means not found
347 333 return ''
348 334
349 335 cls._static_paths[path] = abspath
350 336 return abspath
351 337
352 338 def validate_absolute_path(self, root, absolute_path):
353 339 """check if the file should be served (raises 404, 403, etc.)"""
354 340 if absolute_path == '':
355 341 raise web.HTTPError(404)
356 342
357 343 for root in self.root:
358 344 if (absolute_path + os.sep).startswith(root):
359 345 break
360 346
361 347 return super(FileFindHandler, self).validate_absolute_path(root, absolute_path)
362 348
363 349
364 350 class TrailingSlashHandler(web.RequestHandler):
365 351 """Simple redirect handler that strips trailing slashes
366 352
367 353 This should be the first, highest priority handler.
368 354 """
369 355
370 356 SUPPORTED_METHODS = ['GET']
371 357
372 358 def get(self):
373 359 self.redirect(self.request.uri.rstrip('/'))
374 360
375 361 #-----------------------------------------------------------------------------
376 362 # URL pattern fragments for re-use
377 363 #-----------------------------------------------------------------------------
378 364
379 365 path_regex = r"(?P<path>(?:/.*)*)"
380 366 notebook_name_regex = r"(?P<name>[^/]+\.ipynb)"
381 367 notebook_path_regex = "%s/%s" % (path_regex, notebook_name_regex)
382 368
383 369 #-----------------------------------------------------------------------------
384 370 # URL to handler mappings
385 371 #-----------------------------------------------------------------------------
386 372
387 373
388 374 default_handlers = [
389 375 (r".*/", TrailingSlashHandler)
390 376 ]
@@ -1,238 +1,217 b''
1 """The official API for working with notebooks in the current format version.
2
3 Authors:
4
5 * Brian Granger
6 * Jonathan Frederic
7 """
8
9 #-----------------------------------------------------------------------------
10 # Copyright (C) 2008-2011 The IPython Development Team
11 #
12 # Distributed under the terms of the BSD License. The full license is in
13 # the file COPYING, distributed as part of this software.
14 #-----------------------------------------------------------------------------
15
16 #-----------------------------------------------------------------------------
17 # Imports
18 #-----------------------------------------------------------------------------
1 """The official API for working with notebooks in the current format version."""
19 2
20 3 from __future__ import print_function
21 4
22 5 from xml.etree import ElementTree as ET
23 6 import re
24 7
25 8 from IPython.utils.py3compat import unicode_type
26 9
27 10 from IPython.nbformat.v3 import (
28 11 NotebookNode,
29 12 new_code_cell, new_text_cell, new_notebook, new_output, new_worksheet,
30 13 parse_filename, new_metadata, new_author, new_heading_cell, nbformat,
31 14 nbformat_minor, nbformat_schema, to_notebook_json
32 15 )
33 16 from IPython.nbformat import v3 as _v_latest
34 17
35 18 from .reader import reads as reader_reads
36 19 from .reader import versions
37 20 from .convert import convert
38 21 from .validator import validate
39 22
40 import logging
41 logger = logging.getLogger('NotebookApp')
23 from IPython.utils.log import get_logger
42 24
43 #-----------------------------------------------------------------------------
44 # Code
45 #-----------------------------------------------------------------------------
46 25
47 26 current_nbformat = nbformat
48 27 current_nbformat_minor = nbformat_minor
49 28 current_nbformat_module = _v_latest.__name__
50 29
51 30
52 31 def docstring_nbformat_mod(func):
53 32 """Decorator for docstrings referring to classes/functions accessed through
54 33 nbformat.current.
55 34
56 35 Put {nbformat_mod} in the docstring in place of 'IPython.nbformat.v3'.
57 36 """
58 37 func.__doc__ = func.__doc__.format(nbformat_mod=current_nbformat_module)
59 38 return func
60 39
61 40
62 41 class NBFormatError(ValueError):
63 42 pass
64 43
65 44
66 45 def parse_py(s, **kwargs):
67 46 """Parse a string into a (nbformat, string) tuple."""
68 47 nbf = current_nbformat
69 48 nbm = current_nbformat_minor
70 49
71 50 pattern = r'# <nbformat>(?P<nbformat>\d+[\.\d+]*)</nbformat>'
72 51 m = re.search(pattern,s)
73 52 if m is not None:
74 53 digits = m.group('nbformat').split('.')
75 54 nbf = int(digits[0])
76 55 if len(digits) > 1:
77 56 nbm = int(digits[1])
78 57
79 58 return nbf, nbm, s
80 59
81 60
82 61 def reads_json(nbjson, **kwargs):
83 62 """Read a JSON notebook from a string and return the NotebookNode
84 63 object. Report if any JSON format errors are detected.
85 64
86 65 """
87 66 nb = reader_reads(nbjson, **kwargs)
88 67 nb_current = convert(nb, current_nbformat)
89 68 errors = validate(nb_current)
90 69 if errors:
91 logger.error(
70 get_logger().error(
92 71 "Notebook JSON is invalid (%d errors detected during read)",
93 72 len(errors))
94 73 return nb_current
95 74
96 75
97 76 def writes_json(nb, **kwargs):
98 77 """Take a NotebookNode object and write out a JSON string. Report if
99 78 any JSON format errors are detected.
100 79
101 80 """
102 81 errors = validate(nb)
103 82 if errors:
104 logger.error(
83 get_logger().error(
105 84 "Notebook JSON is invalid (%d errors detected during write)",
106 85 len(errors))
107 86 nbjson = versions[current_nbformat].writes_json(nb, **kwargs)
108 87 return nbjson
109 88
110 89
111 90 def reads_py(s, **kwargs):
112 91 """Read a .py notebook from a string and return the NotebookNode object."""
113 92 nbf, nbm, s = parse_py(s, **kwargs)
114 93 if nbf in (2, 3):
115 94 nb = versions[nbf].to_notebook_py(s, **kwargs)
116 95 else:
117 96 raise NBFormatError('Unsupported PY nbformat version: %i' % nbf)
118 97 return nb
119 98
120 99
121 100 def writes_py(nb, **kwargs):
122 101 # nbformat 3 is the latest format that supports py
123 102 return versions[3].writes_py(nb, **kwargs)
124 103
125 104
126 105 # High level API
127 106
128 107
129 108 def reads(s, format, **kwargs):
130 109 """Read a notebook from a string and return the NotebookNode object.
131 110
132 111 This function properly handles notebooks of any version. The notebook
133 112 returned will always be in the current version's format.
134 113
135 114 Parameters
136 115 ----------
137 116 s : unicode
138 117 The raw unicode string to read the notebook from.
139 118 format : (u'json', u'ipynb', u'py')
140 119 The format that the string is in.
141 120
142 121 Returns
143 122 -------
144 123 nb : NotebookNode
145 124 The notebook that was read.
146 125 """
147 126 format = unicode_type(format)
148 127 if format == u'json' or format == u'ipynb':
149 128 return reads_json(s, **kwargs)
150 129 elif format == u'py':
151 130 return reads_py(s, **kwargs)
152 131 else:
153 132 raise NBFormatError('Unsupported format: %s' % format)
154 133
155 134
156 135 def writes(nb, format, **kwargs):
157 136 """Write a notebook to a string in a given format in the current nbformat version.
158 137
159 138 This function always writes the notebook in the current nbformat version.
160 139
161 140 Parameters
162 141 ----------
163 142 nb : NotebookNode
164 143 The notebook to write.
165 144 format : (u'json', u'ipynb', u'py')
166 145 The format to write the notebook in.
167 146
168 147 Returns
169 148 -------
170 149 s : unicode
171 150 The notebook string.
172 151 """
173 152 format = unicode_type(format)
174 153 if format == u'json' or format == u'ipynb':
175 154 return writes_json(nb, **kwargs)
176 155 elif format == u'py':
177 156 return writes_py(nb, **kwargs)
178 157 else:
179 158 raise NBFormatError('Unsupported format: %s' % format)
180 159
181 160
182 161 def read(fp, format, **kwargs):
183 162 """Read a notebook from a file and return the NotebookNode object.
184 163
185 164 This function properly handles notebooks of any version. The notebook
186 165 returned will always be in the current version's format.
187 166
188 167 Parameters
189 168 ----------
190 169 fp : file
191 170 Any file-like object with a read method.
192 171 format : (u'json', u'ipynb', u'py')
193 172 The format that the string is in.
194 173
195 174 Returns
196 175 -------
197 176 nb : NotebookNode
198 177 The notebook that was read.
199 178 """
200 179 return reads(fp.read(), format, **kwargs)
201 180
202 181
203 182 def write(nb, fp, format, **kwargs):
204 183 """Write a notebook to a file in a given format in the current nbformat version.
205 184
206 185 This function always writes the notebook in the current nbformat version.
207 186
208 187 Parameters
209 188 ----------
210 189 nb : NotebookNode
211 190 The notebook to write.
212 191 fp : file
213 192 Any file-like object with a write method.
214 193 format : (u'json', u'ipynb', u'py')
215 194 The format to write the notebook in.
216 195
217 196 Returns
218 197 -------
219 198 s : unicode
220 199 The notebook string.
221 200 """
222 201 return fp.write(writes(nb, format, **kwargs))
223 202
224 203 def _convert_to_metadata():
225 204 """Convert to a notebook having notebook metadata."""
226 205 import glob
227 206 for fname in glob.glob('*.ipynb'):
228 207 print('Converting file:',fname)
229 208 with open(fname,'r') as f:
230 209 nb = read(f,u'json')
231 210 md = new_metadata()
232 211 if u'name' in nb:
233 212 md.name = nb.name
234 213 del nb[u'name']
235 214 nb.metadata = md
236 215 with open(fname,'w') as f:
237 216 write(nb, f, u'json')
238 217
@@ -1,859 +1,848 b''
1 1 """The Python scheduler for rich scheduling.
2 2
3 3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 5 Python Scheduler exists.
6
7 Authors:
8
9 * Min RK
10 6 """
11 #-----------------------------------------------------------------------------
12 # Copyright (C) 2010-2011 The IPython Development Team
13 #
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
17 7
18 #----------------------------------------------------------------------
19 # Imports
20 #----------------------------------------------------------------------
8 # Copyright (c) IPython Development Team.
9 # Distributed under the terms of the Modified BSD License.
21 10
22 11 import logging
23 12 import sys
24 13 import time
25 14
26 15 from collections import deque
27 16 from datetime import datetime
28 17 from random import randint, random
29 18 from types import FunctionType
30 19
31 20 try:
32 21 import numpy
33 22 except ImportError:
34 23 numpy = None
35 24
36 25 import zmq
37 26 from zmq.eventloop import ioloop, zmqstream
38 27
39 28 # local imports
40 29 from IPython.external.decorator import decorator
41 30 from IPython.config.application import Application
42 31 from IPython.config.loader import Config
43 32 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
44 33 from IPython.utils.py3compat import cast_bytes
45 34
46 35 from IPython.parallel import error, util
47 36 from IPython.parallel.factory import SessionFactory
48 37 from IPython.parallel.util import connect_logger, local_logger
49 38
50 39 from .dependency import Dependency
51 40
52 41 @decorator
53 42 def logged(f,self,*args,**kwargs):
54 43 # print ("#--------------------")
55 44 self.log.debug("scheduler::%s(*%s,**%s)", f.__name__, args, kwargs)
56 45 # print ("#--")
57 46 return f(self,*args, **kwargs)
58 47
59 48 #----------------------------------------------------------------------
60 49 # Chooser functions
61 50 #----------------------------------------------------------------------
62 51
63 52 def plainrandom(loads):
64 53 """Plain random pick."""
65 54 n = len(loads)
66 55 return randint(0,n-1)
67 56
68 57 def lru(loads):
69 58 """Always pick the front of the line.
70 59
71 60 The content of `loads` is ignored.
72 61
73 62 Assumes LRU ordering of loads, with oldest first.
74 63 """
75 64 return 0
76 65
77 66 def twobin(loads):
78 67 """Pick two at random, use the LRU of the two.
79 68
80 69 The content of loads is ignored.
81 70
82 71 Assumes LRU ordering of loads, with oldest first.
83 72 """
84 73 n = len(loads)
85 74 a = randint(0,n-1)
86 75 b = randint(0,n-1)
87 76 return min(a,b)
88 77
89 78 def weighted(loads):
90 79 """Pick two at random using inverse load as weight.
91 80
92 81 Return the less loaded of the two.
93 82 """
94 83 # weight 0 a million times more than 1:
95 84 weights = 1./(1e-6+numpy.array(loads))
96 85 sums = weights.cumsum()
97 86 t = sums[-1]
98 87 x = random()*t
99 88 y = random()*t
100 89 idx = 0
101 90 idy = 0
102 91 while sums[idx] < x:
103 92 idx += 1
104 93 while sums[idy] < y:
105 94 idy += 1
106 95 if weights[idy] > weights[idx]:
107 96 return idy
108 97 else:
109 98 return idx
110 99
111 100 def leastload(loads):
112 101 """Always choose the lowest load.
113 102
114 103 If the lowest load occurs more than once, the first
115 104 occurance will be used. If loads has LRU ordering, this means
116 105 the LRU of those with the lowest load is chosen.
117 106 """
118 107 return loads.index(min(loads))
119 108
120 109 #---------------------------------------------------------------------
121 110 # Classes
122 111 #---------------------------------------------------------------------
123 112
124 113
125 114 # store empty default dependency:
126 115 MET = Dependency([])
127 116
128 117
129 118 class Job(object):
130 119 """Simple container for a job"""
131 120 def __init__(self, msg_id, raw_msg, idents, msg, header, metadata,
132 121 targets, after, follow, timeout):
133 122 self.msg_id = msg_id
134 123 self.raw_msg = raw_msg
135 124 self.idents = idents
136 125 self.msg = msg
137 126 self.header = header
138 127 self.metadata = metadata
139 128 self.targets = targets
140 129 self.after = after
141 130 self.follow = follow
142 131 self.timeout = timeout
143 132
144 133 self.removed = False # used for lazy-delete from sorted queue
145 134 self.timestamp = time.time()
146 135 self.timeout_id = 0
147 136 self.blacklist = set()
148 137
149 138 def __lt__(self, other):
150 139 return self.timestamp < other.timestamp
151 140
152 141 def __cmp__(self, other):
153 142 return cmp(self.timestamp, other.timestamp)
154 143
155 144 @property
156 145 def dependents(self):
157 146 return self.follow.union(self.after)
158 147
159 148
160 149 class TaskScheduler(SessionFactory):
161 150 """Python TaskScheduler object.
162 151
163 152 This is the simplest object that supports msg_id based
164 153 DAG dependencies. *Only* task msg_ids are checked, not
165 154 msg_ids of jobs submitted via the MUX queue.
166 155
167 156 """
168 157
169 158 hwm = Integer(1, config=True,
170 159 help="""specify the High Water Mark (HWM) for the downstream
171 160 socket in the Task scheduler. This is the maximum number
172 161 of allowed outstanding tasks on each engine.
173 162
174 163 The default (1) means that only one task can be outstanding on each
175 164 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
176 165 engines continue to be assigned tasks while they are working,
177 166 effectively hiding network latency behind computation, but can result
178 167 in an imbalance of work when submitting many heterogenous tasks all at
179 168 once. Any positive value greater than one is a compromise between the
180 169 two.
181 170
182 171 """
183 172 )
184 173 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
185 174 'leastload', config=True, allow_none=False,
186 175 help="""select the task scheduler scheme [default: Python LRU]
187 176 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
188 177 )
189 178 def _scheme_name_changed(self, old, new):
190 179 self.log.debug("Using scheme %r"%new)
191 180 self.scheme = globals()[new]
192 181
193 182 # input arguments:
194 183 scheme = Instance(FunctionType) # function for determining the destination
195 184 def _scheme_default(self):
196 185 return leastload
197 186 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
198 187 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
199 188 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
200 189 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
201 190 query_stream = Instance(zmqstream.ZMQStream) # hub-facing DEALER stream
202 191
203 192 # internals:
204 193 queue = Instance(deque) # sorted list of Jobs
205 194 def _queue_default(self):
206 195 return deque()
207 196 queue_map = Dict() # dict by msg_id of Jobs (for O(1) access to the Queue)
208 197 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
209 198 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
210 199 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
211 200 pending = Dict() # dict by engine_uuid of submitted tasks
212 201 completed = Dict() # dict by engine_uuid of completed tasks
213 202 failed = Dict() # dict by engine_uuid of failed tasks
214 203 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
215 204 clients = Dict() # dict by msg_id for who submitted the task
216 205 targets = List() # list of target IDENTs
217 206 loads = List() # list of engine loads
218 207 # full = Set() # set of IDENTs that have HWM outstanding tasks
219 208 all_completed = Set() # set of all completed tasks
220 209 all_failed = Set() # set of all failed tasks
221 210 all_done = Set() # set of all finished tasks=union(completed,failed)
222 211 all_ids = Set() # set of all submitted task IDs
223 212
224 213 ident = CBytes() # ZMQ identity. This should just be self.session.session
225 214 # but ensure Bytes
226 215 def _ident_default(self):
227 216 return self.session.bsession
228 217
229 218 def start(self):
230 219 self.query_stream.on_recv(self.dispatch_query_reply)
231 220 self.session.send(self.query_stream, "connection_request", {})
232 221
233 222 self.engine_stream.on_recv(self.dispatch_result, copy=False)
234 223 self.client_stream.on_recv(self.dispatch_submission, copy=False)
235 224
236 225 self._notification_handlers = dict(
237 226 registration_notification = self._register_engine,
238 227 unregistration_notification = self._unregister_engine
239 228 )
240 229 self.notifier_stream.on_recv(self.dispatch_notification)
241 230 self.log.info("Scheduler started [%s]" % self.scheme_name)
242 231
243 232 def resume_receiving(self):
244 233 """Resume accepting jobs."""
245 234 self.client_stream.on_recv(self.dispatch_submission, copy=False)
246 235
247 236 def stop_receiving(self):
248 237 """Stop accepting jobs while there are no engines.
249 238 Leave them in the ZMQ queue."""
250 239 self.client_stream.on_recv(None)
251 240
252 241 #-----------------------------------------------------------------------
253 242 # [Un]Registration Handling
254 243 #-----------------------------------------------------------------------
255 244
256 245
257 246 def dispatch_query_reply(self, msg):
258 247 """handle reply to our initial connection request"""
259 248 try:
260 249 idents,msg = self.session.feed_identities(msg)
261 250 except ValueError:
262 251 self.log.warn("task::Invalid Message: %r",msg)
263 252 return
264 253 try:
265 254 msg = self.session.unserialize(msg)
266 255 except ValueError:
267 256 self.log.warn("task::Unauthorized message from: %r"%idents)
268 257 return
269 258
270 259 content = msg['content']
271 260 for uuid in content.get('engines', {}).values():
272 261 self._register_engine(cast_bytes(uuid))
273 262
274 263
275 264 @util.log_errors
276 265 def dispatch_notification(self, msg):
277 266 """dispatch register/unregister events."""
278 267 try:
279 268 idents,msg = self.session.feed_identities(msg)
280 269 except ValueError:
281 270 self.log.warn("task::Invalid Message: %r",msg)
282 271 return
283 272 try:
284 273 msg = self.session.unserialize(msg)
285 274 except ValueError:
286 275 self.log.warn("task::Unauthorized message from: %r"%idents)
287 276 return
288 277
289 278 msg_type = msg['header']['msg_type']
290 279
291 280 handler = self._notification_handlers.get(msg_type, None)
292 281 if handler is None:
293 282 self.log.error("Unhandled message type: %r"%msg_type)
294 283 else:
295 284 try:
296 285 handler(cast_bytes(msg['content']['uuid']))
297 286 except Exception:
298 287 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
299 288
300 289 def _register_engine(self, uid):
301 290 """New engine with ident `uid` became available."""
302 291 # head of the line:
303 292 self.targets.insert(0,uid)
304 293 self.loads.insert(0,0)
305 294
306 295 # initialize sets
307 296 self.completed[uid] = set()
308 297 self.failed[uid] = set()
309 298 self.pending[uid] = {}
310 299
311 300 # rescan the graph:
312 301 self.update_graph(None)
313 302
314 303 def _unregister_engine(self, uid):
315 304 """Existing engine with ident `uid` became unavailable."""
316 305 if len(self.targets) == 1:
317 306 # this was our only engine
318 307 pass
319 308
320 309 # handle any potentially finished tasks:
321 310 self.engine_stream.flush()
322 311
323 312 # don't pop destinations, because they might be used later
324 313 # map(self.destinations.pop, self.completed.pop(uid))
325 314 # map(self.destinations.pop, self.failed.pop(uid))
326 315
327 316 # prevent this engine from receiving work
328 317 idx = self.targets.index(uid)
329 318 self.targets.pop(idx)
330 319 self.loads.pop(idx)
331 320
332 321 # wait 5 seconds before cleaning up pending jobs, since the results might
333 322 # still be incoming
334 323 if self.pending[uid]:
335 324 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
336 325 dc.start()
337 326 else:
338 327 self.completed.pop(uid)
339 328 self.failed.pop(uid)
340 329
341 330
342 331 def handle_stranded_tasks(self, engine):
343 332 """Deal with jobs resident in an engine that died."""
344 333 lost = self.pending[engine]
345 334 for msg_id in lost.keys():
346 335 if msg_id not in self.pending[engine]:
347 336 # prevent double-handling of messages
348 337 continue
349 338
350 339 raw_msg = lost[msg_id].raw_msg
351 340 idents,msg = self.session.feed_identities(raw_msg, copy=False)
352 341 parent = self.session.unpack(msg[1].bytes)
353 342 idents = [engine, idents[0]]
354 343
355 344 # build fake error reply
356 345 try:
357 346 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
358 347 except:
359 348 content = error.wrap_exception()
360 349 # build fake metadata
361 350 md = dict(
362 351 status=u'error',
363 352 engine=engine.decode('ascii'),
364 353 date=datetime.now(),
365 354 )
366 355 msg = self.session.msg('apply_reply', content, parent=parent, metadata=md)
367 356 raw_reply = list(map(zmq.Message, self.session.serialize(msg, ident=idents)))
368 357 # and dispatch it
369 358 self.dispatch_result(raw_reply)
370 359
371 360 # finally scrub completed/failed lists
372 361 self.completed.pop(engine)
373 362 self.failed.pop(engine)
374 363
375 364
376 365 #-----------------------------------------------------------------------
377 366 # Job Submission
378 367 #-----------------------------------------------------------------------
379 368
380 369
381 370 @util.log_errors
382 371 def dispatch_submission(self, raw_msg):
383 372 """Dispatch job submission to appropriate handlers."""
384 373 # ensure targets up to date:
385 374 self.notifier_stream.flush()
386 375 try:
387 376 idents, msg = self.session.feed_identities(raw_msg, copy=False)
388 377 msg = self.session.unserialize(msg, content=False, copy=False)
389 378 except Exception:
390 379 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
391 380 return
392 381
393 382
394 383 # send to monitor
395 384 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
396 385
397 386 header = msg['header']
398 387 md = msg['metadata']
399 388 msg_id = header['msg_id']
400 389 self.all_ids.add(msg_id)
401 390
402 391 # get targets as a set of bytes objects
403 392 # from a list of unicode objects
404 393 targets = md.get('targets', [])
405 394 targets = set(map(cast_bytes, targets))
406 395
407 396 retries = md.get('retries', 0)
408 397 self.retries[msg_id] = retries
409 398
410 399 # time dependencies
411 400 after = md.get('after', None)
412 401 if after:
413 402 after = Dependency(after)
414 403 if after.all:
415 404 if after.success:
416 405 after = Dependency(after.difference(self.all_completed),
417 406 success=after.success,
418 407 failure=after.failure,
419 408 all=after.all,
420 409 )
421 410 if after.failure:
422 411 after = Dependency(after.difference(self.all_failed),
423 412 success=after.success,
424 413 failure=after.failure,
425 414 all=after.all,
426 415 )
427 416 if after.check(self.all_completed, self.all_failed):
428 417 # recast as empty set, if `after` already met,
429 418 # to prevent unnecessary set comparisons
430 419 after = MET
431 420 else:
432 421 after = MET
433 422
434 423 # location dependencies
435 424 follow = Dependency(md.get('follow', []))
436 425
437 426 timeout = md.get('timeout', None)
438 427 if timeout:
439 428 timeout = float(timeout)
440 429
441 430 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
442 431 header=header, targets=targets, after=after, follow=follow,
443 432 timeout=timeout, metadata=md,
444 433 )
445 434 # validate and reduce dependencies:
446 435 for dep in after,follow:
447 436 if not dep: # empty dependency
448 437 continue
449 438 # check valid:
450 439 if msg_id in dep or dep.difference(self.all_ids):
451 440 self.queue_map[msg_id] = job
452 441 return self.fail_unreachable(msg_id, error.InvalidDependency)
453 442 # check if unreachable:
454 443 if dep.unreachable(self.all_completed, self.all_failed):
455 444 self.queue_map[msg_id] = job
456 445 return self.fail_unreachable(msg_id)
457 446
458 447 if after.check(self.all_completed, self.all_failed):
459 448 # time deps already met, try to run
460 449 if not self.maybe_run(job):
461 450 # can't run yet
462 451 if msg_id not in self.all_failed:
463 452 # could have failed as unreachable
464 453 self.save_unmet(job)
465 454 else:
466 455 self.save_unmet(job)
467 456
468 457 def job_timeout(self, job, timeout_id):
469 458 """callback for a job's timeout.
470 459
471 460 The job may or may not have been run at this point.
472 461 """
473 462 if job.timeout_id != timeout_id:
474 463 # not the most recent call
475 464 return
476 465 now = time.time()
477 466 if job.timeout >= (now + 1):
478 467 self.log.warn("task %s timeout fired prematurely: %s > %s",
479 468 job.msg_id, job.timeout, now
480 469 )
481 470 if job.msg_id in self.queue_map:
482 471 # still waiting, but ran out of time
483 472 self.log.info("task %r timed out", job.msg_id)
484 473 self.fail_unreachable(job.msg_id, error.TaskTimeout)
485 474
486 475 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
487 476 """a task has become unreachable, send a reply with an ImpossibleDependency
488 477 error."""
489 478 if msg_id not in self.queue_map:
490 479 self.log.error("task %r already failed!", msg_id)
491 480 return
492 481 job = self.queue_map.pop(msg_id)
493 482 # lazy-delete from the queue
494 483 job.removed = True
495 484 for mid in job.dependents:
496 485 if mid in self.graph:
497 486 self.graph[mid].remove(msg_id)
498 487
499 488 try:
500 489 raise why()
501 490 except:
502 491 content = error.wrap_exception()
503 492 self.log.debug("task %r failing as unreachable with: %s", msg_id, content['ename'])
504 493
505 494 self.all_done.add(msg_id)
506 495 self.all_failed.add(msg_id)
507 496
508 497 msg = self.session.send(self.client_stream, 'apply_reply', content,
509 498 parent=job.header, ident=job.idents)
510 499 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
511 500
512 501 self.update_graph(msg_id, success=False)
513 502
514 503 def available_engines(self):
515 504 """return a list of available engine indices based on HWM"""
516 505 if not self.hwm:
517 506 return list(range(len(self.targets)))
518 507 available = []
519 508 for idx in range(len(self.targets)):
520 509 if self.loads[idx] < self.hwm:
521 510 available.append(idx)
522 511 return available
523 512
524 513 def maybe_run(self, job):
525 514 """check location dependencies, and run if they are met."""
526 515 msg_id = job.msg_id
527 516 self.log.debug("Attempting to assign task %s", msg_id)
528 517 available = self.available_engines()
529 518 if not available:
530 519 # no engines, definitely can't run
531 520 return False
532 521
533 522 if job.follow or job.targets or job.blacklist or self.hwm:
534 523 # we need a can_run filter
535 524 def can_run(idx):
536 525 # check hwm
537 526 if self.hwm and self.loads[idx] == self.hwm:
538 527 return False
539 528 target = self.targets[idx]
540 529 # check blacklist
541 530 if target in job.blacklist:
542 531 return False
543 532 # check targets
544 533 if job.targets and target not in job.targets:
545 534 return False
546 535 # check follow
547 536 return job.follow.check(self.completed[target], self.failed[target])
548 537
549 538 indices = list(filter(can_run, available))
550 539
551 540 if not indices:
552 541 # couldn't run
553 542 if job.follow.all:
554 543 # check follow for impossibility
555 544 dests = set()
556 545 relevant = set()
557 546 if job.follow.success:
558 547 relevant = self.all_completed
559 548 if job.follow.failure:
560 549 relevant = relevant.union(self.all_failed)
561 550 for m in job.follow.intersection(relevant):
562 551 dests.add(self.destinations[m])
563 552 if len(dests) > 1:
564 553 self.queue_map[msg_id] = job
565 554 self.fail_unreachable(msg_id)
566 555 return False
567 556 if job.targets:
568 557 # check blacklist+targets for impossibility
569 558 job.targets.difference_update(job.blacklist)
570 559 if not job.targets or not job.targets.intersection(self.targets):
571 560 self.queue_map[msg_id] = job
572 561 self.fail_unreachable(msg_id)
573 562 return False
574 563 return False
575 564 else:
576 565 indices = None
577 566
578 567 self.submit_task(job, indices)
579 568 return True
580 569
581 570 def save_unmet(self, job):
582 571 """Save a message for later submission when its dependencies are met."""
583 572 msg_id = job.msg_id
584 573 self.log.debug("Adding task %s to the queue", msg_id)
585 574 self.queue_map[msg_id] = job
586 575 self.queue.append(job)
587 576 # track the ids in follow or after, but not those already finished
588 577 for dep_id in job.after.union(job.follow).difference(self.all_done):
589 578 if dep_id not in self.graph:
590 579 self.graph[dep_id] = set()
591 580 self.graph[dep_id].add(msg_id)
592 581
593 582 # schedule timeout callback
594 583 if job.timeout:
595 584 timeout_id = job.timeout_id = job.timeout_id + 1
596 585 self.loop.add_timeout(time.time() + job.timeout,
597 586 lambda : self.job_timeout(job, timeout_id)
598 587 )
599 588
600 589
601 590 def submit_task(self, job, indices=None):
602 591 """Submit a task to any of a subset of our targets."""
603 592 if indices:
604 593 loads = [self.loads[i] for i in indices]
605 594 else:
606 595 loads = self.loads
607 596 idx = self.scheme(loads)
608 597 if indices:
609 598 idx = indices[idx]
610 599 target = self.targets[idx]
611 600 # print (target, map(str, msg[:3]))
612 601 # send job to the engine
613 602 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
614 603 self.engine_stream.send_multipart(job.raw_msg, copy=False)
615 604 # update load
616 605 self.add_job(idx)
617 606 self.pending[target][job.msg_id] = job
618 607 # notify Hub
619 608 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
620 609 self.session.send(self.mon_stream, 'task_destination', content=content,
621 610 ident=[b'tracktask',self.ident])
622 611
623 612
624 613 #-----------------------------------------------------------------------
625 614 # Result Handling
626 615 #-----------------------------------------------------------------------
627 616
628 617
629 618 @util.log_errors
630 619 def dispatch_result(self, raw_msg):
631 620 """dispatch method for result replies"""
632 621 try:
633 622 idents,msg = self.session.feed_identities(raw_msg, copy=False)
634 623 msg = self.session.unserialize(msg, content=False, copy=False)
635 624 engine = idents[0]
636 625 try:
637 626 idx = self.targets.index(engine)
638 627 except ValueError:
639 628 pass # skip load-update for dead engines
640 629 else:
641 630 self.finish_job(idx)
642 631 except Exception:
643 632 self.log.error("task::Invalid result: %r", raw_msg, exc_info=True)
644 633 return
645 634
646 635 md = msg['metadata']
647 636 parent = msg['parent_header']
648 637 if md.get('dependencies_met', True):
649 638 success = (md['status'] == 'ok')
650 639 msg_id = parent['msg_id']
651 640 retries = self.retries[msg_id]
652 641 if not success and retries > 0:
653 642 # failed
654 643 self.retries[msg_id] = retries - 1
655 644 self.handle_unmet_dependency(idents, parent)
656 645 else:
657 646 del self.retries[msg_id]
658 647 # relay to client and update graph
659 648 self.handle_result(idents, parent, raw_msg, success)
660 649 # send to Hub monitor
661 650 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
662 651 else:
663 652 self.handle_unmet_dependency(idents, parent)
664 653
665 654 def handle_result(self, idents, parent, raw_msg, success=True):
666 655 """handle a real task result, either success or failure"""
667 656 # first, relay result to client
668 657 engine = idents[0]
669 658 client = idents[1]
670 659 # swap_ids for ROUTER-ROUTER mirror
671 660 raw_msg[:2] = [client,engine]
672 661 # print (map(str, raw_msg[:4]))
673 662 self.client_stream.send_multipart(raw_msg, copy=False)
674 663 # now, update our data structures
675 664 msg_id = parent['msg_id']
676 665 self.pending[engine].pop(msg_id)
677 666 if success:
678 667 self.completed[engine].add(msg_id)
679 668 self.all_completed.add(msg_id)
680 669 else:
681 670 self.failed[engine].add(msg_id)
682 671 self.all_failed.add(msg_id)
683 672 self.all_done.add(msg_id)
684 673 self.destinations[msg_id] = engine
685 674
686 675 self.update_graph(msg_id, success)
687 676
688 677 def handle_unmet_dependency(self, idents, parent):
689 678 """handle an unmet dependency"""
690 679 engine = idents[0]
691 680 msg_id = parent['msg_id']
692 681
693 682 job = self.pending[engine].pop(msg_id)
694 683 job.blacklist.add(engine)
695 684
696 685 if job.blacklist == job.targets:
697 686 self.queue_map[msg_id] = job
698 687 self.fail_unreachable(msg_id)
699 688 elif not self.maybe_run(job):
700 689 # resubmit failed
701 690 if msg_id not in self.all_failed:
702 691 # put it back in our dependency tree
703 692 self.save_unmet(job)
704 693
705 694 if self.hwm:
706 695 try:
707 696 idx = self.targets.index(engine)
708 697 except ValueError:
709 698 pass # skip load-update for dead engines
710 699 else:
711 700 if self.loads[idx] == self.hwm-1:
712 701 self.update_graph(None)
713 702
714 703 def update_graph(self, dep_id=None, success=True):
715 704 """dep_id just finished. Update our dependency
716 705 graph and submit any jobs that just became runnable.
717 706
718 707 Called with dep_id=None to update entire graph for hwm, but without finishing a task.
719 708 """
720 709 # print ("\n\n***********")
721 710 # pprint (dep_id)
722 711 # pprint (self.graph)
723 712 # pprint (self.queue_map)
724 713 # pprint (self.all_completed)
725 714 # pprint (self.all_failed)
726 715 # print ("\n\n***********\n\n")
727 716 # update any jobs that depended on the dependency
728 717 msg_ids = self.graph.pop(dep_id, [])
729 718
730 719 # recheck *all* jobs if
731 720 # a) we have HWM and an engine just become no longer full
732 721 # or b) dep_id was given as None
733 722
734 723 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
735 724 jobs = self.queue
736 725 using_queue = True
737 726 else:
738 727 using_queue = False
739 728 jobs = deque(sorted( self.queue_map[msg_id] for msg_id in msg_ids ))
740 729
741 730 to_restore = []
742 731 while jobs:
743 732 job = jobs.popleft()
744 733 if job.removed:
745 734 continue
746 735 msg_id = job.msg_id
747 736
748 737 put_it_back = True
749 738
750 739 if job.after.unreachable(self.all_completed, self.all_failed)\
751 740 or job.follow.unreachable(self.all_completed, self.all_failed):
752 741 self.fail_unreachable(msg_id)
753 742 put_it_back = False
754 743
755 744 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
756 745 if self.maybe_run(job):
757 746 put_it_back = False
758 747 self.queue_map.pop(msg_id)
759 748 for mid in job.dependents:
760 749 if mid in self.graph:
761 750 self.graph[mid].remove(msg_id)
762 751
763 752 # abort the loop if we just filled up all of our engines.
764 753 # avoids an O(N) operation in situation of full queue,
765 754 # where graph update is triggered as soon as an engine becomes
766 755 # non-full, and all tasks after the first are checked,
767 756 # even though they can't run.
768 757 if not self.available_engines():
769 758 break
770 759
771 760 if using_queue and put_it_back:
772 761 # popped a job from the queue but it neither ran nor failed,
773 762 # so we need to put it back when we are done
774 763 # make sure to_restore preserves the same ordering
775 764 to_restore.append(job)
776 765
777 766 # put back any tasks we popped but didn't run
778 767 if using_queue:
779 768 self.queue.extendleft(to_restore)
780 769
781 770 #----------------------------------------------------------------------
782 771 # methods to be overridden by subclasses
783 772 #----------------------------------------------------------------------
784 773
785 774 def add_job(self, idx):
786 775 """Called after self.targets[idx] just got the job with header.
787 776 Override with subclasses. The default ordering is simple LRU.
788 777 The default loads are the number of outstanding jobs."""
789 778 self.loads[idx] += 1
790 779 for lis in (self.targets, self.loads):
791 780 lis.append(lis.pop(idx))
792 781
793 782
794 783 def finish_job(self, idx):
795 784 """Called after self.targets[idx] just finished a job.
796 785 Override with subclasses."""
797 786 self.loads[idx] -= 1
798 787
799 788
800 789
801 790 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, reg_addr, config=None,
802 791 logname='root', log_url=None, loglevel=logging.DEBUG,
803 792 identity=b'task', in_thread=False):
804 793
805 794 ZMQStream = zmqstream.ZMQStream
806 795
807 796 if config:
808 797 # unwrap dict back into Config
809 798 config = Config(config)
810 799
811 800 if in_thread:
812 801 # use instance() to get the same Context/Loop as our parent
813 802 ctx = zmq.Context.instance()
814 803 loop = ioloop.IOLoop.instance()
815 804 else:
816 805 # in a process, don't use instance()
817 806 # for safety with multiprocessing
818 807 ctx = zmq.Context()
819 808 loop = ioloop.IOLoop()
820 809 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
821 810 util.set_hwm(ins, 0)
822 811 ins.setsockopt(zmq.IDENTITY, identity + b'_in')
823 812 ins.bind(in_addr)
824 813
825 814 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
826 815 util.set_hwm(outs, 0)
827 816 outs.setsockopt(zmq.IDENTITY, identity + b'_out')
828 817 outs.bind(out_addr)
829 818 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
830 819 util.set_hwm(mons, 0)
831 820 mons.connect(mon_addr)
832 821 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
833 822 nots.setsockopt(zmq.SUBSCRIBE, b'')
834 823 nots.connect(not_addr)
835 824
836 825 querys = ZMQStream(ctx.socket(zmq.DEALER),loop)
837 826 querys.connect(reg_addr)
838 827
839 828 # setup logging.
840 829 if in_thread:
841 830 log = Application.instance().log
842 831 else:
843 832 if log_url:
844 833 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
845 834 else:
846 835 log = local_logger(logname, loglevel)
847 836
848 837 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
849 838 mon_stream=mons, notifier_stream=nots,
850 839 query_stream=querys,
851 840 loop=loop, log=log,
852 841 config=config)
853 842 scheduler.start()
854 843 if not in_thread:
855 844 try:
856 845 loop.start()
857 846 except KeyboardInterrupt:
858 847 scheduler.log.critical("Interrupted, exiting...")
859 848
@@ -1,388 +1,389 b''
1 1 """Some generic utilities for dealing with classes, urls, and serialization."""
2 2
3 3 # Copyright (c) IPython Development Team.
4 4 # Distributed under the terms of the Modified BSD License.
5 5
6 6 import logging
7 7 import os
8 8 import re
9 9 import stat
10 10 import socket
11 11 import sys
12 12 import warnings
13 13 from signal import signal, SIGINT, SIGABRT, SIGTERM
14 14 try:
15 15 from signal import SIGKILL
16 16 except ImportError:
17 17 SIGKILL=None
18 18 from types import FunctionType
19 19
20 20 try:
21 21 import cPickle
22 22 pickle = cPickle
23 23 except:
24 24 cPickle = None
25 25 import pickle
26 26
27 27 import zmq
28 28 from zmq.log import handlers
29 29
30 from IPython.utils.log import get_logger
30 31 from IPython.external.decorator import decorator
31 32
32 33 from IPython.config.application import Application
33 34 from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips
34 35 from IPython.utils.py3compat import string_types, iteritems, itervalues
35 36 from IPython.kernel.zmq.log import EnginePUBHandler
36 37
37 38
38 39 #-----------------------------------------------------------------------------
39 40 # Classes
40 41 #-----------------------------------------------------------------------------
41 42
42 43 class Namespace(dict):
43 44 """Subclass of dict for attribute access to keys."""
44 45
45 46 def __getattr__(self, key):
46 47 """getattr aliased to getitem"""
47 48 if key in self:
48 49 return self[key]
49 50 else:
50 51 raise NameError(key)
51 52
52 53 def __setattr__(self, key, value):
53 54 """setattr aliased to setitem, with strict"""
54 55 if hasattr(dict, key):
55 56 raise KeyError("Cannot override dict keys %r"%key)
56 57 self[key] = value
57 58
58 59
59 60 class ReverseDict(dict):
60 61 """simple double-keyed subset of dict methods."""
61 62
62 63 def __init__(self, *args, **kwargs):
63 64 dict.__init__(self, *args, **kwargs)
64 65 self._reverse = dict()
65 66 for key, value in iteritems(self):
66 67 self._reverse[value] = key
67 68
68 69 def __getitem__(self, key):
69 70 try:
70 71 return dict.__getitem__(self, key)
71 72 except KeyError:
72 73 return self._reverse[key]
73 74
74 75 def __setitem__(self, key, value):
75 76 if key in self._reverse:
76 77 raise KeyError("Can't have key %r on both sides!"%key)
77 78 dict.__setitem__(self, key, value)
78 79 self._reverse[value] = key
79 80
80 81 def pop(self, key):
81 82 value = dict.pop(self, key)
82 83 self._reverse.pop(value)
83 84 return value
84 85
85 86 def get(self, key, default=None):
86 87 try:
87 88 return self[key]
88 89 except KeyError:
89 90 return default
90 91
91 92 #-----------------------------------------------------------------------------
92 93 # Functions
93 94 #-----------------------------------------------------------------------------
94 95
95 96 @decorator
96 97 def log_errors(f, self, *args, **kwargs):
97 98 """decorator to log unhandled exceptions raised in a method.
98 99
99 100 For use wrapping on_recv callbacks, so that exceptions
100 101 do not cause the stream to be closed.
101 102 """
102 103 try:
103 104 return f(self, *args, **kwargs)
104 105 except Exception:
105 106 self.log.error("Uncaught exception in %r" % f, exc_info=True)
106 107
107 108
108 109 def is_url(url):
109 110 """boolean check for whether a string is a zmq url"""
110 111 if '://' not in url:
111 112 return False
112 113 proto, addr = url.split('://', 1)
113 114 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
114 115 return False
115 116 return True
116 117
117 118 def validate_url(url):
118 119 """validate a url for zeromq"""
119 120 if not isinstance(url, string_types):
120 121 raise TypeError("url must be a string, not %r"%type(url))
121 122 url = url.lower()
122 123
123 124 proto_addr = url.split('://')
124 125 assert len(proto_addr) == 2, 'Invalid url: %r'%url
125 126 proto, addr = proto_addr
126 127 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
127 128
128 129 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
129 130 # author: Remi Sabourin
130 131 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
131 132
132 133 if proto == 'tcp':
133 134 lis = addr.split(':')
134 135 assert len(lis) == 2, 'Invalid url: %r'%url
135 136 addr,s_port = lis
136 137 try:
137 138 port = int(s_port)
138 139 except ValueError:
139 140 raise AssertionError("Invalid port %r in url: %r"%(port, url))
140 141
141 142 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
142 143
143 144 else:
144 145 # only validate tcp urls currently
145 146 pass
146 147
147 148 return True
148 149
149 150
150 151 def validate_url_container(container):
151 152 """validate a potentially nested collection of urls."""
152 153 if isinstance(container, string_types):
153 154 url = container
154 155 return validate_url(url)
155 156 elif isinstance(container, dict):
156 157 container = itervalues(container)
157 158
158 159 for element in container:
159 160 validate_url_container(element)
160 161
161 162
162 163 def split_url(url):
163 164 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
164 165 proto_addr = url.split('://')
165 166 assert len(proto_addr) == 2, 'Invalid url: %r'%url
166 167 proto, addr = proto_addr
167 168 lis = addr.split(':')
168 169 assert len(lis) == 2, 'Invalid url: %r'%url
169 170 addr,s_port = lis
170 171 return proto,addr,s_port
171 172
172 173
173 174 def disambiguate_ip_address(ip, location=None):
174 175 """turn multi-ip interfaces '0.0.0.0' and '*' into a connectable address
175 176
176 177 Explicit IP addresses are returned unmodified.
177 178
178 179 Parameters
179 180 ----------
180 181
181 182 ip : IP address
182 183 An IP address, or the special values 0.0.0.0, or *
183 184 location: IP address, optional
184 185 A public IP of the target machine.
185 186 If location is an IP of the current machine,
186 187 localhost will be returned,
187 188 otherwise location will be returned.
188 189 """
189 190 if ip in {'0.0.0.0', '*'}:
190 191 if not location:
191 192 # unspecified location, localhost is the only choice
192 193 ip = localhost()
193 194 elif is_public_ip(location):
194 195 # location is a public IP on this machine, use localhost
195 196 ip = localhost()
196 197 elif not public_ips():
197 198 # this machine's public IPs cannot be determined,
198 199 # assume `location` is not this machine
199 200 warnings.warn("IPython could not determine public IPs", RuntimeWarning)
200 201 ip = location
201 202 else:
202 203 # location is not this machine, do not use loopback
203 204 ip = location
204 205 return ip
205 206
206 207
207 208 def disambiguate_url(url, location=None):
208 209 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
209 210 ones, based on the location (default interpretation is localhost).
210 211
211 212 This is for zeromq urls, such as ``tcp://*:10101``.
212 213 """
213 214 try:
214 215 proto,ip,port = split_url(url)
215 216 except AssertionError:
216 217 # probably not tcp url; could be ipc, etc.
217 218 return url
218 219
219 220 ip = disambiguate_ip_address(ip,location)
220 221
221 222 return "%s://%s:%s"%(proto,ip,port)
222 223
223 224
224 225 #--------------------------------------------------------------------------
225 226 # helpers for implementing old MEC API via view.apply
226 227 #--------------------------------------------------------------------------
227 228
228 229 def interactive(f):
229 230 """decorator for making functions appear as interactively defined.
230 231 This results in the function being linked to the user_ns as globals()
231 232 instead of the module globals().
232 233 """
233 234
234 235 # build new FunctionType, so it can have the right globals
235 236 # interactive functions never have closures, that's kind of the point
236 237 if isinstance(f, FunctionType):
237 238 mainmod = __import__('__main__')
238 239 f = FunctionType(f.__code__, mainmod.__dict__,
239 240 f.__name__, f.__defaults__,
240 241 )
241 242 # associate with __main__ for uncanning
242 243 f.__module__ = '__main__'
243 244 return f
244 245
245 246 @interactive
246 247 def _push(**ns):
247 248 """helper method for implementing `client.push` via `client.apply`"""
248 249 user_ns = globals()
249 250 tmp = '_IP_PUSH_TMP_'
250 251 while tmp in user_ns:
251 252 tmp = tmp + '_'
252 253 try:
253 254 for name, value in ns.items():
254 255 user_ns[tmp] = value
255 256 exec("%s = %s" % (name, tmp), user_ns)
256 257 finally:
257 258 user_ns.pop(tmp, None)
258 259
259 260 @interactive
260 261 def _pull(keys):
261 262 """helper method for implementing `client.pull` via `client.apply`"""
262 263 if isinstance(keys, (list,tuple, set)):
263 264 return [eval(key, globals()) for key in keys]
264 265 else:
265 266 return eval(keys, globals())
266 267
267 268 @interactive
268 269 def _execute(code):
269 270 """helper method for implementing `client.execute` via `client.apply`"""
270 271 exec(code, globals())
271 272
272 273 #--------------------------------------------------------------------------
273 274 # extra process management utilities
274 275 #--------------------------------------------------------------------------
275 276
276 277 _random_ports = set()
277 278
278 279 def select_random_ports(n):
279 280 """Selects and return n random ports that are available."""
280 281 ports = []
281 282 for i in range(n):
282 283 sock = socket.socket()
283 284 sock.bind(('', 0))
284 285 while sock.getsockname()[1] in _random_ports:
285 286 sock.close()
286 287 sock = socket.socket()
287 288 sock.bind(('', 0))
288 289 ports.append(sock)
289 290 for i, sock in enumerate(ports):
290 291 port = sock.getsockname()[1]
291 292 sock.close()
292 293 ports[i] = port
293 294 _random_ports.add(port)
294 295 return ports
295 296
296 297 def signal_children(children):
297 298 """Relay interupt/term signals to children, for more solid process cleanup."""
298 299 def terminate_children(sig, frame):
299 log = Application.instance().log
300 log = get_logger()
300 301 log.critical("Got signal %i, terminating children..."%sig)
301 302 for child in children:
302 303 child.terminate()
303 304
304 305 sys.exit(sig != SIGINT)
305 306 # sys.exit(sig)
306 307 for sig in (SIGINT, SIGABRT, SIGTERM):
307 308 signal(sig, terminate_children)
308 309
309 310 def generate_exec_key(keyfile):
310 311 import uuid
311 312 newkey = str(uuid.uuid4())
312 313 with open(keyfile, 'w') as f:
313 314 # f.write('ipython-key ')
314 315 f.write(newkey+'\n')
315 316 # set user-only RW permissions (0600)
316 317 # this will have no effect on Windows
317 318 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
318 319
319 320
320 321 def integer_loglevel(loglevel):
321 322 try:
322 323 loglevel = int(loglevel)
323 324 except ValueError:
324 325 if isinstance(loglevel, str):
325 326 loglevel = getattr(logging, loglevel)
326 327 return loglevel
327 328
328 329 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
329 330 logger = logging.getLogger(logname)
330 331 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
331 332 # don't add a second PUBHandler
332 333 return
333 334 loglevel = integer_loglevel(loglevel)
334 335 lsock = context.socket(zmq.PUB)
335 336 lsock.connect(iface)
336 337 handler = handlers.PUBHandler(lsock)
337 338 handler.setLevel(loglevel)
338 339 handler.root_topic = root
339 340 logger.addHandler(handler)
340 341 logger.setLevel(loglevel)
341 342
342 343 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
343 344 logger = logging.getLogger()
344 345 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
345 346 # don't add a second PUBHandler
346 347 return
347 348 loglevel = integer_loglevel(loglevel)
348 349 lsock = context.socket(zmq.PUB)
349 350 lsock.connect(iface)
350 351 handler = EnginePUBHandler(engine, lsock)
351 352 handler.setLevel(loglevel)
352 353 logger.addHandler(handler)
353 354 logger.setLevel(loglevel)
354 355 return logger
355 356
356 357 def local_logger(logname, loglevel=logging.DEBUG):
357 358 loglevel = integer_loglevel(loglevel)
358 359 logger = logging.getLogger(logname)
359 360 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
360 361 # don't add a second StreamHandler
361 362 return
362 363 handler = logging.StreamHandler()
363 364 handler.setLevel(loglevel)
364 365 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
365 366 datefmt="%Y-%m-%d %H:%M:%S")
366 367 handler.setFormatter(formatter)
367 368
368 369 logger.addHandler(handler)
369 370 logger.setLevel(loglevel)
370 371 return logger
371 372
372 373 def set_hwm(sock, hwm=0):
373 374 """set zmq High Water Mark on a socket
374 375
375 376 in a way that always works for various pyzmq / libzmq versions.
376 377 """
377 378 import zmq
378 379
379 380 for key in ('HWM', 'SNDHWM', 'RCVHWM'):
380 381 opt = getattr(zmq, key, None)
381 382 if opt is None:
382 383 continue
383 384 try:
384 385 sock.setsockopt(opt, hwm)
385 386 except zmq.ZMQError:
386 387 pass
387 388
388 389
@@ -1,433 +1,420 b''
1 1 # encoding: utf-8
2 2 """Pickle related utilities. Perhaps this should be called 'can'."""
3 3
4 4 # Copyright (c) IPython Development Team.
5 5 # Distributed under the terms of the Modified BSD License.
6 6
7 7 import copy
8 8 import logging
9 9 import sys
10 10 from types import FunctionType
11 11
12 12 try:
13 13 import cPickle as pickle
14 14 except ImportError:
15 15 import pickle
16 16
17 17 from . import codeutil # This registers a hook when it's imported
18 18 from . import py3compat
19 19 from .importstring import import_item
20 20 from .py3compat import string_types, iteritems
21 21
22 22 from IPython.config import Application
23 from IPython.utils.log import get_logger
23 24
24 25 if py3compat.PY3:
25 26 buffer = memoryview
26 27 class_type = type
27 28 else:
28 29 from types import ClassType
29 30 class_type = (type, ClassType)
30 31
31 32 def _get_cell_type(a=None):
32 33 """the type of a closure cell doesn't seem to be importable,
33 34 so just create one
34 35 """
35 36 def inner():
36 37 return a
37 38 return type(py3compat.get_closure(inner)[0])
38 39
39 40 cell_type = _get_cell_type()
40 41
41 42 #-------------------------------------------------------------------------------
42 43 # Functions
43 44 #-------------------------------------------------------------------------------
44 45
45 46
46 47 def use_dill():
47 48 """use dill to expand serialization support
48 49
49 50 adds support for object methods and closures to serialization.
50 51 """
51 52 # import dill causes most of the magic
52 53 import dill
53 54
54 55 # dill doesn't work with cPickle,
55 56 # tell the two relevant modules to use plain pickle
56 57
57 58 global pickle
58 59 pickle = dill
59 60
60 61 try:
61 62 from IPython.kernel.zmq import serialize
62 63 except ImportError:
63 64 pass
64 65 else:
65 66 serialize.pickle = dill
66 67
67 68 # disable special function handling, let dill take care of it
68 69 can_map.pop(FunctionType, None)
69 70
70 71 def use_cloudpickle():
71 72 """use cloudpickle to expand serialization support
72 73
73 74 adds support for object methods and closures to serialization.
74 75 """
75 76 from cloud.serialization import cloudpickle
76 77
77 78 global pickle
78 79 pickle = cloudpickle
79 80
80 81 try:
81 82 from IPython.kernel.zmq import serialize
82 83 except ImportError:
83 84 pass
84 85 else:
85 86 serialize.pickle = cloudpickle
86 87
87 88 # disable special function handling, let cloudpickle take care of it
88 89 can_map.pop(FunctionType, None)
89 90
90 91
91 92 #-------------------------------------------------------------------------------
92 93 # Classes
93 94 #-------------------------------------------------------------------------------
94 95
95 96
96 97 class CannedObject(object):
97 98 def __init__(self, obj, keys=[], hook=None):
98 99 """can an object for safe pickling
99 100
100 101 Parameters
101 102 ==========
102 103
103 104 obj:
104 105 The object to be canned
105 106 keys: list (optional)
106 107 list of attribute names that will be explicitly canned / uncanned
107 108 hook: callable (optional)
108 109 An optional extra callable,
109 110 which can do additional processing of the uncanned object.
110 111
111 112 large data may be offloaded into the buffers list,
112 113 used for zero-copy transfers.
113 114 """
114 115 self.keys = keys
115 116 self.obj = copy.copy(obj)
116 117 self.hook = can(hook)
117 118 for key in keys:
118 119 setattr(self.obj, key, can(getattr(obj, key)))
119 120
120 121 self.buffers = []
121 122
122 123 def get_object(self, g=None):
123 124 if g is None:
124 125 g = {}
125 126 obj = self.obj
126 127 for key in self.keys:
127 128 setattr(obj, key, uncan(getattr(obj, key), g))
128 129
129 130 if self.hook:
130 131 self.hook = uncan(self.hook, g)
131 132 self.hook(obj, g)
132 133 return self.obj
133 134
134 135
135 136 class Reference(CannedObject):
136 137 """object for wrapping a remote reference by name."""
137 138 def __init__(self, name):
138 139 if not isinstance(name, string_types):
139 140 raise TypeError("illegal name: %r"%name)
140 141 self.name = name
141 142 self.buffers = []
142 143
143 144 def __repr__(self):
144 145 return "<Reference: %r>"%self.name
145 146
146 147 def get_object(self, g=None):
147 148 if g is None:
148 149 g = {}
149 150
150 151 return eval(self.name, g)
151 152
152 153
153 154 class CannedCell(CannedObject):
154 155 """Can a closure cell"""
155 156 def __init__(self, cell):
156 157 self.cell_contents = can(cell.cell_contents)
157 158
158 159 def get_object(self, g=None):
159 160 cell_contents = uncan(self.cell_contents, g)
160 161 def inner():
161 162 return cell_contents
162 163 return py3compat.get_closure(inner)[0]
163 164
164 165
165 166 class CannedFunction(CannedObject):
166 167
167 168 def __init__(self, f):
168 169 self._check_type(f)
169 170 self.code = f.__code__
170 171 if f.__defaults__:
171 172 self.defaults = [ can(fd) for fd in f.__defaults__ ]
172 173 else:
173 174 self.defaults = None
174 175
175 176 closure = py3compat.get_closure(f)
176 177 if closure:
177 178 self.closure = tuple( can(cell) for cell in closure )
178 179 else:
179 180 self.closure = None
180 181
181 182 self.module = f.__module__ or '__main__'
182 183 self.__name__ = f.__name__
183 184 self.buffers = []
184 185
185 186 def _check_type(self, obj):
186 187 assert isinstance(obj, FunctionType), "Not a function type"
187 188
188 189 def get_object(self, g=None):
189 190 # try to load function back into its module:
190 191 if not self.module.startswith('__'):
191 192 __import__(self.module)
192 193 g = sys.modules[self.module].__dict__
193 194
194 195 if g is None:
195 196 g = {}
196 197 if self.defaults:
197 198 defaults = tuple(uncan(cfd, g) for cfd in self.defaults)
198 199 else:
199 200 defaults = None
200 201 if self.closure:
201 202 closure = tuple(uncan(cell, g) for cell in self.closure)
202 203 else:
203 204 closure = None
204 205 newFunc = FunctionType(self.code, g, self.__name__, defaults, closure)
205 206 return newFunc
206 207
207 208 class CannedClass(CannedObject):
208 209
209 210 def __init__(self, cls):
210 211 self._check_type(cls)
211 212 self.name = cls.__name__
212 213 self.old_style = not isinstance(cls, type)
213 214 self._canned_dict = {}
214 215 for k,v in cls.__dict__.items():
215 216 if k not in ('__weakref__', '__dict__'):
216 217 self._canned_dict[k] = can(v)
217 218 if self.old_style:
218 219 mro = []
219 220 else:
220 221 mro = cls.mro()
221 222
222 223 self.parents = [ can(c) for c in mro[1:] ]
223 224 self.buffers = []
224 225
225 226 def _check_type(self, obj):
226 227 assert isinstance(obj, class_type), "Not a class type"
227 228
228 229 def get_object(self, g=None):
229 230 parents = tuple(uncan(p, g) for p in self.parents)
230 231 return type(self.name, parents, uncan_dict(self._canned_dict, g=g))
231 232
232 233 class CannedArray(CannedObject):
233 234 def __init__(self, obj):
234 235 from numpy import ascontiguousarray
235 236 self.shape = obj.shape
236 237 self.dtype = obj.dtype.descr if obj.dtype.fields else obj.dtype.str
237 238 self.pickled = False
238 239 if sum(obj.shape) == 0:
239 240 self.pickled = True
240 241 elif obj.dtype == 'O':
241 242 # can't handle object dtype with buffer approach
242 243 self.pickled = True
243 244 elif obj.dtype.fields and any(dt == 'O' for dt,sz in obj.dtype.fields.values()):
244 245 self.pickled = True
245 246 if self.pickled:
246 247 # just pickle it
247 248 self.buffers = [pickle.dumps(obj, -1)]
248 249 else:
249 250 # ensure contiguous
250 251 obj = ascontiguousarray(obj, dtype=None)
251 252 self.buffers = [buffer(obj)]
252 253
253 254 def get_object(self, g=None):
254 255 from numpy import frombuffer
255 256 data = self.buffers[0]
256 257 if self.pickled:
257 258 # no shape, we just pickled it
258 259 return pickle.loads(data)
259 260 else:
260 261 return frombuffer(data, dtype=self.dtype).reshape(self.shape)
261 262
262 263
263 264 class CannedBytes(CannedObject):
264 265 wrap = bytes
265 266 def __init__(self, obj):
266 267 self.buffers = [obj]
267 268
268 269 def get_object(self, g=None):
269 270 data = self.buffers[0]
270 271 return self.wrap(data)
271 272
272 273 def CannedBuffer(CannedBytes):
273 274 wrap = buffer
274 275
275 276 #-------------------------------------------------------------------------------
276 277 # Functions
277 278 #-------------------------------------------------------------------------------
278 279
279 def _logger():
280 """get the logger for the current Application
281
282 the root logger will be used if no Application is running
283 """
284 if Application.initialized():
285 logger = Application.instance().log
286 else:
287 logger = logging.getLogger()
288 if not logger.handlers:
289 logging.basicConfig()
290
291 return logger
292
293 280 def _import_mapping(mapping, original=None):
294 281 """import any string-keys in a type mapping
295 282
296 283 """
297 log = _logger()
284 log = get_logger()
298 285 log.debug("Importing canning map")
299 286 for key,value in list(mapping.items()):
300 287 if isinstance(key, string_types):
301 288 try:
302 289 cls = import_item(key)
303 290 except Exception:
304 291 if original and key not in original:
305 292 # only message on user-added classes
306 293 log.error("canning class not importable: %r", key, exc_info=True)
307 294 mapping.pop(key)
308 295 else:
309 296 mapping[cls] = mapping.pop(key)
310 297
311 298 def istype(obj, check):
312 299 """like isinstance(obj, check), but strict
313 300
314 301 This won't catch subclasses.
315 302 """
316 303 if isinstance(check, tuple):
317 304 for cls in check:
318 305 if type(obj) is cls:
319 306 return True
320 307 return False
321 308 else:
322 309 return type(obj) is check
323 310
324 311 def can(obj):
325 312 """prepare an object for pickling"""
326 313
327 314 import_needed = False
328 315
329 316 for cls,canner in iteritems(can_map):
330 317 if isinstance(cls, string_types):
331 318 import_needed = True
332 319 break
333 320 elif istype(obj, cls):
334 321 return canner(obj)
335 322
336 323 if import_needed:
337 324 # perform can_map imports, then try again
338 325 # this will usually only happen once
339 326 _import_mapping(can_map, _original_can_map)
340 327 return can(obj)
341 328
342 329 return obj
343 330
344 331 def can_class(obj):
345 332 if isinstance(obj, class_type) and obj.__module__ == '__main__':
346 333 return CannedClass(obj)
347 334 else:
348 335 return obj
349 336
350 337 def can_dict(obj):
351 338 """can the *values* of a dict"""
352 339 if istype(obj, dict):
353 340 newobj = {}
354 341 for k, v in iteritems(obj):
355 342 newobj[k] = can(v)
356 343 return newobj
357 344 else:
358 345 return obj
359 346
360 347 sequence_types = (list, tuple, set)
361 348
362 349 def can_sequence(obj):
363 350 """can the elements of a sequence"""
364 351 if istype(obj, sequence_types):
365 352 t = type(obj)
366 353 return t([can(i) for i in obj])
367 354 else:
368 355 return obj
369 356
370 357 def uncan(obj, g=None):
371 358 """invert canning"""
372 359
373 360 import_needed = False
374 361 for cls,uncanner in iteritems(uncan_map):
375 362 if isinstance(cls, string_types):
376 363 import_needed = True
377 364 break
378 365 elif isinstance(obj, cls):
379 366 return uncanner(obj, g)
380 367
381 368 if import_needed:
382 369 # perform uncan_map imports, then try again
383 370 # this will usually only happen once
384 371 _import_mapping(uncan_map, _original_uncan_map)
385 372 return uncan(obj, g)
386 373
387 374 return obj
388 375
389 376 def uncan_dict(obj, g=None):
390 377 if istype(obj, dict):
391 378 newobj = {}
392 379 for k, v in iteritems(obj):
393 380 newobj[k] = uncan(v,g)
394 381 return newobj
395 382 else:
396 383 return obj
397 384
398 385 def uncan_sequence(obj, g=None):
399 386 if istype(obj, sequence_types):
400 387 t = type(obj)
401 388 return t([uncan(i,g) for i in obj])
402 389 else:
403 390 return obj
404 391
405 392 def _uncan_dependent_hook(dep, g=None):
406 393 dep.check_dependency()
407 394
408 395 def can_dependent(obj):
409 396 return CannedObject(obj, keys=('f', 'df'), hook=_uncan_dependent_hook)
410 397
411 398 #-------------------------------------------------------------------------------
412 399 # API dictionaries
413 400 #-------------------------------------------------------------------------------
414 401
415 402 # These dicts can be extended for custom serialization of new objects
416 403
417 404 can_map = {
418 405 'IPython.parallel.dependent' : can_dependent,
419 406 'numpy.ndarray' : CannedArray,
420 407 FunctionType : CannedFunction,
421 408 bytes : CannedBytes,
422 409 buffer : CannedBuffer,
423 410 cell_type : CannedCell,
424 411 class_type : can_class,
425 412 }
426 413
427 414 uncan_map = {
428 415 CannedObject : lambda obj, g: obj.get_object(g),
429 416 }
430 417
431 418 # for use in _import_mapping:
432 419 _original_can_map = can_map.copy()
433 420 _original_uncan_map = uncan_map.copy()
General Comments 0
You need to be logged in to leave comments. Login now