##// END OF EJS Templates
Merge pull request #4276 from minrk/config-extend...
Thomas Kluyver -
r12934:61716e73 merge
parent child Browse files
Show More
@@ -0,0 +1,15 b''
1 Extending Configurable Containers
2 ---------------------------------
3
4 Some configurable traits are containers (list, dict, set)
5 Config objects now support calling ``extend``, ``update``, ``insert``, etc.
6 on traits in config files, which will ultimately result in calling
7 those methods on the original object.
8
9 The effect being that you can now add to containers without having to copy/paste
10 the initial value::
11
12 c = get_config()
13 c.InlineBackend.rc.update({ 'figure.figsize' : (6, 4) })
14
15
@@ -1,382 +1,387 b''
1 1 # encoding: utf-8
2 2 """
3 3 A base class for objects that are configurable.
4 4
5 5 Inheritance diagram:
6 6
7 7 .. inheritance-diagram:: IPython.config.configurable
8 8 :parts: 3
9 9
10 10 Authors:
11 11
12 12 * Brian Granger
13 13 * Fernando Perez
14 14 * Min RK
15 15 """
16 16
17 17 #-----------------------------------------------------------------------------
18 18 # Copyright (C) 2008-2011 The IPython Development Team
19 19 #
20 20 # Distributed under the terms of the BSD License. The full license is in
21 21 # the file COPYING, distributed as part of this software.
22 22 #-----------------------------------------------------------------------------
23 23
24 24 #-----------------------------------------------------------------------------
25 25 # Imports
26 26 #-----------------------------------------------------------------------------
27 27
28 28 import datetime
29 29 from copy import deepcopy
30 30
31 from loader import Config
31 from .loader import Config, LazyConfigValue
32 32 from IPython.utils.traitlets import HasTraits, Instance
33 33 from IPython.utils.text import indent, wrap_paragraphs
34 34
35 35
36 36 #-----------------------------------------------------------------------------
37 37 # Helper classes for Configurables
38 38 #-----------------------------------------------------------------------------
39 39
40 40
41 41 class ConfigurableError(Exception):
42 42 pass
43 43
44 44
45 45 class MultipleInstanceError(ConfigurableError):
46 46 pass
47 47
48 48 #-----------------------------------------------------------------------------
49 49 # Configurable implementation
50 50 #-----------------------------------------------------------------------------
51 51
52 52 class Configurable(HasTraits):
53 53
54 54 config = Instance(Config, (), {})
55 55 parent = Instance('IPython.config.configurable.Configurable')
56 56 created = None
57 57
58 58 def __init__(self, **kwargs):
59 59 """Create a configurable given a config config.
60 60
61 61 Parameters
62 62 ----------
63 63 config : Config
64 64 If this is empty, default values are used. If config is a
65 65 :class:`Config` instance, it will be used to configure the
66 66 instance.
67 67 parent : Configurable instance, optional
68 68 The parent Configurable instance of this object.
69 69
70 70 Notes
71 71 -----
72 72 Subclasses of Configurable must call the :meth:`__init__` method of
73 73 :class:`Configurable` *before* doing anything else and using
74 74 :func:`super`::
75 75
76 76 class MyConfigurable(Configurable):
77 77 def __init__(self, config=None):
78 78 super(MyConfigurable, self).__init__(config=config)
79 79 # Then any other code you need to finish initialization.
80 80
81 81 This ensures that instances will be configured properly.
82 82 """
83 83 parent = kwargs.pop('parent', None)
84 84 if parent is not None:
85 85 # config is implied from parent
86 86 if kwargs.get('config', None) is None:
87 87 kwargs['config'] = parent.config
88 88 self.parent = parent
89 89
90 90 config = kwargs.pop('config', None)
91 91 if config is not None:
92 92 # We used to deepcopy, but for now we are trying to just save
93 93 # by reference. This *could* have side effects as all components
94 94 # will share config. In fact, I did find such a side effect in
95 95 # _config_changed below. If a config attribute value was a mutable type
96 96 # all instances of a component were getting the same copy, effectively
97 97 # making that a class attribute.
98 98 # self.config = deepcopy(config)
99 99 self.config = config
100 100 # This should go second so individual keyword arguments override
101 101 # the values in config.
102 102 super(Configurable, self).__init__(**kwargs)
103 103 self.created = datetime.datetime.now()
104 104
105 105 #-------------------------------------------------------------------------
106 106 # Static trait notifiations
107 107 #-------------------------------------------------------------------------
108 108
109 109 @classmethod
110 110 def section_names(cls):
111 111 """return section names as a list"""
112 112 return [c.__name__ for c in reversed(cls.__mro__) if
113 113 issubclass(c, Configurable) and issubclass(cls, c)
114 114 ]
115 115
116 116 def _find_my_config(self, cfg):
117 117 """extract my config from a global Config object
118 118
119 119 will construct a Config object of only the config values that apply to me
120 120 based on my mro(), as well as those of my parent(s) if they exist.
121 121
122 122 If I am Bar and my parent is Foo, and their parent is Tim,
123 123 this will return merge following config sections, in this order::
124 124
125 125 [Bar, Foo.bar, Tim.Foo.Bar]
126 126
127 127 With the last item being the highest priority.
128 128 """
129 129 cfgs = [cfg]
130 130 if self.parent:
131 131 cfgs.append(self.parent._find_my_config(cfg))
132 132 my_config = Config()
133 133 for c in cfgs:
134 134 for sname in self.section_names():
135 135 # Don't do a blind getattr as that would cause the config to
136 136 # dynamically create the section with name Class.__name__.
137 137 if c._has_section(sname):
138 138 my_config.merge(c[sname])
139 139 return my_config
140
140
141 141 def _load_config(self, cfg, section_names=None, traits=None):
142 142 """load traits from a Config object"""
143 143
144 144 if traits is None:
145 145 traits = self.traits(config=True)
146 146 if section_names is None:
147 147 section_names = self.section_names()
148 148
149 149 my_config = self._find_my_config(cfg)
150 150 for name, config_value in my_config.iteritems():
151 151 if name in traits:
152 if isinstance(config_value, LazyConfigValue):
153 # ConfigValue is a wrapper for using append / update on containers
154 # without having to copy the
155 initial = getattr(self, name)
156 config_value = config_value.get_value(initial)
152 157 # We have to do a deepcopy here if we don't deepcopy the entire
153 158 # config object. If we don't, a mutable config_value will be
154 159 # shared by all instances, effectively making it a class attribute.
155 160 setattr(self, name, deepcopy(config_value))
156 161
157 162 def _config_changed(self, name, old, new):
158 163 """Update all the class traits having ``config=True`` as metadata.
159 164
160 165 For any class trait with a ``config`` metadata attribute that is
161 166 ``True``, we update the trait with the value of the corresponding
162 167 config entry.
163 168 """
164 169 # Get all traits with a config metadata entry that is True
165 170 traits = self.traits(config=True)
166 171
167 172 # We auto-load config section for this class as well as any parent
168 173 # classes that are Configurable subclasses. This starts with Configurable
169 174 # and works down the mro loading the config for each section.
170 175 section_names = self.section_names()
171 176 self._load_config(new, traits=traits, section_names=section_names)
172 177
173 178 def update_config(self, config):
174 179 """Fire the traits events when the config is updated."""
175 180 # Save a copy of the current config.
176 181 newconfig = deepcopy(self.config)
177 182 # Merge the new config into the current one.
178 183 newconfig.merge(config)
179 184 # Save the combined config as self.config, which triggers the traits
180 185 # events.
181 186 self.config = newconfig
182 187
183 188 @classmethod
184 189 def class_get_help(cls, inst=None):
185 190 """Get the help string for this class in ReST format.
186 191
187 192 If `inst` is given, it's current trait values will be used in place of
188 193 class defaults.
189 194 """
190 195 assert inst is None or isinstance(inst, cls)
191 196 final_help = []
192 197 final_help.append(u'%s options' % cls.__name__)
193 198 final_help.append(len(final_help[0])*u'-')
194 199 for k, v in sorted(cls.class_traits(config=True).iteritems()):
195 200 help = cls.class_get_trait_help(v, inst)
196 201 final_help.append(help)
197 202 return '\n'.join(final_help)
198 203
199 204 @classmethod
200 205 def class_get_trait_help(cls, trait, inst=None):
201 206 """Get the help string for a single trait.
202 207
203 208 If `inst` is given, it's current trait values will be used in place of
204 209 the class default.
205 210 """
206 211 assert inst is None or isinstance(inst, cls)
207 212 lines = []
208 213 header = "--%s.%s=<%s>" % (cls.__name__, trait.name, trait.__class__.__name__)
209 214 lines.append(header)
210 215 if inst is not None:
211 216 lines.append(indent('Current: %r' % getattr(inst, trait.name), 4))
212 217 else:
213 218 try:
214 219 dvr = repr(trait.get_default_value())
215 220 except Exception:
216 221 dvr = None # ignore defaults we can't construct
217 222 if dvr is not None:
218 223 if len(dvr) > 64:
219 224 dvr = dvr[:61]+'...'
220 225 lines.append(indent('Default: %s' % dvr, 4))
221 226 if 'Enum' in trait.__class__.__name__:
222 227 # include Enum choices
223 228 lines.append(indent('Choices: %r' % (trait.values,)))
224 229
225 230 help = trait.get_metadata('help')
226 231 if help is not None:
227 232 help = '\n'.join(wrap_paragraphs(help, 76))
228 233 lines.append(indent(help, 4))
229 234 return '\n'.join(lines)
230 235
231 236 @classmethod
232 237 def class_print_help(cls, inst=None):
233 238 """Get the help string for a single trait and print it."""
234 239 print cls.class_get_help(inst)
235 240
236 241 @classmethod
237 242 def class_config_section(cls):
238 243 """Get the config class config section"""
239 244 def c(s):
240 245 """return a commented, wrapped block."""
241 246 s = '\n\n'.join(wrap_paragraphs(s, 78))
242 247
243 248 return '# ' + s.replace('\n', '\n# ')
244 249
245 250 # section header
246 251 breaker = '#' + '-'*78
247 252 s = "# %s configuration" % cls.__name__
248 253 lines = [breaker, s, breaker, '']
249 254 # get the description trait
250 255 desc = cls.class_traits().get('description')
251 256 if desc:
252 257 desc = desc.default_value
253 258 else:
254 259 # no description trait, use __doc__
255 260 desc = getattr(cls, '__doc__', '')
256 261 if desc:
257 262 lines.append(c(desc))
258 263 lines.append('')
259 264
260 265 parents = []
261 266 for parent in cls.mro():
262 267 # only include parents that are not base classes
263 268 # and are not the class itself
264 269 # and have some configurable traits to inherit
265 270 if parent is not cls and issubclass(parent, Configurable) and \
266 271 parent.class_traits(config=True):
267 272 parents.append(parent)
268 273
269 274 if parents:
270 275 pstr = ', '.join([ p.__name__ for p in parents ])
271 276 lines.append(c('%s will inherit config from: %s'%(cls.__name__, pstr)))
272 277 lines.append('')
273 278
274 279 for name, trait in cls.class_traits(config=True).iteritems():
275 280 help = trait.get_metadata('help') or ''
276 281 lines.append(c(help))
277 282 lines.append('# c.%s.%s = %r'%(cls.__name__, name, trait.get_default_value()))
278 283 lines.append('')
279 284 return '\n'.join(lines)
280 285
281 286
282 287
283 288 class SingletonConfigurable(Configurable):
284 289 """A configurable that only allows one instance.
285 290
286 291 This class is for classes that should only have one instance of itself
287 292 or *any* subclass. To create and retrieve such a class use the
288 293 :meth:`SingletonConfigurable.instance` method.
289 294 """
290 295
291 296 _instance = None
292 297
293 298 @classmethod
294 299 def _walk_mro(cls):
295 300 """Walk the cls.mro() for parent classes that are also singletons
296 301
297 302 For use in instance()
298 303 """
299 304
300 305 for subclass in cls.mro():
301 306 if issubclass(cls, subclass) and \
302 307 issubclass(subclass, SingletonConfigurable) and \
303 308 subclass != SingletonConfigurable:
304 309 yield subclass
305 310
306 311 @classmethod
307 312 def clear_instance(cls):
308 313 """unset _instance for this class and singleton parents.
309 314 """
310 315 if not cls.initialized():
311 316 return
312 317 for subclass in cls._walk_mro():
313 318 if isinstance(subclass._instance, cls):
314 319 # only clear instances that are instances
315 320 # of the calling class
316 321 subclass._instance = None
317 322
318 323 @classmethod
319 324 def instance(cls, *args, **kwargs):
320 325 """Returns a global instance of this class.
321 326
322 327 This method create a new instance if none have previously been created
323 328 and returns a previously created instance is one already exists.
324 329
325 330 The arguments and keyword arguments passed to this method are passed
326 331 on to the :meth:`__init__` method of the class upon instantiation.
327 332
328 333 Examples
329 334 --------
330 335
331 336 Create a singleton class using instance, and retrieve it::
332 337
333 338 >>> from IPython.config.configurable import SingletonConfigurable
334 339 >>> class Foo(SingletonConfigurable): pass
335 340 >>> foo = Foo.instance()
336 341 >>> foo == Foo.instance()
337 342 True
338 343
339 344 Create a subclass that is retrived using the base class instance::
340 345
341 346 >>> class Bar(SingletonConfigurable): pass
342 347 >>> class Bam(Bar): pass
343 348 >>> bam = Bam.instance()
344 349 >>> bam == Bar.instance()
345 350 True
346 351 """
347 352 # Create and save the instance
348 353 if cls._instance is None:
349 354 inst = cls(*args, **kwargs)
350 355 # Now make sure that the instance will also be returned by
351 356 # parent classes' _instance attribute.
352 357 for subclass in cls._walk_mro():
353 358 subclass._instance = inst
354 359
355 360 if isinstance(cls._instance, cls):
356 361 return cls._instance
357 362 else:
358 363 raise MultipleInstanceError(
359 364 'Multiple incompatible subclass instances of '
360 365 '%s are being created.' % cls.__name__
361 366 )
362 367
363 368 @classmethod
364 369 def initialized(cls):
365 370 """Has an instance been created?"""
366 371 return hasattr(cls, "_instance") and cls._instance is not None
367 372
368 373
369 374 class LoggingConfigurable(Configurable):
370 375 """A parent class for Configurables that log.
371 376
372 377 Subclasses have a log trait, and the default behavior
373 378 is to get the logger from the currently running Application
374 379 via Application.instance().log.
375 380 """
376 381
377 382 log = Instance('logging.Logger')
378 383 def _log_default(self):
379 384 from IPython.config.application import Application
380 385 return Application.instance().log
381 386
382 387
@@ -1,718 +1,821 b''
1 1 """A simple configuration system.
2 2
3 3 Inheritance diagram:
4 4
5 5 .. inheritance-diagram:: IPython.config.loader
6 6 :parts: 3
7 7
8 8 Authors
9 9 -------
10 10 * Brian Granger
11 11 * Fernando Perez
12 12 * Min RK
13 13 """
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Copyright (C) 2008-2011 The IPython Development Team
17 17 #
18 18 # Distributed under the terms of the BSD License. The full license is in
19 19 # the file COPYING, distributed as part of this software.
20 20 #-----------------------------------------------------------------------------
21 21
22 22 #-----------------------------------------------------------------------------
23 23 # Imports
24 24 #-----------------------------------------------------------------------------
25 25
26 26 import __builtin__ as builtin_mod
27 27 import argparse
28 import copy
28 29 import os
29 30 import re
30 31 import sys
31 32
32 33 from IPython.utils.path import filefind, get_ipython_dir
33 34 from IPython.utils import py3compat, warn
34 35 from IPython.utils.encoding import DEFAULT_ENCODING
36 from IPython.utils.traitlets import HasTraits, List, Any, TraitError
35 37
36 38 #-----------------------------------------------------------------------------
37 39 # Exceptions
38 40 #-----------------------------------------------------------------------------
39 41
40 42
41 43 class ConfigError(Exception):
42 44 pass
43 45
44 46 class ConfigLoaderError(ConfigError):
45 47 pass
46 48
47 49 class ConfigFileNotFound(ConfigError):
48 50 pass
49 51
50 52 class ArgumentError(ConfigLoaderError):
51 53 pass
52 54
53 55 #-----------------------------------------------------------------------------
54 56 # Argparse fix
55 57 #-----------------------------------------------------------------------------
56 58
57 59 # Unfortunately argparse by default prints help messages to stderr instead of
58 60 # stdout. This makes it annoying to capture long help screens at the command
59 61 # line, since one must know how to pipe stderr, which many users don't know how
60 62 # to do. So we override the print_help method with one that defaults to
61 63 # stdout and use our class instead.
62 64
63 65 class ArgumentParser(argparse.ArgumentParser):
64 66 """Simple argparse subclass that prints help to stdout by default."""
65 67
66 68 def print_help(self, file=None):
67 69 if file is None:
68 70 file = sys.stdout
69 71 return super(ArgumentParser, self).print_help(file)
70 72
71 73 print_help.__doc__ = argparse.ArgumentParser.print_help.__doc__
72 74
73 75 #-----------------------------------------------------------------------------
74 76 # Config class for holding config information
75 77 #-----------------------------------------------------------------------------
76 78
79 class LazyConfigValue(HasTraits):
80 """Proxy object for exposing methods on configurable containers
81
82 Exposes:
83
84 - append, extend, insert on lists
85 - update on dicts
86 - update, add on sets
87 """
88
89 _value = None
90
91 # list methods
92 _extend = List()
93 _prepend = List()
94
95 def append(self, obj):
96 self._extend.append(obj)
97
98 def extend(self, other):
99 self._extend.extend(other)
100
101 def prepend(self, other):
102 """like list.extend, but for the front"""
103 self._prepend[:0] = other
104
105 _inserts = List()
106 def insert(self, index, other):
107 if not isinstance(index, int):
108 raise TypeError("An integer is required")
109 self._inserts.append((index, other))
110
111 # dict methods
112 # update is used for both dict and set
113 _update = Any()
114 def update(self, other):
115 if self._update is None:
116 if isinstance(other, dict):
117 self._update = {}
118 else:
119 self._update = set()
120 self._update.update(other)
121
122 # set methods
123 def add(self, obj):
124 self.update({obj})
125
126 def get_value(self, initial):
127 """construct the value from the initial one
128
129 after applying any insert / extend / update changes
130 """
131 if self._value is not None:
132 return self._value
133 value = copy.deepcopy(initial)
134 if isinstance(value, list):
135 for idx, obj in self._inserts:
136 value.insert(idx, obj)
137 value[:0] = self._prepend
138 value.extend(self._extend)
139
140 elif isinstance(value, dict):
141 if self._update:
142 value.update(self._update)
143 elif isinstance(value, set):
144 if self._update:
145 value.update(self._update)
146 self._value = value
147 return value
148
149 def to_dict(self):
150 """return JSONable dict form of my data
151
152 Currently update as dict or set, extend, prepend as lists, and inserts as list of tuples.
153 """
154 d = {}
155 if self._update:
156 d['update'] = self._update
157 if self._extend:
158 d['extend'] = self._extend
159 if self._prepend:
160 d['prepend'] = self._prepend
161 elif self._inserts:
162 d['inserts'] = self._inserts
163 return d
164
77 165
78 166 class Config(dict):
79 167 """An attribute based dict that can do smart merges."""
80 168
81 169 def __init__(self, *args, **kwds):
82 170 dict.__init__(self, *args, **kwds)
83 171 # This sets self.__dict__ = self, but it has to be done this way
84 172 # because we are also overriding __setattr__.
85 173 dict.__setattr__(self, '__dict__', self)
86 174 self._ensure_subconfig()
87 175
88 176 def _ensure_subconfig(self):
89 177 """ensure that sub-dicts that should be Config objects are
90 178
91 179 casts dicts that are under section keys to Config objects,
92 180 which is necessary for constructing Config objects from dict literals.
93 181 """
94 182 for key in self:
95 183 obj = self[key]
96 184 if self._is_section_key(key) \
97 185 and isinstance(obj, dict) \
98 186 and not isinstance(obj, Config):
99 187 dict.__setattr__(self, key, Config(obj))
100 188
101 189 def _merge(self, other):
102 190 """deprecated alias, use Config.merge()"""
103 191 self.merge(other)
104 192
105 193 def merge(self, other):
106 194 """merge another config object into this one"""
107 195 to_update = {}
108 196 for k, v in other.iteritems():
109 197 if k not in self:
110 198 to_update[k] = v
111 199 else: # I have this key
112 200 if isinstance(v, Config) and isinstance(self[k], Config):
113 201 # Recursively merge common sub Configs
114 202 self[k].merge(v)
115 203 else:
116 204 # Plain updates for non-Configs
117 205 to_update[k] = v
118 206
119 207 self.update(to_update)
120 208
121 209 def _is_section_key(self, key):
122 210 if key[0].upper()==key[0] and not key.startswith('_'):
123 211 return True
124 212 else:
125 213 return False
126 214
127 215 def __contains__(self, key):
216 # allow nested contains of the form `"Section.key" in config`
217 if '.' in key:
218 first, remainder = key.split('.', 1)
219 if first not in self:
220 return False
221 return remainder in self[first]
222
223 # we always have Sections
128 224 if self._is_section_key(key):
129 225 return True
130 226 else:
131 227 return super(Config, self).__contains__(key)
132 228 # .has_key is deprecated for dictionaries.
133 229 has_key = __contains__
134 230
135 231 def _has_section(self, key):
136 232 if self._is_section_key(key):
137 233 if super(Config, self).__contains__(key):
138 234 return True
139 235 return False
140 236
141 237 def copy(self):
142 238 return type(self)(dict.copy(self))
143 239
144 240 def __copy__(self):
145 241 return self.copy()
146 242
147 243 def __deepcopy__(self, memo):
148 244 import copy
149 245 return type(self)(copy.deepcopy(self.items()))
150
246
151 247 def __getitem__(self, key):
152 248 # We cannot use directly self._is_section_key, because it triggers
153 249 # infinite recursion on top of PyPy. Instead, we manually fish the
154 250 # bound method.
155 251 is_section_key = self.__class__._is_section_key.__get__(self)
156 252
157 253 # Because we use this for an exec namespace, we need to delegate
158 254 # the lookup of names in __builtin__ to itself. This means
159 255 # that you can't have section or attribute names that are
160 256 # builtins.
161 257 try:
162 258 return getattr(builtin_mod, key)
163 259 except AttributeError:
164 260 pass
165 261 if is_section_key(key):
166 262 try:
167 263 return dict.__getitem__(self, key)
168 264 except KeyError:
169 265 c = Config()
170 266 dict.__setitem__(self, key, c)
171 267 return c
172 268 else:
173 return dict.__getitem__(self, key)
269 try:
270 return dict.__getitem__(self, key)
271 except KeyError:
272 # undefined
273 v = LazyConfigValue()
274 dict.__setitem__(self, key, v)
275 return v
276
174 277
175 278 def __setitem__(self, key, value):
176 279 if self._is_section_key(key):
177 280 if not isinstance(value, Config):
178 281 raise ValueError('values whose keys begin with an uppercase '
179 282 'char must be Config instances: %r, %r' % (key, value))
180 283 else:
181 284 dict.__setitem__(self, key, value)
182 285
183 286 def __getattr__(self, key):
184 287 try:
185 288 return self.__getitem__(key)
186 289 except KeyError as e:
187 290 raise AttributeError(e)
188 291
189 292 def __setattr__(self, key, value):
190 293 try:
191 294 self.__setitem__(key, value)
192 295 except KeyError as e:
193 296 raise AttributeError(e)
194 297
195 298 def __delattr__(self, key):
196 299 try:
197 300 dict.__delitem__(self, key)
198 301 except KeyError as e:
199 302 raise AttributeError(e)
200 303
201 304
202 305 #-----------------------------------------------------------------------------
203 306 # Config loading classes
204 307 #-----------------------------------------------------------------------------
205 308
206 309
207 310 class ConfigLoader(object):
208 311 """A object for loading configurations from just about anywhere.
209 312
210 313 The resulting configuration is packaged as a :class:`Struct`.
211 314
212 315 Notes
213 316 -----
214 317 A :class:`ConfigLoader` does one thing: load a config from a source
215 318 (file, command line arguments) and returns the data as a :class:`Struct`.
216 319 There are lots of things that :class:`ConfigLoader` does not do. It does
217 320 not implement complex logic for finding config files. It does not handle
218 321 default values or merge multiple configs. These things need to be
219 322 handled elsewhere.
220 323 """
221 324
222 325 def __init__(self):
223 326 """A base class for config loaders.
224 327
225 328 Examples
226 329 --------
227 330
228 331 >>> cl = ConfigLoader()
229 332 >>> config = cl.load_config()
230 333 >>> config
231 334 {}
232 335 """
233 336 self.clear()
234 337
235 338 def clear(self):
236 339 self.config = Config()
237 340
238 341 def load_config(self):
239 342 """Load a config from somewhere, return a :class:`Config` instance.
240 343
241 344 Usually, this will cause self.config to be set and then returned.
242 345 However, in most cases, :meth:`ConfigLoader.clear` should be called
243 346 to erase any previous state.
244 347 """
245 348 self.clear()
246 349 return self.config
247 350
248 351
249 352 class FileConfigLoader(ConfigLoader):
250 353 """A base class for file based configurations.
251 354
252 355 As we add more file based config loaders, the common logic should go
253 356 here.
254 357 """
255 358 pass
256 359
257 360
258 361 class PyFileConfigLoader(FileConfigLoader):
259 362 """A config loader for pure python files.
260 363
261 364 This calls execfile on a plain python file and looks for attributes
262 365 that are all caps. These attribute are added to the config Struct.
263 366 """
264 367
265 368 def __init__(self, filename, path=None):
266 369 """Build a config loader for a filename and path.
267 370
268 371 Parameters
269 372 ----------
270 373 filename : str
271 374 The file name of the config file.
272 375 path : str, list, tuple
273 376 The path to search for the config file on, or a sequence of
274 377 paths to try in order.
275 378 """
276 379 super(PyFileConfigLoader, self).__init__()
277 380 self.filename = filename
278 381 self.path = path
279 382 self.full_filename = ''
280 383 self.data = None
281 384
282 385 def load_config(self):
283 386 """Load the config from a file and return it as a Struct."""
284 387 self.clear()
285 388 try:
286 389 self._find_file()
287 390 except IOError as e:
288 391 raise ConfigFileNotFound(str(e))
289 392 self._read_file_as_dict()
290 393 self._convert_to_config()
291 394 return self.config
292 395
293 396 def _find_file(self):
294 397 """Try to find the file by searching the paths."""
295 398 self.full_filename = filefind(self.filename, self.path)
296 399
297 400 def _read_file_as_dict(self):
298 401 """Load the config file into self.config, with recursive loading."""
299 402 # This closure is made available in the namespace that is used
300 403 # to exec the config file. It allows users to call
301 404 # load_subconfig('myconfig.py') to load config files recursively.
302 405 # It needs to be a closure because it has references to self.path
303 406 # and self.config. The sub-config is loaded with the same path
304 407 # as the parent, but it uses an empty config which is then merged
305 408 # with the parents.
306 409
307 410 # If a profile is specified, the config file will be loaded
308 411 # from that profile
309 412
310 413 def load_subconfig(fname, profile=None):
311 414 # import here to prevent circular imports
312 415 from IPython.core.profiledir import ProfileDir, ProfileDirError
313 416 if profile is not None:
314 417 try:
315 418 profile_dir = ProfileDir.find_profile_dir_by_name(
316 419 get_ipython_dir(),
317 420 profile,
318 421 )
319 422 except ProfileDirError:
320 423 return
321 424 path = profile_dir.location
322 425 else:
323 426 path = self.path
324 427 loader = PyFileConfigLoader(fname, path)
325 428 try:
326 429 sub_config = loader.load_config()
327 430 except ConfigFileNotFound:
328 431 # Pass silently if the sub config is not there. This happens
329 432 # when a user s using a profile, but not the default config.
330 433 pass
331 434 else:
332 435 self.config.merge(sub_config)
333 436
334 437 # Again, this needs to be a closure and should be used in config
335 438 # files to get the config being loaded.
336 439 def get_config():
337 440 return self.config
338 441
339 442 namespace = dict(
340 443 load_subconfig=load_subconfig,
341 444 get_config=get_config,
342 445 __file__=self.full_filename,
343 446 )
344 447 fs_encoding = sys.getfilesystemencoding() or 'ascii'
345 448 conf_filename = self.full_filename.encode(fs_encoding)
346 449 py3compat.execfile(conf_filename, namespace)
347 450
348 451 def _convert_to_config(self):
349 452 if self.data is None:
350 453 ConfigLoaderError('self.data does not exist')
351 454
352 455
353 456 class CommandLineConfigLoader(ConfigLoader):
354 457 """A config loader for command line arguments.
355 458
356 459 As we add more command line based loaders, the common logic should go
357 460 here.
358 461 """
359 462
360 463 def _exec_config_str(self, lhs, rhs):
361 464 """execute self.config.<lhs> = <rhs>
362 465
363 466 * expands ~ with expanduser
364 467 * tries to assign with raw eval, otherwise assigns with just the string,
365 468 allowing `--C.a=foobar` and `--C.a="foobar"` to be equivalent. *Not*
366 469 equivalent are `--C.a=4` and `--C.a='4'`.
367 470 """
368 471 rhs = os.path.expanduser(rhs)
369 472 try:
370 473 # Try to see if regular Python syntax will work. This
371 474 # won't handle strings as the quote marks are removed
372 475 # by the system shell.
373 476 value = eval(rhs)
374 477 except (NameError, SyntaxError):
375 478 # This case happens if the rhs is a string.
376 479 value = rhs
377 480
378 481 exec u'self.config.%s = value' % lhs
379 482
380 483 def _load_flag(self, cfg):
381 484 """update self.config from a flag, which can be a dict or Config"""
382 485 if isinstance(cfg, (dict, Config)):
383 486 # don't clobber whole config sections, update
384 487 # each section from config:
385 488 for sec,c in cfg.iteritems():
386 489 self.config[sec].update(c)
387 490 else:
388 491 raise TypeError("Invalid flag: %r" % cfg)
389 492
390 493 # raw --identifier=value pattern
391 494 # but *also* accept '-' as wordsep, for aliases
392 495 # accepts: --foo=a
393 496 # --Class.trait=value
394 497 # --alias-name=value
395 498 # rejects: -foo=value
396 499 # --foo
397 500 # --Class.trait
398 501 kv_pattern = re.compile(r'\-\-[A-Za-z][\w\-]*(\.[\w\-]+)*\=.*')
399 502
400 503 # just flags, no assignments, with two *or one* leading '-'
401 504 # accepts: --foo
402 505 # -foo-bar-again
403 506 # rejects: --anything=anything
404 507 # --two.word
405 508
406 509 flag_pattern = re.compile(r'\-\-?\w+[\-\w]*$')
407 510
408 511 class KeyValueConfigLoader(CommandLineConfigLoader):
409 512 """A config loader that loads key value pairs from the command line.
410 513
411 514 This allows command line options to be gives in the following form::
412 515
413 516 ipython --profile="foo" --InteractiveShell.autocall=False
414 517 """
415 518
416 519 def __init__(self, argv=None, aliases=None, flags=None):
417 520 """Create a key value pair config loader.
418 521
419 522 Parameters
420 523 ----------
421 524 argv : list
422 525 A list that has the form of sys.argv[1:] which has unicode
423 526 elements of the form u"key=value". If this is None (default),
424 527 then sys.argv[1:] will be used.
425 528 aliases : dict
426 529 A dict of aliases for configurable traits.
427 530 Keys are the short aliases, Values are the resolved trait.
428 531 Of the form: `{'alias' : 'Configurable.trait'}`
429 532 flags : dict
430 533 A dict of flags, keyed by str name. Vaues can be Config objects,
431 534 dicts, or "key=value" strings. If Config or dict, when the flag
432 535 is triggered, The flag is loaded as `self.config.update(m)`.
433 536
434 537 Returns
435 538 -------
436 539 config : Config
437 540 The resulting Config object.
438 541
439 542 Examples
440 543 --------
441 544
442 545 >>> from IPython.config.loader import KeyValueConfigLoader
443 546 >>> cl = KeyValueConfigLoader()
444 547 >>> d = cl.load_config(["--A.name='brian'","--B.number=0"])
445 548 >>> sorted(d.items())
446 549 [('A', {'name': 'brian'}), ('B', {'number': 0})]
447 550 """
448 551 self.clear()
449 552 if argv is None:
450 553 argv = sys.argv[1:]
451 554 self.argv = argv
452 555 self.aliases = aliases or {}
453 556 self.flags = flags or {}
454 557
455 558
456 559 def clear(self):
457 560 super(KeyValueConfigLoader, self).clear()
458 561 self.extra_args = []
459 562
460 563
461 564 def _decode_argv(self, argv, enc=None):
462 565 """decode argv if bytes, using stin.encoding, falling back on default enc"""
463 566 uargv = []
464 567 if enc is None:
465 568 enc = DEFAULT_ENCODING
466 569 for arg in argv:
467 570 if not isinstance(arg, unicode):
468 571 # only decode if not already decoded
469 572 arg = arg.decode(enc)
470 573 uargv.append(arg)
471 574 return uargv
472 575
473 576
474 577 def load_config(self, argv=None, aliases=None, flags=None):
475 578 """Parse the configuration and generate the Config object.
476 579
477 580 After loading, any arguments that are not key-value or
478 581 flags will be stored in self.extra_args - a list of
479 582 unparsed command-line arguments. This is used for
480 583 arguments such as input files or subcommands.
481 584
482 585 Parameters
483 586 ----------
484 587 argv : list, optional
485 588 A list that has the form of sys.argv[1:] which has unicode
486 589 elements of the form u"key=value". If this is None (default),
487 590 then self.argv will be used.
488 591 aliases : dict
489 592 A dict of aliases for configurable traits.
490 593 Keys are the short aliases, Values are the resolved trait.
491 594 Of the form: `{'alias' : 'Configurable.trait'}`
492 595 flags : dict
493 596 A dict of flags, keyed by str name. Values can be Config objects
494 597 or dicts. When the flag is triggered, The config is loaded as
495 598 `self.config.update(cfg)`.
496 599 """
497 600 self.clear()
498 601 if argv is None:
499 602 argv = self.argv
500 603 if aliases is None:
501 604 aliases = self.aliases
502 605 if flags is None:
503 606 flags = self.flags
504 607
505 608 # ensure argv is a list of unicode strings:
506 609 uargv = self._decode_argv(argv)
507 610 for idx,raw in enumerate(uargv):
508 611 # strip leading '-'
509 612 item = raw.lstrip('-')
510 613
511 614 if raw == '--':
512 615 # don't parse arguments after '--'
513 616 # this is useful for relaying arguments to scripts, e.g.
514 617 # ipython -i foo.py --matplotlib=qt -- args after '--' go-to-foo.py
515 618 self.extra_args.extend(uargv[idx+1:])
516 619 break
517 620
518 621 if kv_pattern.match(raw):
519 622 lhs,rhs = item.split('=',1)
520 623 # Substitute longnames for aliases.
521 624 if lhs in aliases:
522 625 lhs = aliases[lhs]
523 626 if '.' not in lhs:
524 627 # probably a mistyped alias, but not technically illegal
525 628 warn.warn("Unrecognized alias: '%s', it will probably have no effect."%lhs)
526 629 try:
527 630 self._exec_config_str(lhs, rhs)
528 631 except Exception:
529 632 raise ArgumentError("Invalid argument: '%s'" % raw)
530 633
531 634 elif flag_pattern.match(raw):
532 635 if item in flags:
533 636 cfg,help = flags[item]
534 637 self._load_flag(cfg)
535 638 else:
536 639 raise ArgumentError("Unrecognized flag: '%s'"%raw)
537 640 elif raw.startswith('-'):
538 641 kv = '--'+item
539 642 if kv_pattern.match(kv):
540 643 raise ArgumentError("Invalid argument: '%s', did you mean '%s'?"%(raw, kv))
541 644 else:
542 645 raise ArgumentError("Invalid argument: '%s'"%raw)
543 646 else:
544 647 # keep all args that aren't valid in a list,
545 648 # in case our parent knows what to do with them.
546 649 self.extra_args.append(item)
547 650 return self.config
548 651
549 652 class ArgParseConfigLoader(CommandLineConfigLoader):
550 653 """A loader that uses the argparse module to load from the command line."""
551 654
552 655 def __init__(self, argv=None, aliases=None, flags=None, *parser_args, **parser_kw):
553 656 """Create a config loader for use with argparse.
554 657
555 658 Parameters
556 659 ----------
557 660
558 661 argv : optional, list
559 662 If given, used to read command-line arguments from, otherwise
560 663 sys.argv[1:] is used.
561 664
562 665 parser_args : tuple
563 666 A tuple of positional arguments that will be passed to the
564 667 constructor of :class:`argparse.ArgumentParser`.
565 668
566 669 parser_kw : dict
567 670 A tuple of keyword arguments that will be passed to the
568 671 constructor of :class:`argparse.ArgumentParser`.
569 672
570 673 Returns
571 674 -------
572 675 config : Config
573 676 The resulting Config object.
574 677 """
575 678 super(CommandLineConfigLoader, self).__init__()
576 679 self.clear()
577 680 if argv is None:
578 681 argv = sys.argv[1:]
579 682 self.argv = argv
580 683 self.aliases = aliases or {}
581 684 self.flags = flags or {}
582 685
583 686 self.parser_args = parser_args
584 687 self.version = parser_kw.pop("version", None)
585 688 kwargs = dict(argument_default=argparse.SUPPRESS)
586 689 kwargs.update(parser_kw)
587 690 self.parser_kw = kwargs
588 691
589 692 def load_config(self, argv=None, aliases=None, flags=None):
590 693 """Parse command line arguments and return as a Config object.
591 694
592 695 Parameters
593 696 ----------
594 697
595 698 args : optional, list
596 699 If given, a list with the structure of sys.argv[1:] to parse
597 700 arguments from. If not given, the instance's self.argv attribute
598 701 (given at construction time) is used."""
599 702 self.clear()
600 703 if argv is None:
601 704 argv = self.argv
602 705 if aliases is None:
603 706 aliases = self.aliases
604 707 if flags is None:
605 708 flags = self.flags
606 709 self._create_parser(aliases, flags)
607 710 self._parse_args(argv)
608 711 self._convert_to_config()
609 712 return self.config
610 713
611 714 def get_extra_args(self):
612 715 if hasattr(self, 'extra_args'):
613 716 return self.extra_args
614 717 else:
615 718 return []
616 719
617 720 def _create_parser(self, aliases=None, flags=None):
618 721 self.parser = ArgumentParser(*self.parser_args, **self.parser_kw)
619 722 self._add_arguments(aliases, flags)
620 723
621 724 def _add_arguments(self, aliases=None, flags=None):
622 725 raise NotImplementedError("subclasses must implement _add_arguments")
623 726
624 727 def _parse_args(self, args):
625 728 """self.parser->self.parsed_data"""
626 729 # decode sys.argv to support unicode command-line options
627 730 enc = DEFAULT_ENCODING
628 731 uargs = [py3compat.cast_unicode(a, enc) for a in args]
629 732 self.parsed_data, self.extra_args = self.parser.parse_known_args(uargs)
630 733
631 734 def _convert_to_config(self):
632 735 """self.parsed_data->self.config"""
633 736 for k, v in vars(self.parsed_data).iteritems():
634 737 exec "self.config.%s = v"%k in locals(), globals()
635 738
636 739 class KVArgParseConfigLoader(ArgParseConfigLoader):
637 740 """A config loader that loads aliases and flags with argparse,
638 741 but will use KVLoader for the rest. This allows better parsing
639 742 of common args, such as `ipython -c 'print 5'`, but still gets
640 743 arbitrary config with `ipython --InteractiveShell.use_readline=False`"""
641 744
642 745 def _add_arguments(self, aliases=None, flags=None):
643 746 self.alias_flags = {}
644 747 # print aliases, flags
645 748 if aliases is None:
646 749 aliases = self.aliases
647 750 if flags is None:
648 751 flags = self.flags
649 752 paa = self.parser.add_argument
650 753 for key,value in aliases.iteritems():
651 754 if key in flags:
652 755 # flags
653 756 nargs = '?'
654 757 else:
655 758 nargs = None
656 759 if len(key) is 1:
657 760 paa('-'+key, '--'+key, type=unicode, dest=value, nargs=nargs)
658 761 else:
659 762 paa('--'+key, type=unicode, dest=value, nargs=nargs)
660 763 for key, (value, help) in flags.iteritems():
661 764 if key in self.aliases:
662 765 #
663 766 self.alias_flags[self.aliases[key]] = value
664 767 continue
665 768 if len(key) is 1:
666 769 paa('-'+key, '--'+key, action='append_const', dest='_flags', const=value)
667 770 else:
668 771 paa('--'+key, action='append_const', dest='_flags', const=value)
669 772
670 773 def _convert_to_config(self):
671 774 """self.parsed_data->self.config, parse unrecognized extra args via KVLoader."""
672 775 # remove subconfigs list from namespace before transforming the Namespace
673 776 if '_flags' in self.parsed_data:
674 777 subcs = self.parsed_data._flags
675 778 del self.parsed_data._flags
676 779 else:
677 780 subcs = []
678 781
679 782 for k, v in vars(self.parsed_data).iteritems():
680 783 if v is None:
681 784 # it was a flag that shares the name of an alias
682 785 subcs.append(self.alias_flags[k])
683 786 else:
684 787 # eval the KV assignment
685 788 self._exec_config_str(k, v)
686 789
687 790 for subc in subcs:
688 791 self._load_flag(subc)
689 792
690 793 if self.extra_args:
691 794 sub_parser = KeyValueConfigLoader()
692 795 sub_parser.load_config(self.extra_args)
693 796 self.config.merge(sub_parser.config)
694 797 self.extra_args = sub_parser.extra_args
695 798
696 799
697 800 def load_pyconfig_files(config_files, path):
698 801 """Load multiple Python config files, merging each of them in turn.
699 802
700 803 Parameters
701 804 ==========
702 805 config_files : list of str
703 806 List of config files names to load and merge into the config.
704 807 path : unicode
705 808 The full path to the location of the config files.
706 809 """
707 810 config = Config()
708 811 for cf in config_files:
709 812 loader = PyFileConfigLoader(cf, path=path)
710 813 try:
711 814 next_config = loader.load_config()
712 815 except ConfigFileNotFound:
713 816 pass
714 817 except:
715 818 raise
716 819 else:
717 820 config.merge(next_config)
718 821 return config
@@ -1,19 +1,13 b''
1 1 c = get_config()
2 2 app = c.InteractiveShellApp
3 3
4 4 # This can be used at any point in a config file to load a sub config
5 5 # and merge it into the current one.
6 6 load_subconfig('ipython_config.py', profile='default')
7 7
8 8 lines = """
9 9 from IPython.parallel import *
10 10 """
11 11
12 # You have to make sure that attributes that are containers already
13 # exist before using them. Simple assigning a new list will override
14 # all previous values.
15 if hasattr(app, 'exec_lines'):
16 app.exec_lines.append(lines)
17 else:
18 app.exec_lines = [lines]
12 app.exec_lines.append(lines)
19 13
@@ -1,21 +1,13 b''
1 1 c = get_config()
2 2 app = c.InteractiveShellApp
3 3
4 4 # This can be used at any point in a config file to load a sub config
5 5 # and merge it into the current one.
6 6 load_subconfig('ipython_config.py', profile='default')
7 7
8 8 lines = """
9 9 import cmath
10 10 from math import *
11 11 """
12 12
13 # You have to make sure that attributes that are containers already
14 # exist before using them. Simple assigning a new list will override
15 # all previous values.
16
17 if hasattr(app, 'exec_lines'):
18 app.exec_lines.append(lines)
19 else:
20 app.exec_lines = [lines]
21
13 app.exec_lines.append(lines)
@@ -1,30 +1,24 b''
1 1 c = get_config()
2 2 app = c.InteractiveShellApp
3 3
4 4 # This can be used at any point in a config file to load a sub config
5 5 # and merge it into the current one.
6 6 load_subconfig('ipython_config.py', profile='default')
7 7
8 8 c.PromptManager.in_template = r'{color.LightGreen}\u@\h{color.LightBlue}[{color.LightCyan}\Y1{color.LightBlue}]{color.Green}|\#> '
9 9 c.PromptManager.in2_template = r'{color.Green}|{color.LightGreen}\D{color.Green}> '
10 10 c.PromptManager.out_template = r'<\#> '
11 11
12 12 c.PromptManager.justify = True
13 13
14 14 c.InteractiveShell.separate_in = ''
15 15 c.InteractiveShell.separate_out = ''
16 16 c.InteractiveShell.separate_out2 = ''
17 17
18 18 c.PrefilterManager.multi_line_specials = True
19 19
20 20 lines = """
21 21 %rehashx
22 22 """
23 23
24 # You have to make sure that attributes that are containers already
25 # exist before using them. Simple assigning a new list will override
26 # all previous values.
27 if hasattr(app, 'exec_lines'):
28 app.exec_lines.append(lines)
29 else:
30 app.exec_lines = [lines]
24 app.exec_lines.append(lines)
@@ -1,30 +1,20 b''
1 1 c = get_config()
2 2 app = c.InteractiveShellApp
3 3
4 4 # This can be used at any point in a config file to load a sub config
5 5 # and merge it into the current one.
6 6 load_subconfig('ipython_config.py', profile='default')
7 7
8 8 lines = """
9 9 from __future__ import division
10 10 from sympy import *
11 11 x, y, z, t = symbols('x y z t')
12 12 k, m, n = symbols('k m n', integer=True)
13 13 f, g, h = symbols('f g h', cls=Function)
14 14 """
15 15
16 # You have to make sure that attributes that are containers already
17 # exist before using them. Simple assigning a new list will override
18 # all previous values.
19
20 if hasattr(app, 'exec_lines'):
21 app.exec_lines.append(lines)
22 else:
23 app.exec_lines = [lines]
16 app.exec_lines.append(lines)
24 17
25 18 # Load the sympy_printing extension to enable nice printing of sympy expr's.
26 if hasattr(app, 'extensions'):
27 app.extensions.append('sympyprinting')
28 else:
29 app.extensions = ['sympyprinting']
19 app.extensions.append('sympy.interactive.ipythonprinting')
30 20
@@ -1,283 +1,360 b''
1 1 # encoding: utf-8
2 2 """
3 3 Tests for IPython.config.configurable
4 4
5 5 Authors:
6 6
7 7 * Brian Granger
8 8 * Fernando Perez (design help)
9 9 """
10 10
11 11 #-----------------------------------------------------------------------------
12 12 # Copyright (C) 2008-2011 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-----------------------------------------------------------------------------
17 17
18 18 #-----------------------------------------------------------------------------
19 19 # Imports
20 20 #-----------------------------------------------------------------------------
21 21
22 22 from unittest import TestCase
23 23
24 24 from IPython.config.configurable import (
25 25 Configurable,
26 26 SingletonConfigurable
27 27 )
28 28
29 29 from IPython.utils.traitlets import (
30 Integer, Float, Unicode
30 Integer, Float, Unicode, List, Dict, Set,
31 31 )
32 32
33 33 from IPython.config.loader import Config
34 34 from IPython.utils.py3compat import PY3
35 35
36 36 #-----------------------------------------------------------------------------
37 37 # Test cases
38 38 #-----------------------------------------------------------------------------
39 39
40 40
41 41 class MyConfigurable(Configurable):
42 42 a = Integer(1, config=True, help="The integer a.")
43 43 b = Float(1.0, config=True, help="The integer b.")
44 44 c = Unicode('no config')
45 45
46 46
47 47 mc_help=u"""MyConfigurable options
48 48 ----------------------
49 49 --MyConfigurable.a=<Integer>
50 50 Default: 1
51 51 The integer a.
52 52 --MyConfigurable.b=<Float>
53 53 Default: 1.0
54 54 The integer b."""
55 55
56 56 mc_help_inst=u"""MyConfigurable options
57 57 ----------------------
58 58 --MyConfigurable.a=<Integer>
59 59 Current: 5
60 60 The integer a.
61 61 --MyConfigurable.b=<Float>
62 62 Current: 4.0
63 63 The integer b."""
64 64
65 65 # On Python 3, the Integer trait is a synonym for Int
66 66 if PY3:
67 67 mc_help = mc_help.replace(u"<Integer>", u"<Int>")
68 68 mc_help_inst = mc_help_inst.replace(u"<Integer>", u"<Int>")
69 69
70 70 class Foo(Configurable):
71 71 a = Integer(0, config=True, help="The integer a.")
72 72 b = Unicode('nope', config=True)
73 73
74 74
75 75 class Bar(Foo):
76 76 b = Unicode('gotit', config=False, help="The string b.")
77 77 c = Float(config=True, help="The string c.")
78 78
79 79
80 80 class TestConfigurable(TestCase):
81 81
82 82 def test_default(self):
83 83 c1 = Configurable()
84 84 c2 = Configurable(config=c1.config)
85 85 c3 = Configurable(config=c2.config)
86 86 self.assertEqual(c1.config, c2.config)
87 87 self.assertEqual(c2.config, c3.config)
88 88
89 89 def test_custom(self):
90 90 config = Config()
91 91 config.foo = 'foo'
92 92 config.bar = 'bar'
93 93 c1 = Configurable(config=config)
94 94 c2 = Configurable(config=c1.config)
95 95 c3 = Configurable(config=c2.config)
96 96 self.assertEqual(c1.config, config)
97 97 self.assertEqual(c2.config, config)
98 98 self.assertEqual(c3.config, config)
99 99 # Test that copies are not made
100 100 self.assertTrue(c1.config is config)
101 101 self.assertTrue(c2.config is config)
102 102 self.assertTrue(c3.config is config)
103 103 self.assertTrue(c1.config is c2.config)
104 104 self.assertTrue(c2.config is c3.config)
105 105
106 106 def test_inheritance(self):
107 107 config = Config()
108 108 config.MyConfigurable.a = 2
109 109 config.MyConfigurable.b = 2.0
110 110 c1 = MyConfigurable(config=config)
111 111 c2 = MyConfigurable(config=c1.config)
112 112 self.assertEqual(c1.a, config.MyConfigurable.a)
113 113 self.assertEqual(c1.b, config.MyConfigurable.b)
114 114 self.assertEqual(c2.a, config.MyConfigurable.a)
115 115 self.assertEqual(c2.b, config.MyConfigurable.b)
116 116
117 117 def test_parent(self):
118 118 config = Config()
119 119 config.Foo.a = 10
120 120 config.Foo.b = "wow"
121 121 config.Bar.b = 'later'
122 122 config.Bar.c = 100.0
123 123 f = Foo(config=config)
124 124 b = Bar(config=f.config)
125 125 self.assertEqual(f.a, 10)
126 126 self.assertEqual(f.b, 'wow')
127 127 self.assertEqual(b.b, 'gotit')
128 128 self.assertEqual(b.c, 100.0)
129 129
130 130 def test_override1(self):
131 131 config = Config()
132 132 config.MyConfigurable.a = 2
133 133 config.MyConfigurable.b = 2.0
134 134 c = MyConfigurable(a=3, config=config)
135 135 self.assertEqual(c.a, 3)
136 136 self.assertEqual(c.b, config.MyConfigurable.b)
137 137 self.assertEqual(c.c, 'no config')
138 138
139 139 def test_override2(self):
140 140 config = Config()
141 141 config.Foo.a = 1
142 142 config.Bar.b = 'or' # Up above b is config=False, so this won't do it.
143 143 config.Bar.c = 10.0
144 144 c = Bar(config=config)
145 145 self.assertEqual(c.a, config.Foo.a)
146 146 self.assertEqual(c.b, 'gotit')
147 147 self.assertEqual(c.c, config.Bar.c)
148 148 c = Bar(a=2, b='and', c=20.0, config=config)
149 149 self.assertEqual(c.a, 2)
150 150 self.assertEqual(c.b, 'and')
151 151 self.assertEqual(c.c, 20.0)
152 152
153 153 def test_help(self):
154 154 self.assertEqual(MyConfigurable.class_get_help(), mc_help)
155 155
156 156 def test_help_inst(self):
157 157 inst = MyConfigurable(a=5, b=4)
158 158 self.assertEqual(MyConfigurable.class_get_help(inst), mc_help_inst)
159 159
160 160
161 161 class TestSingletonConfigurable(TestCase):
162 162
163 163 def test_instance(self):
164 164 class Foo(SingletonConfigurable): pass
165 165 self.assertEqual(Foo.initialized(), False)
166 166 foo = Foo.instance()
167 167 self.assertEqual(Foo.initialized(), True)
168 168 self.assertEqual(foo, Foo.instance())
169 169 self.assertEqual(SingletonConfigurable._instance, None)
170 170
171 171 def test_inheritance(self):
172 172 class Bar(SingletonConfigurable): pass
173 173 class Bam(Bar): pass
174 174 self.assertEqual(Bar.initialized(), False)
175 175 self.assertEqual(Bam.initialized(), False)
176 176 bam = Bam.instance()
177 177 bam == Bar.instance()
178 178 self.assertEqual(Bar.initialized(), True)
179 179 self.assertEqual(Bam.initialized(), True)
180 180 self.assertEqual(bam, Bam._instance)
181 181 self.assertEqual(bam, Bar._instance)
182 182 self.assertEqual(SingletonConfigurable._instance, None)
183 183
184 184
185 185 class MyParent(Configurable):
186 186 pass
187 187
188 188 class MyParent2(MyParent):
189 189 pass
190 190
191 191 class TestParentConfigurable(TestCase):
192 192
193 193 def test_parent_config(self):
194 194 cfg = Config({
195 195 'MyParent' : {
196 196 'MyConfigurable' : {
197 197 'b' : 2.0,
198 198 }
199 199 }
200 200 })
201 201 parent = MyParent(config=cfg)
202 202 myc = MyConfigurable(parent=parent)
203 203 self.assertEqual(myc.b, parent.config.MyParent.MyConfigurable.b)
204 204
205 205 def test_parent_inheritance(self):
206 206 cfg = Config({
207 207 'MyParent' : {
208 208 'MyConfigurable' : {
209 209 'b' : 2.0,
210 210 }
211 211 }
212 212 })
213 213 parent = MyParent2(config=cfg)
214 214 myc = MyConfigurable(parent=parent)
215 215 self.assertEqual(myc.b, parent.config.MyParent.MyConfigurable.b)
216 216
217 217 def test_multi_parent(self):
218 218 cfg = Config({
219 219 'MyParent2' : {
220 220 'MyParent' : {
221 221 'MyConfigurable' : {
222 222 'b' : 2.0,
223 223 }
224 224 },
225 225 # this one shouldn't count
226 226 'MyConfigurable' : {
227 227 'b' : 3.0,
228 228 },
229 229 }
230 230 })
231 231 parent2 = MyParent2(config=cfg)
232 232 parent = MyParent(parent=parent2)
233 233 myc = MyConfigurable(parent=parent)
234 234 self.assertEqual(myc.b, parent.config.MyParent2.MyParent.MyConfigurable.b)
235 235
236 236 def test_parent_priority(self):
237 237 cfg = Config({
238 238 'MyConfigurable' : {
239 239 'b' : 2.0,
240 240 },
241 241 'MyParent' : {
242 242 'MyConfigurable' : {
243 243 'b' : 3.0,
244 244 }
245 245 },
246 246 'MyParent2' : {
247 247 'MyConfigurable' : {
248 248 'b' : 4.0,
249 249 }
250 250 }
251 251 })
252 252 parent = MyParent2(config=cfg)
253 253 myc = MyConfigurable(parent=parent)
254 254 self.assertEqual(myc.b, parent.config.MyParent2.MyConfigurable.b)
255 255
256 256 def test_multi_parent_priority(self):
257 257 cfg = Config({
258 258 'MyConfigurable' : {
259 259 'b' : 2.0,
260 260 },
261 261 'MyParent' : {
262 262 'MyConfigurable' : {
263 263 'b' : 3.0,
264 264 }
265 265 },
266 266 'MyParent2' : {
267 267 'MyConfigurable' : {
268 268 'b' : 4.0,
269 269 }
270 270 },
271 271 'MyParent2' : {
272 272 'MyParent' : {
273 273 'MyConfigurable' : {
274 274 'b' : 5.0,
275 275 }
276 276 }
277 277 }
278 278 })
279 279 parent2 = MyParent2(config=cfg)
280 280 parent = MyParent2(parent=parent2)
281 281 myc = MyConfigurable(parent=parent)
282 282 self.assertEqual(myc.b, parent.config.MyParent2.MyParent.MyConfigurable.b)
283 283
284 class Containers(Configurable):
285 lis = List(config=True)
286 def _lis_default(self):
287 return [-1]
288
289 s = Set(config=True)
290 def _s_default(self):
291 return {'a'}
292
293 d = Dict(config=True)
294 def _d_default(self):
295 return {'a' : 'b'}
296
297 class TestConfigContainers(TestCase):
298 def test_extend(self):
299 c = Config()
300 c.Containers.lis.extend(range(5))
301 obj = Containers(config=c)
302 self.assertEqual(obj.lis, range(-1,5))
303
304 def test_insert(self):
305 c = Config()
306 c.Containers.lis.insert(0, 'a')
307 c.Containers.lis.insert(1, 'b')
308 obj = Containers(config=c)
309 self.assertEqual(obj.lis, ['a', 'b', -1])
310
311 def test_prepend(self):
312 c = Config()
313 c.Containers.lis.prepend([1,2])
314 c.Containers.lis.prepend([2,3])
315 obj = Containers(config=c)
316 self.assertEqual(obj.lis, [2,3,1,2,-1])
317
318 def test_prepend_extend(self):
319 c = Config()
320 c.Containers.lis.prepend([1,2])
321 c.Containers.lis.extend([2,3])
322 obj = Containers(config=c)
323 self.assertEqual(obj.lis, [1,2,-1,2,3])
324
325 def test_append_extend(self):
326 c = Config()
327 c.Containers.lis.append([1,2])
328 c.Containers.lis.extend([2,3])
329 obj = Containers(config=c)
330 self.assertEqual(obj.lis, [-1,[1,2],2,3])
331
332 def test_extend_append(self):
333 c = Config()
334 c.Containers.lis.extend([2,3])
335 c.Containers.lis.append([1,2])
336 obj = Containers(config=c)
337 self.assertEqual(obj.lis, [-1,2,3,[1,2]])
338
339 def test_insert_extend(self):
340 c = Config()
341 c.Containers.lis.insert(0, 1)
342 c.Containers.lis.extend([2,3])
343 obj = Containers(config=c)
344 self.assertEqual(obj.lis, [1,-1,2,3])
345
346 def test_set_update(self):
347 c = Config()
348 c.Containers.s.update({0,1,2})
349 c.Containers.s.update({3})
350 obj = Containers(config=c)
351 self.assertEqual(obj.s, {'a', 0, 1, 2, 3})
352
353 def test_dict_update(self):
354 c = Config()
355 c.Containers.d.update({'c' : 'd'})
356 c.Containers.d.update({'e' : 'f'})
357 obj = Containers(config=c)
358 self.assertEqual(obj.d, {'a':'b', 'c':'d', 'e':'f'})
359
360
@@ -1,282 +1,292 b''
1 1 # encoding: utf-8
2 2 """
3 3 Tests for IPython.config.loader
4 4
5 5 Authors:
6 6
7 7 * Brian Granger
8 8 * Fernando Perez (design help)
9 9 """
10 10
11 11 #-----------------------------------------------------------------------------
12 # Copyright (C) 2008-2011 The IPython Development Team
12 # Copyright (C) 2008 The IPython Development Team
13 13 #
14 14 # Distributed under the terms of the BSD License. The full license is in
15 15 # the file COPYING, distributed as part of this software.
16 16 #-----------------------------------------------------------------------------
17 17
18 18 #-----------------------------------------------------------------------------
19 19 # Imports
20 20 #-----------------------------------------------------------------------------
21 21
22 22 import os
23 23 import sys
24 24 from tempfile import mkstemp
25 25 from unittest import TestCase
26 26
27 27 from nose import SkipTest
28 28
29 29 from IPython.testing.tools import mute_warn
30 30
31 31 from IPython.config.loader import (
32 32 Config,
33 33 PyFileConfigLoader,
34 34 KeyValueConfigLoader,
35 35 ArgParseConfigLoader,
36 36 KVArgParseConfigLoader,
37 37 ConfigError
38 38 )
39 39
40 40 #-----------------------------------------------------------------------------
41 41 # Actual tests
42 42 #-----------------------------------------------------------------------------
43 43
44 44
45 45 pyfile = """
46 46 c = get_config()
47 47 c.a=10
48 48 c.b=20
49 49 c.Foo.Bar.value=10
50 50 c.Foo.Bam.value=list(range(10)) # list() is just so it's the same on Python 3
51 51 c.D.C.value='hi there'
52 52 """
53 53
54 54 class TestPyFileCL(TestCase):
55 55
56 56 def test_basic(self):
57 57 fd, fname = mkstemp('.py')
58 58 f = os.fdopen(fd, 'w')
59 59 f.write(pyfile)
60 60 f.close()
61 61 # Unlink the file
62 62 cl = PyFileConfigLoader(fname)
63 63 config = cl.load_config()
64 64 self.assertEqual(config.a, 10)
65 65 self.assertEqual(config.b, 20)
66 66 self.assertEqual(config.Foo.Bar.value, 10)
67 67 self.assertEqual(config.Foo.Bam.value, range(10))
68 68 self.assertEqual(config.D.C.value, 'hi there')
69 69
70 70 class MyLoader1(ArgParseConfigLoader):
71 71 def _add_arguments(self, aliases=None, flags=None):
72 72 p = self.parser
73 73 p.add_argument('-f', '--foo', dest='Global.foo', type=str)
74 74 p.add_argument('-b', dest='MyClass.bar', type=int)
75 75 p.add_argument('-n', dest='n', action='store_true')
76 76 p.add_argument('Global.bam', type=str)
77 77
78 78 class MyLoader2(ArgParseConfigLoader):
79 79 def _add_arguments(self, aliases=None, flags=None):
80 80 subparsers = self.parser.add_subparsers(dest='subparser_name')
81 81 subparser1 = subparsers.add_parser('1')
82 82 subparser1.add_argument('-x',dest='Global.x')
83 83 subparser2 = subparsers.add_parser('2')
84 84 subparser2.add_argument('y')
85 85
86 86 class TestArgParseCL(TestCase):
87 87
88 88 def test_basic(self):
89 89 cl = MyLoader1()
90 90 config = cl.load_config('-f hi -b 10 -n wow'.split())
91 91 self.assertEqual(config.Global.foo, 'hi')
92 92 self.assertEqual(config.MyClass.bar, 10)
93 93 self.assertEqual(config.n, True)
94 94 self.assertEqual(config.Global.bam, 'wow')
95 95 config = cl.load_config(['wow'])
96 96 self.assertEqual(config.keys(), ['Global'])
97 97 self.assertEqual(config.Global.keys(), ['bam'])
98 98 self.assertEqual(config.Global.bam, 'wow')
99 99
100 100 def test_add_arguments(self):
101 101 cl = MyLoader2()
102 102 config = cl.load_config('2 frobble'.split())
103 103 self.assertEqual(config.subparser_name, '2')
104 104 self.assertEqual(config.y, 'frobble')
105 105 config = cl.load_config('1 -x frobble'.split())
106 106 self.assertEqual(config.subparser_name, '1')
107 107 self.assertEqual(config.Global.x, 'frobble')
108 108
109 109 def test_argv(self):
110 110 cl = MyLoader1(argv='-f hi -b 10 -n wow'.split())
111 111 config = cl.load_config()
112 112 self.assertEqual(config.Global.foo, 'hi')
113 113 self.assertEqual(config.MyClass.bar, 10)
114 114 self.assertEqual(config.n, True)
115 115 self.assertEqual(config.Global.bam, 'wow')
116 116
117 117
118 118 class TestKeyValueCL(TestCase):
119 119 klass = KeyValueConfigLoader
120 120
121 121 def test_basic(self):
122 122 cl = self.klass()
123 123 argv = ['--'+s.strip('c.') for s in pyfile.split('\n')[2:-1]]
124 124 with mute_warn():
125 125 config = cl.load_config(argv)
126 126 self.assertEqual(config.a, 10)
127 127 self.assertEqual(config.b, 20)
128 128 self.assertEqual(config.Foo.Bar.value, 10)
129 129 self.assertEqual(config.Foo.Bam.value, range(10))
130 130 self.assertEqual(config.D.C.value, 'hi there')
131 131
132 132 def test_expanduser(self):
133 133 cl = self.klass()
134 134 argv = ['--a=~/1/2/3', '--b=~', '--c=~/', '--d="~/"']
135 135 with mute_warn():
136 136 config = cl.load_config(argv)
137 137 self.assertEqual(config.a, os.path.expanduser('~/1/2/3'))
138 138 self.assertEqual(config.b, os.path.expanduser('~'))
139 139 self.assertEqual(config.c, os.path.expanduser('~/'))
140 140 self.assertEqual(config.d, '~/')
141 141
142 142 def test_extra_args(self):
143 143 cl = self.klass()
144 144 with mute_warn():
145 145 config = cl.load_config(['--a=5', 'b', '--c=10', 'd'])
146 146 self.assertEqual(cl.extra_args, ['b', 'd'])
147 147 self.assertEqual(config.a, 5)
148 148 self.assertEqual(config.c, 10)
149 149 with mute_warn():
150 150 config = cl.load_config(['--', '--a=5', '--c=10'])
151 151 self.assertEqual(cl.extra_args, ['--a=5', '--c=10'])
152 152
153 153 def test_unicode_args(self):
154 154 cl = self.klass()
155 155 argv = [u'--a=épsîlön']
156 156 with mute_warn():
157 157 config = cl.load_config(argv)
158 158 self.assertEqual(config.a, u'épsîlön')
159 159
160 160 def test_unicode_bytes_args(self):
161 161 uarg = u'--a=é'
162 162 try:
163 163 barg = uarg.encode(sys.stdin.encoding)
164 164 except (TypeError, UnicodeEncodeError):
165 165 raise SkipTest("sys.stdin.encoding can't handle 'é'")
166 166
167 167 cl = self.klass()
168 168 with mute_warn():
169 169 config = cl.load_config([barg])
170 170 self.assertEqual(config.a, u'é')
171 171
172 172 def test_unicode_alias(self):
173 173 cl = self.klass()
174 174 argv = [u'--a=épsîlön']
175 175 with mute_warn():
176 176 config = cl.load_config(argv, aliases=dict(a='A.a'))
177 177 self.assertEqual(config.A.a, u'épsîlön')
178 178
179 179
180 180 class TestArgParseKVCL(TestKeyValueCL):
181 181 klass = KVArgParseConfigLoader
182 182
183 183 def test_expanduser2(self):
184 184 cl = self.klass()
185 185 argv = ['-a', '~/1/2/3', '--b', "'~/1/2/3'"]
186 186 with mute_warn():
187 187 config = cl.load_config(argv, aliases=dict(a='A.a', b='A.b'))
188 188 self.assertEqual(config.A.a, os.path.expanduser('~/1/2/3'))
189 189 self.assertEqual(config.A.b, '~/1/2/3')
190 190
191 191 def test_eval(self):
192 192 cl = self.klass()
193 193 argv = ['-c', 'a=5']
194 194 with mute_warn():
195 195 config = cl.load_config(argv, aliases=dict(c='A.c'))
196 196 self.assertEqual(config.A.c, u"a=5")
197 197
198 198
199 199 class TestConfig(TestCase):
200 200
201 201 def test_setget(self):
202 202 c = Config()
203 203 c.a = 10
204 204 self.assertEqual(c.a, 10)
205 205 self.assertEqual('b' in c, False)
206 206
207 207 def test_auto_section(self):
208 208 c = Config()
209 209 self.assertEqual('A' in c, True)
210 210 self.assertEqual(c._has_section('A'), False)
211 211 A = c.A
212 212 A.foo = 'hi there'
213 213 self.assertEqual(c._has_section('A'), True)
214 214 self.assertEqual(c.A.foo, 'hi there')
215 215 del c.A
216 216 self.assertEqual(len(c.A.keys()),0)
217 217
218 218 def test_merge_doesnt_exist(self):
219 219 c1 = Config()
220 220 c2 = Config()
221 221 c2.bar = 10
222 222 c2.Foo.bar = 10
223 223 c1.merge(c2)
224 224 self.assertEqual(c1.Foo.bar, 10)
225 225 self.assertEqual(c1.bar, 10)
226 226 c2.Bar.bar = 10
227 227 c1.merge(c2)
228 228 self.assertEqual(c1.Bar.bar, 10)
229 229
230 230 def test_merge_exists(self):
231 231 c1 = Config()
232 232 c2 = Config()
233 233 c1.Foo.bar = 10
234 234 c1.Foo.bam = 30
235 235 c2.Foo.bar = 20
236 236 c2.Foo.wow = 40
237 237 c1.merge(c2)
238 238 self.assertEqual(c1.Foo.bam, 30)
239 239 self.assertEqual(c1.Foo.bar, 20)
240 240 self.assertEqual(c1.Foo.wow, 40)
241 241 c2.Foo.Bam.bam = 10
242 242 c1.merge(c2)
243 243 self.assertEqual(c1.Foo.Bam.bam, 10)
244 244
245 245 def test_deepcopy(self):
246 246 c1 = Config()
247 247 c1.Foo.bar = 10
248 248 c1.Foo.bam = 30
249 249 c1.a = 'asdf'
250 250 c1.b = range(10)
251 251 import copy
252 252 c2 = copy.deepcopy(c1)
253 253 self.assertEqual(c1, c2)
254 254 self.assertTrue(c1 is not c2)
255 255 self.assertTrue(c1.Foo is not c2.Foo)
256 256
257 257 def test_builtin(self):
258 258 c1 = Config()
259 259 exec 'foo = True' in c1
260 260 self.assertEqual(c1.foo, True)
261 261 c1.format = "json"
262 262
263 263 def test_fromdict(self):
264 264 c1 = Config({'Foo' : {'bar' : 1}})
265 265 self.assertEqual(c1.Foo.__class__, Config)
266 266 self.assertEqual(c1.Foo.bar, 1)
267 267
268 268 def test_fromdictmerge(self):
269 269 c1 = Config()
270 270 c2 = Config({'Foo' : {'bar' : 1}})
271 271 c1.merge(c2)
272 272 self.assertEqual(c1.Foo.__class__, Config)
273 273 self.assertEqual(c1.Foo.bar, 1)
274 274
275 275 def test_fromdictmerge2(self):
276 276 c1 = Config({'Foo' : {'baz' : 2}})
277 277 c2 = Config({'Foo' : {'bar' : 1}})
278 278 c1.merge(c2)
279 279 self.assertEqual(c1.Foo.__class__, Config)
280 280 self.assertEqual(c1.Foo.bar, 1)
281 281 self.assertEqual(c1.Foo.baz, 2)
282 self.assertRaises(AttributeError, getattr, c2.Foo, 'baz')
282 self.assertNotIn('baz', c2.Foo)
283
284 def test_contains(self):
285 c1 = Config({'Foo' : {'baz' : 2}})
286 c2 = Config({'Foo' : {'bar' : 1}})
287 self.assertIn('Foo', c1)
288 self.assertIn('Foo.baz', c1)
289 self.assertIn('Foo.bar', c2)
290 self.assertNotIn('Foo.bar', c1)
291
292
@@ -1,383 +1,381 b''
1 1 # encoding: utf-8
2 2 """
3 3 An application for IPython.
4 4
5 5 All top-level applications should use the classes in this module for
6 6 handling configuration and creating configurables.
7 7
8 8 The job of an :class:`Application` is to create the master configuration
9 9 object and then create the configurable objects, passing the config to them.
10 10
11 11 Authors:
12 12
13 13 * Brian Granger
14 14 * Fernando Perez
15 15 * Min RK
16 16
17 17 """
18 18
19 19 #-----------------------------------------------------------------------------
20 20 # Copyright (C) 2008 The IPython Development Team
21 21 #
22 22 # Distributed under the terms of the BSD License. The full license is in
23 23 # the file COPYING, distributed as part of this software.
24 24 #-----------------------------------------------------------------------------
25 25
26 26 #-----------------------------------------------------------------------------
27 27 # Imports
28 28 #-----------------------------------------------------------------------------
29 29
30 30 import atexit
31 31 import errno
32 32 import glob
33 33 import logging
34 34 import os
35 35 import shutil
36 36 import sys
37 37
38 38 from IPython.config.application import Application, catch_config_error
39 39 from IPython.config.loader import ConfigFileNotFound
40 40 from IPython.core import release, crashhandler
41 41 from IPython.core.profiledir import ProfileDir, ProfileDirError
42 42 from IPython.utils.path import get_ipython_dir, get_ipython_package_dir
43 43 from IPython.utils.traitlets import List, Unicode, Type, Bool, Dict, Set, Instance
44 44
45 45 #-----------------------------------------------------------------------------
46 46 # Classes and functions
47 47 #-----------------------------------------------------------------------------
48 48
49 49
50 50 #-----------------------------------------------------------------------------
51 51 # Base Application Class
52 52 #-----------------------------------------------------------------------------
53 53
54 54 # aliases and flags
55 55
56 56 base_aliases = {
57 57 'profile-dir' : 'ProfileDir.location',
58 58 'profile' : 'BaseIPythonApplication.profile',
59 59 'ipython-dir' : 'BaseIPythonApplication.ipython_dir',
60 60 'log-level' : 'Application.log_level',
61 61 'config' : 'BaseIPythonApplication.extra_config_file',
62 62 }
63 63
64 64 base_flags = dict(
65 65 debug = ({'Application' : {'log_level' : logging.DEBUG}},
66 66 "set log level to logging.DEBUG (maximize logging output)"),
67 67 quiet = ({'Application' : {'log_level' : logging.CRITICAL}},
68 68 "set log level to logging.CRITICAL (minimize logging output)"),
69 69 init = ({'BaseIPythonApplication' : {
70 70 'copy_config_files' : True,
71 71 'auto_create' : True}
72 72 }, """Initialize profile with default config files. This is equivalent
73 73 to running `ipython profile create <profile>` prior to startup.
74 74 """)
75 75 )
76 76
77 77
78 78 class BaseIPythonApplication(Application):
79 79
80 80 name = Unicode(u'ipython')
81 81 description = Unicode(u'IPython: an enhanced interactive Python shell.')
82 82 version = Unicode(release.version)
83 83
84 84 aliases = Dict(base_aliases)
85 85 flags = Dict(base_flags)
86 86 classes = List([ProfileDir])
87 87
88 88 # Track whether the config_file has changed,
89 89 # because some logic happens only if we aren't using the default.
90 90 config_file_specified = Set()
91 91
92 92 config_file_name = Unicode()
93 93 def _config_file_name_default(self):
94 94 return self.name.replace('-','_') + u'_config.py'
95 95 def _config_file_name_changed(self, name, old, new):
96 96 if new != old:
97 97 self.config_file_specified.add(new)
98 98
99 99 # The directory that contains IPython's builtin profiles.
100 100 builtin_profile_dir = Unicode(
101 101 os.path.join(get_ipython_package_dir(), u'config', u'profile', u'default')
102 102 )
103 103
104 104 config_file_paths = List(Unicode)
105 105 def _config_file_paths_default(self):
106 106 return [os.getcwdu()]
107 107
108 108 extra_config_file = Unicode(config=True,
109 109 help="""Path to an extra config file to load.
110 110
111 111 If specified, load this config file in addition to any other IPython config.
112 112 """)
113 113 def _extra_config_file_changed(self, name, old, new):
114 114 try:
115 115 self.config_files.remove(old)
116 116 except ValueError:
117 117 pass
118 118 self.config_file_specified.add(new)
119 119 self.config_files.append(new)
120 120
121 121 profile = Unicode(u'default', config=True,
122 122 help="""The IPython profile to use."""
123 123 )
124 124
125 125 def _profile_changed(self, name, old, new):
126 126 self.builtin_profile_dir = os.path.join(
127 127 get_ipython_package_dir(), u'config', u'profile', new
128 128 )
129 129
130 130 ipython_dir = Unicode(config=True,
131 131 help="""
132 132 The name of the IPython directory. This directory is used for logging
133 133 configuration (through profiles), history storage, etc. The default
134 134 is usually $HOME/.ipython. This options can also be specified through
135 135 the environment variable IPYTHONDIR.
136 136 """
137 137 )
138 138 def _ipython_dir_default(self):
139 139 d = get_ipython_dir()
140 140 self._ipython_dir_changed('ipython_dir', d, d)
141 141 return d
142 142
143 143 _in_init_profile_dir = False
144 144 profile_dir = Instance(ProfileDir)
145 145 def _profile_dir_default(self):
146 146 # avoid recursion
147 147 if self._in_init_profile_dir:
148 148 return
149 149 # profile_dir requested early, force initialization
150 150 self.init_profile_dir()
151 151 return self.profile_dir
152 152
153 153 overwrite = Bool(False, config=True,
154 154 help="""Whether to overwrite existing config files when copying""")
155 155 auto_create = Bool(False, config=True,
156 156 help="""Whether to create profile dir if it doesn't exist""")
157 157
158 158 config_files = List(Unicode)
159 159 def _config_files_default(self):
160 160 return [self.config_file_name]
161 161
162 162 copy_config_files = Bool(False, config=True,
163 163 help="""Whether to install the default config files into the profile dir.
164 164 If a new profile is being created, and IPython contains config files for that
165 165 profile, then they will be staged into the new directory. Otherwise,
166 166 default config files will be automatically generated.
167 167 """)
168 168
169 169 verbose_crash = Bool(False, config=True,
170 170 help="""Create a massive crash report when IPython encounters what may be an
171 171 internal error. The default is to append a short message to the
172 172 usual traceback""")
173 173
174 174 # The class to use as the crash handler.
175 175 crash_handler_class = Type(crashhandler.CrashHandler)
176 176
177 177 @catch_config_error
178 178 def __init__(self, **kwargs):
179 179 super(BaseIPythonApplication, self).__init__(**kwargs)
180 180 # ensure current working directory exists
181 181 try:
182 182 directory = os.getcwdu()
183 183 except:
184 184 # raise exception
185 185 self.log.error("Current working directory doesn't exist.")
186 186 raise
187 187
188 188 #-------------------------------------------------------------------------
189 189 # Various stages of Application creation
190 190 #-------------------------------------------------------------------------
191 191
192 192 def init_crash_handler(self):
193 193 """Create a crash handler, typically setting sys.excepthook to it."""
194 194 self.crash_handler = self.crash_handler_class(self)
195 195 sys.excepthook = self.excepthook
196 196 def unset_crashhandler():
197 197 sys.excepthook = sys.__excepthook__
198 198 atexit.register(unset_crashhandler)
199 199
200 200 def excepthook(self, etype, evalue, tb):
201 201 """this is sys.excepthook after init_crashhandler
202 202
203 203 set self.verbose_crash=True to use our full crashhandler, instead of
204 204 a regular traceback with a short message (crash_handler_lite)
205 205 """
206 206
207 207 if self.verbose_crash:
208 208 return self.crash_handler(etype, evalue, tb)
209 209 else:
210 210 return crashhandler.crash_handler_lite(etype, evalue, tb)
211 211
212 212 def _ipython_dir_changed(self, name, old, new):
213 213 if old in sys.path:
214 214 sys.path.remove(old)
215 215 sys.path.append(os.path.abspath(new))
216 216 if not os.path.isdir(new):
217 217 os.makedirs(new, mode=0o777)
218 218 readme = os.path.join(new, 'README')
219 219 readme_src = os.path.join(get_ipython_package_dir(), u'config', u'profile', 'README')
220 220 if not os.path.exists(readme) and os.path.exists(readme_src):
221 221 shutil.copy(readme_src, readme)
222 222 for d in ('extensions', 'nbextensions'):
223 223 path = os.path.join(new, d)
224 224 if not os.path.exists(path):
225 225 try:
226 226 os.mkdir(path)
227 227 except OSError as e:
228 228 if e.errno != errno.EEXIST:
229 229 self.log.error("couldn't create path %s: %s", path, e)
230 230 self.log.debug("IPYTHONDIR set to: %s" % new)
231 231
232 232 def load_config_file(self, suppress_errors=True):
233 233 """Load the config file.
234 234
235 235 By default, errors in loading config are handled, and a warning
236 236 printed on screen. For testing, the suppress_errors option is set
237 237 to False, so errors will make tests fail.
238 238 """
239 239 self.log.debug("Searching path %s for config files", self.config_file_paths)
240 240 base_config = 'ipython_config.py'
241 241 self.log.debug("Attempting to load config file: %s" %
242 242 base_config)
243 243 try:
244 244 Application.load_config_file(
245 245 self,
246 246 base_config,
247 247 path=self.config_file_paths
248 248 )
249 249 except ConfigFileNotFound:
250 250 # ignore errors loading parent
251 251 self.log.debug("Config file %s not found", base_config)
252 252 pass
253 253
254 254 for config_file_name in self.config_files:
255 255 if not config_file_name or config_file_name == base_config:
256 256 continue
257 257 self.log.debug("Attempting to load config file: %s" %
258 258 self.config_file_name)
259 259 try:
260 260 Application.load_config_file(
261 261 self,
262 262 config_file_name,
263 263 path=self.config_file_paths
264 264 )
265 265 except ConfigFileNotFound:
266 266 # Only warn if the default config file was NOT being used.
267 267 if config_file_name in self.config_file_specified:
268 268 msg = self.log.warn
269 269 else:
270 270 msg = self.log.debug
271 271 msg("Config file not found, skipping: %s", config_file_name)
272 272 except:
273 273 # For testing purposes.
274 274 if not suppress_errors:
275 275 raise
276 276 self.log.warn("Error loading config file: %s" %
277 277 self.config_file_name, exc_info=True)
278 278
279 279 def init_profile_dir(self):
280 280 """initialize the profile dir"""
281 281 self._in_init_profile_dir = True
282 282 if self.profile_dir is not None:
283 283 # already ran
284 284 return
285 try:
286 # location explicitly specified:
287 location = self.config.ProfileDir.location
288 except AttributeError:
285 if 'ProfileDir.location' not in self.config:
289 286 # location not specified, find by profile name
290 287 try:
291 288 p = ProfileDir.find_profile_dir_by_name(self.ipython_dir, self.profile, self.config)
292 289 except ProfileDirError:
293 290 # not found, maybe create it (always create default profile)
294 291 if self.auto_create or self.profile == 'default':
295 292 try:
296 293 p = ProfileDir.create_profile_dir_by_name(self.ipython_dir, self.profile, self.config)
297 294 except ProfileDirError:
298 295 self.log.fatal("Could not create profile: %r"%self.profile)
299 296 self.exit(1)
300 297 else:
301 298 self.log.info("Created profile dir: %r"%p.location)
302 299 else:
303 300 self.log.fatal("Profile %r not found."%self.profile)
304 301 self.exit(1)
305 302 else:
306 303 self.log.info("Using existing profile dir: %r"%p.location)
307 304 else:
305 location = self.config.ProfileDir.location
308 306 # location is fully specified
309 307 try:
310 308 p = ProfileDir.find_profile_dir(location, self.config)
311 309 except ProfileDirError:
312 310 # not found, maybe create it
313 311 if self.auto_create:
314 312 try:
315 313 p = ProfileDir.create_profile_dir(location, self.config)
316 314 except ProfileDirError:
317 315 self.log.fatal("Could not create profile directory: %r"%location)
318 316 self.exit(1)
319 317 else:
320 318 self.log.info("Creating new profile dir: %r"%location)
321 319 else:
322 320 self.log.fatal("Profile directory %r not found."%location)
323 321 self.exit(1)
324 322 else:
325 323 self.log.info("Using existing profile dir: %r"%location)
326 324
327 325 self.profile_dir = p
328 326 self.config_file_paths.append(p.location)
329 327 self._in_init_profile_dir = False
330 328
331 329 def init_config_files(self):
332 330 """[optionally] copy default config files into profile dir."""
333 331 # copy config files
334 332 path = self.builtin_profile_dir
335 333 if self.copy_config_files:
336 334 src = self.profile
337 335
338 336 cfg = self.config_file_name
339 337 if path and os.path.exists(os.path.join(path, cfg)):
340 338 self.log.warn("Staging %r from %s into %r [overwrite=%s]"%(
341 339 cfg, src, self.profile_dir.location, self.overwrite)
342 340 )
343 341 self.profile_dir.copy_config_file(cfg, path=path, overwrite=self.overwrite)
344 342 else:
345 343 self.stage_default_config_file()
346 344 else:
347 345 # Still stage *bundled* config files, but not generated ones
348 346 # This is necessary for `ipython profile=sympy` to load the profile
349 347 # on the first go
350 348 files = glob.glob(os.path.join(path, '*.py'))
351 349 for fullpath in files:
352 350 cfg = os.path.basename(fullpath)
353 351 if self.profile_dir.copy_config_file(cfg, path=path, overwrite=False):
354 352 # file was copied
355 353 self.log.warn("Staging bundled %s from %s into %r"%(
356 354 cfg, self.profile, self.profile_dir.location)
357 355 )
358 356
359 357
360 358 def stage_default_config_file(self):
361 359 """auto generate default config file, and stage it into the profile."""
362 360 s = self.generate_config_file()
363 361 fname = os.path.join(self.profile_dir.location, self.config_file_name)
364 362 if self.overwrite or not os.path.exists(fname):
365 363 self.log.warn("Generating default config file: %r"%(fname))
366 364 with open(fname, 'w') as f:
367 365 f.write(s)
368 366
369 367 @catch_config_error
370 368 def initialize(self, argv=None):
371 369 # don't hook up crash handler before parsing command-line
372 370 self.parse_command_line(argv)
373 371 self.init_crash_handler()
374 372 if self.subapp is not None:
375 373 # stop here if subapp is taking over
376 374 return
377 375 cl_config = self.config
378 376 self.init_profile_dir()
379 377 self.init_config_files()
380 378 self.load_config_file()
381 379 # enforce cl-opts override configfile opts:
382 380 self.update_config(cl_config)
383 381
@@ -1,547 +1,547 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The IPython controller application.
5 5
6 6 Authors:
7 7
8 8 * Brian Granger
9 9 * MinRK
10 10
11 11 """
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Copyright (C) 2008 The IPython Development Team
15 15 #
16 16 # Distributed under the terms of the BSD License. The full license is in
17 17 # the file COPYING, distributed as part of this software.
18 18 #-----------------------------------------------------------------------------
19 19
20 20 #-----------------------------------------------------------------------------
21 21 # Imports
22 22 #-----------------------------------------------------------------------------
23 23
24 24 from __future__ import with_statement
25 25
26 26 import json
27 27 import os
28 28 import stat
29 29 import sys
30 30
31 31 from multiprocessing import Process
32 32 from signal import signal, SIGINT, SIGABRT, SIGTERM
33 33
34 34 import zmq
35 35 from zmq.devices import ProcessMonitoredQueue
36 36 from zmq.log.handlers import PUBHandler
37 37
38 38 from IPython.core.profiledir import ProfileDir
39 39
40 40 from IPython.parallel.apps.baseapp import (
41 41 BaseParallelApplication,
42 42 base_aliases,
43 43 base_flags,
44 44 catch_config_error,
45 45 )
46 46 from IPython.utils.importstring import import_item
47 47 from IPython.utils.localinterfaces import localhost, public_ips
48 48 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict, TraitError
49 49
50 50 from IPython.kernel.zmq.session import (
51 51 Session, session_aliases, session_flags, default_secure
52 52 )
53 53
54 54 from IPython.parallel.controller.heartmonitor import HeartMonitor
55 55 from IPython.parallel.controller.hub import HubFactory
56 56 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
57 57 from IPython.parallel.controller.dictdb import DictDB
58 58
59 59 from IPython.parallel.util import split_url, disambiguate_url, set_hwm
60 60
61 61 # conditional import of SQLiteDB / MongoDB backend class
62 62 real_dbs = []
63 63
64 64 try:
65 65 from IPython.parallel.controller.sqlitedb import SQLiteDB
66 66 except ImportError:
67 67 pass
68 68 else:
69 69 real_dbs.append(SQLiteDB)
70 70
71 71 try:
72 72 from IPython.parallel.controller.mongodb import MongoDB
73 73 except ImportError:
74 74 pass
75 75 else:
76 76 real_dbs.append(MongoDB)
77 77
78 78
79 79
80 80 #-----------------------------------------------------------------------------
81 81 # Module level variables
82 82 #-----------------------------------------------------------------------------
83 83
84 84
85 85 _description = """Start the IPython controller for parallel computing.
86 86
87 87 The IPython controller provides a gateway between the IPython engines and
88 88 clients. The controller needs to be started before the engines and can be
89 89 configured using command line options or using a cluster directory. Cluster
90 90 directories contain config, log and security files and are usually located in
91 91 your ipython directory and named as "profile_name". See the `profile`
92 92 and `profile-dir` options for details.
93 93 """
94 94
95 95 _examples = """
96 96 ipcontroller --ip=192.168.0.1 --port=1000 # listen on ip, port for engines
97 97 ipcontroller --scheme=pure # use the pure zeromq scheduler
98 98 """
99 99
100 100
101 101 #-----------------------------------------------------------------------------
102 102 # The main application
103 103 #-----------------------------------------------------------------------------
104 104 flags = {}
105 105 flags.update(base_flags)
106 106 flags.update({
107 107 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
108 108 'Use threads instead of processes for the schedulers'),
109 109 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
110 110 'use the SQLiteDB backend'),
111 111 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
112 112 'use the MongoDB backend'),
113 113 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
114 114 'use the in-memory DictDB backend'),
115 115 'nodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.NoDB'}},
116 116 """use dummy DB backend, which doesn't store any information.
117 117
118 118 This is the default as of IPython 0.13.
119 119
120 120 To enable delayed or repeated retrieval of results from the Hub,
121 121 select one of the true db backends.
122 122 """),
123 123 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
124 124 'reuse existing json connection files'),
125 125 'restore' : ({'IPControllerApp' : {'restore_engines' : True, 'reuse_files' : True}},
126 126 'Attempt to restore engines from a JSON file. '
127 127 'For use when resuming a crashed controller'),
128 128 })
129 129
130 130 flags.update(session_flags)
131 131
132 132 aliases = dict(
133 133 ssh = 'IPControllerApp.ssh_server',
134 134 enginessh = 'IPControllerApp.engine_ssh_server',
135 135 location = 'IPControllerApp.location',
136 136
137 137 url = 'HubFactory.url',
138 138 ip = 'HubFactory.ip',
139 139 transport = 'HubFactory.transport',
140 140 port = 'HubFactory.regport',
141 141
142 142 ping = 'HeartMonitor.period',
143 143
144 144 scheme = 'TaskScheduler.scheme_name',
145 145 hwm = 'TaskScheduler.hwm',
146 146 )
147 147 aliases.update(base_aliases)
148 148 aliases.update(session_aliases)
149 149
150 150 class IPControllerApp(BaseParallelApplication):
151 151
152 152 name = u'ipcontroller'
153 153 description = _description
154 154 examples = _examples
155 155 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, DictDB] + real_dbs
156 156
157 157 # change default to True
158 158 auto_create = Bool(True, config=True,
159 159 help="""Whether to create profile dir if it doesn't exist.""")
160 160
161 161 reuse_files = Bool(False, config=True,
162 162 help="""Whether to reuse existing json connection files.
163 163 If False, connection files will be removed on a clean exit.
164 164 """
165 165 )
166 166 restore_engines = Bool(False, config=True,
167 167 help="""Reload engine state from JSON file
168 168 """
169 169 )
170 170 ssh_server = Unicode(u'', config=True,
171 171 help="""ssh url for clients to use when connecting to the Controller
172 172 processes. It should be of the form: [user@]server[:port]. The
173 173 Controller's listening addresses must be accessible from the ssh server""",
174 174 )
175 175 engine_ssh_server = Unicode(u'', config=True,
176 176 help="""ssh url for engines to use when connecting to the Controller
177 177 processes. It should be of the form: [user@]server[:port]. The
178 178 Controller's listening addresses must be accessible from the ssh server""",
179 179 )
180 180 location = Unicode(u'', config=True,
181 181 help="""The external IP or domain name of the Controller, used for disambiguating
182 182 engine and client connections.""",
183 183 )
184 184 import_statements = List([], config=True,
185 185 help="import statements to be run at startup. Necessary in some environments"
186 186 )
187 187
188 188 use_threads = Bool(False, config=True,
189 189 help='Use threads instead of processes for the schedulers',
190 190 )
191 191
192 192 engine_json_file = Unicode('ipcontroller-engine.json', config=True,
193 193 help="JSON filename where engine connection info will be stored.")
194 194 client_json_file = Unicode('ipcontroller-client.json', config=True,
195 195 help="JSON filename where client connection info will be stored.")
196 196
197 197 def _cluster_id_changed(self, name, old, new):
198 198 super(IPControllerApp, self)._cluster_id_changed(name, old, new)
199 199 self.engine_json_file = "%s-engine.json" % self.name
200 200 self.client_json_file = "%s-client.json" % self.name
201 201
202 202
203 203 # internal
204 204 children = List()
205 205 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
206 206
207 207 def _use_threads_changed(self, name, old, new):
208 208 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
209 209
210 210 write_connection_files = Bool(True,
211 211 help="""Whether to write connection files to disk.
212 212 True in all cases other than runs with `reuse_files=True` *after the first*
213 213 """
214 214 )
215 215
216 216 aliases = Dict(aliases)
217 217 flags = Dict(flags)
218 218
219 219
220 220 def save_connection_dict(self, fname, cdict):
221 221 """save a connection dict to json file."""
222 222 c = self.config
223 223 url = cdict['registration']
224 224 location = cdict['location']
225 225
226 226 if not location:
227 227 if public_ips():
228 228 location = public_ips()[-1]
229 229 else:
230 230 self.log.warn("Could not identify this machine's IP, assuming %s."
231 231 " You may need to specify '--location=<external_ip_address>' to help"
232 232 " IPython decide when to connect via loopback." % localhost() )
233 233 location = localhost()
234 234 cdict['location'] = location
235 235 fname = os.path.join(self.profile_dir.security_dir, fname)
236 236 self.log.info("writing connection info to %s", fname)
237 237 with open(fname, 'w') as f:
238 238 f.write(json.dumps(cdict, indent=2))
239 239 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
240 240
241 241 def load_config_from_json(self):
242 242 """load config from existing json connector files."""
243 243 c = self.config
244 244 self.log.debug("loading config from JSON")
245 245
246 246 # load engine config
247 247
248 248 fname = os.path.join(self.profile_dir.security_dir, self.engine_json_file)
249 249 self.log.info("loading connection info from %s", fname)
250 250 with open(fname) as f:
251 251 ecfg = json.loads(f.read())
252 252
253 253 # json gives unicode, Session.key wants bytes
254 254 c.Session.key = ecfg['key'].encode('ascii')
255 255
256 256 xport,ip = ecfg['interface'].split('://')
257 257
258 258 c.HubFactory.engine_ip = ip
259 259 c.HubFactory.engine_transport = xport
260 260
261 261 self.location = ecfg['location']
262 262 if not self.engine_ssh_server:
263 263 self.engine_ssh_server = ecfg['ssh']
264 264
265 265 # load client config
266 266
267 267 fname = os.path.join(self.profile_dir.security_dir, self.client_json_file)
268 268 self.log.info("loading connection info from %s", fname)
269 269 with open(fname) as f:
270 270 ccfg = json.loads(f.read())
271 271
272 272 for key in ('key', 'registration', 'pack', 'unpack', 'signature_scheme'):
273 273 assert ccfg[key] == ecfg[key], "mismatch between engine and client info: %r" % key
274 274
275 275 xport,addr = ccfg['interface'].split('://')
276 276
277 277 c.HubFactory.client_transport = xport
278 278 c.HubFactory.client_ip = ip
279 279 if not self.ssh_server:
280 280 self.ssh_server = ccfg['ssh']
281 281
282 282 # load port config:
283 283 c.HubFactory.regport = ecfg['registration']
284 284 c.HubFactory.hb = (ecfg['hb_ping'], ecfg['hb_pong'])
285 285 c.HubFactory.control = (ccfg['control'], ecfg['control'])
286 286 c.HubFactory.mux = (ccfg['mux'], ecfg['mux'])
287 287 c.HubFactory.task = (ccfg['task'], ecfg['task'])
288 288 c.HubFactory.iopub = (ccfg['iopub'], ecfg['iopub'])
289 289 c.HubFactory.notifier_port = ccfg['notification']
290 290
291 291 def cleanup_connection_files(self):
292 292 if self.reuse_files:
293 293 self.log.debug("leaving JSON connection files for reuse")
294 294 return
295 295 self.log.debug("cleaning up JSON connection files")
296 296 for f in (self.client_json_file, self.engine_json_file):
297 297 f = os.path.join(self.profile_dir.security_dir, f)
298 298 try:
299 299 os.remove(f)
300 300 except Exception as e:
301 301 self.log.error("Failed to cleanup connection file: %s", e)
302 302 else:
303 303 self.log.debug(u"removed %s", f)
304 304
305 305 def load_secondary_config(self):
306 306 """secondary config, loading from JSON and setting defaults"""
307 307 if self.reuse_files:
308 308 try:
309 309 self.load_config_from_json()
310 310 except (AssertionError,IOError) as e:
311 311 self.log.error("Could not load config from JSON: %s" % e)
312 312 else:
313 313 # successfully loaded config from JSON, and reuse=True
314 314 # no need to wite back the same file
315 315 self.write_connection_files = False
316 316
317 317 # switch Session.key default to secure
318 318 default_secure(self.config)
319 319 self.log.debug("Config changed")
320 320 self.log.debug(repr(self.config))
321 321
322 322 def init_hub(self):
323 323 c = self.config
324 324
325 325 self.do_import_statements()
326 326
327 327 try:
328 328 self.factory = HubFactory(config=c, log=self.log)
329 329 # self.start_logging()
330 330 self.factory.init_hub()
331 331 except TraitError:
332 332 raise
333 333 except Exception:
334 334 self.log.error("Couldn't construct the Controller", exc_info=True)
335 335 self.exit(1)
336 336
337 337 if self.write_connection_files:
338 338 # save to new json config files
339 339 f = self.factory
340 340 base = {
341 341 'key' : f.session.key.decode('ascii'),
342 342 'location' : self.location,
343 343 'pack' : f.session.packer,
344 344 'unpack' : f.session.unpacker,
345 345 'signature_scheme' : f.session.signature_scheme,
346 346 }
347 347
348 348 cdict = {'ssh' : self.ssh_server}
349 349 cdict.update(f.client_info)
350 350 cdict.update(base)
351 351 self.save_connection_dict(self.client_json_file, cdict)
352 352
353 353 edict = {'ssh' : self.engine_ssh_server}
354 354 edict.update(f.engine_info)
355 355 edict.update(base)
356 356 self.save_connection_dict(self.engine_json_file, edict)
357 357
358 358 fname = "engines%s.json" % self.cluster_id
359 359 self.factory.hub.engine_state_file = os.path.join(self.profile_dir.log_dir, fname)
360 360 if self.restore_engines:
361 361 self.factory.hub._load_engine_state()
362 362
363 363 def init_schedulers(self):
364 364 children = self.children
365 365 mq = import_item(str(self.mq_class))
366 366
367 367 f = self.factory
368 368 ident = f.session.bsession
369 369 # disambiguate url, in case of *
370 370 monitor_url = disambiguate_url(f.monitor_url)
371 371 # maybe_inproc = 'inproc://monitor' if self.use_threads else monitor_url
372 372 # IOPub relay (in a Process)
373 373 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, b'N/A',b'iopub')
374 374 q.bind_in(f.client_url('iopub'))
375 375 q.setsockopt_in(zmq.IDENTITY, ident + b"_iopub")
376 376 q.bind_out(f.engine_url('iopub'))
377 377 q.setsockopt_out(zmq.SUBSCRIBE, b'')
378 378 q.connect_mon(monitor_url)
379 379 q.daemon=True
380 380 children.append(q)
381 381
382 382 # Multiplexer Queue (in a Process)
383 383 q = mq(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out')
384 384
385 385 q.bind_in(f.client_url('mux'))
386 386 q.setsockopt_in(zmq.IDENTITY, b'mux_in')
387 387 q.bind_out(f.engine_url('mux'))
388 388 q.setsockopt_out(zmq.IDENTITY, b'mux_out')
389 389 q.connect_mon(monitor_url)
390 390 q.daemon=True
391 391 children.append(q)
392 392
393 393 # Control Queue (in a Process)
394 394 q = mq(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'incontrol', b'outcontrol')
395 395 q.bind_in(f.client_url('control'))
396 396 q.setsockopt_in(zmq.IDENTITY, b'control_in')
397 397 q.bind_out(f.engine_url('control'))
398 398 q.setsockopt_out(zmq.IDENTITY, b'control_out')
399 399 q.connect_mon(monitor_url)
400 400 q.daemon=True
401 401 children.append(q)
402 try:
402 if 'TaskScheduler.scheme_name' in self.config:
403 403 scheme = self.config.TaskScheduler.scheme_name
404 except AttributeError:
404 else:
405 405 scheme = TaskScheduler.scheme_name.get_default_value()
406 406 # Task Queue (in a Process)
407 407 if scheme == 'pure':
408 408 self.log.warn("task::using pure DEALER Task scheduler")
409 409 q = mq(zmq.ROUTER, zmq.DEALER, zmq.PUB, b'intask', b'outtask')
410 410 # q.setsockopt_out(zmq.HWM, hub.hwm)
411 411 q.bind_in(f.client_url('task'))
412 412 q.setsockopt_in(zmq.IDENTITY, b'task_in')
413 413 q.bind_out(f.engine_url('task'))
414 414 q.setsockopt_out(zmq.IDENTITY, b'task_out')
415 415 q.connect_mon(monitor_url)
416 416 q.daemon=True
417 417 children.append(q)
418 418 elif scheme == 'none':
419 419 self.log.warn("task::using no Task scheduler")
420 420
421 421 else:
422 422 self.log.info("task::using Python %s Task scheduler"%scheme)
423 423 sargs = (f.client_url('task'), f.engine_url('task'),
424 424 monitor_url, disambiguate_url(f.client_url('notification')),
425 425 disambiguate_url(f.client_url('registration')),
426 426 )
427 427 kwargs = dict(logname='scheduler', loglevel=self.log_level,
428 428 log_url = self.log_url, config=dict(self.config))
429 429 if 'Process' in self.mq_class:
430 430 # run the Python scheduler in a Process
431 431 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
432 432 q.daemon=True
433 433 children.append(q)
434 434 else:
435 435 # single-threaded Controller
436 436 kwargs['in_thread'] = True
437 437 launch_scheduler(*sargs, **kwargs)
438 438
439 439 # set unlimited HWM for all relay devices
440 440 if hasattr(zmq, 'SNDHWM'):
441 441 q = children[0]
442 442 q.setsockopt_in(zmq.RCVHWM, 0)
443 443 q.setsockopt_out(zmq.SNDHWM, 0)
444 444
445 445 for q in children[1:]:
446 446 if not hasattr(q, 'setsockopt_in'):
447 447 continue
448 448 q.setsockopt_in(zmq.SNDHWM, 0)
449 449 q.setsockopt_in(zmq.RCVHWM, 0)
450 450 q.setsockopt_out(zmq.SNDHWM, 0)
451 451 q.setsockopt_out(zmq.RCVHWM, 0)
452 452 q.setsockopt_mon(zmq.SNDHWM, 0)
453 453
454 454
455 455 def terminate_children(self):
456 456 child_procs = []
457 457 for child in self.children:
458 458 if isinstance(child, ProcessMonitoredQueue):
459 459 child_procs.append(child.launcher)
460 460 elif isinstance(child, Process):
461 461 child_procs.append(child)
462 462 if child_procs:
463 463 self.log.critical("terminating children...")
464 464 for child in child_procs:
465 465 try:
466 466 child.terminate()
467 467 except OSError:
468 468 # already dead
469 469 pass
470 470
471 471 def handle_signal(self, sig, frame):
472 472 self.log.critical("Received signal %i, shutting down", sig)
473 473 self.terminate_children()
474 474 self.loop.stop()
475 475
476 476 def init_signal(self):
477 477 for sig in (SIGINT, SIGABRT, SIGTERM):
478 478 signal(sig, self.handle_signal)
479 479
480 480 def do_import_statements(self):
481 481 statements = self.import_statements
482 482 for s in statements:
483 483 try:
484 484 self.log.msg("Executing statement: '%s'" % s)
485 485 exec s in globals(), locals()
486 486 except:
487 487 self.log.msg("Error running statement: %s" % s)
488 488
489 489 def forward_logging(self):
490 490 if self.log_url:
491 491 self.log.info("Forwarding logging to %s"%self.log_url)
492 492 context = zmq.Context.instance()
493 493 lsock = context.socket(zmq.PUB)
494 494 lsock.connect(self.log_url)
495 495 handler = PUBHandler(lsock)
496 496 handler.root_topic = 'controller'
497 497 handler.setLevel(self.log_level)
498 498 self.log.addHandler(handler)
499 499
500 500 @catch_config_error
501 501 def initialize(self, argv=None):
502 502 super(IPControllerApp, self).initialize(argv)
503 503 self.forward_logging()
504 504 self.load_secondary_config()
505 505 self.init_hub()
506 506 self.init_schedulers()
507 507
508 508 def start(self):
509 509 # Start the subprocesses:
510 510 self.factory.start()
511 511 # children must be started before signals are setup,
512 512 # otherwise signal-handling will fire multiple times
513 513 for child in self.children:
514 514 child.start()
515 515 self.init_signal()
516 516
517 517 self.write_pid_file(overwrite=True)
518 518
519 519 try:
520 520 self.factory.loop.start()
521 521 except KeyboardInterrupt:
522 522 self.log.critical("Interrupted, Exiting...\n")
523 523 finally:
524 524 self.cleanup_connection_files()
525 525
526 526
527 527 def launch_new_instance(*args, **kwargs):
528 528 """Create and run the IPython controller"""
529 529 if sys.platform == 'win32':
530 530 # make sure we don't get called from a multiprocessing subprocess
531 531 # this can result in infinite Controllers being started on Windows
532 532 # which doesn't have a proper fork, so multiprocessing is wonky
533 533
534 534 # this only comes up when IPython has been installed using vanilla
535 535 # setuptools, and *not* distribute.
536 536 import multiprocessing
537 537 p = multiprocessing.current_process()
538 538 # the main process has name 'MainProcess'
539 539 # subprocesses will have names like 'Process-1'
540 540 if p.name != 'MainProcess':
541 541 # we are a subprocess, don't start another Controller!
542 542 return
543 543 return IPControllerApp.launch_instance(*args, **kwargs)
544 544
545 545
546 546 if __name__ == '__main__':
547 547 launch_new_instance()
@@ -1,393 +1,384 b''
1 1 #!/usr/bin/env python
2 2 # encoding: utf-8
3 3 """
4 4 The IPython engine application
5 5
6 6 Authors:
7 7
8 8 * Brian Granger
9 9 * MinRK
10 10
11 11 """
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Copyright (C) 2008-2011 The IPython Development Team
15 15 #
16 16 # Distributed under the terms of the BSD License. The full license is in
17 17 # the file COPYING, distributed as part of this software.
18 18 #-----------------------------------------------------------------------------
19 19
20 20 #-----------------------------------------------------------------------------
21 21 # Imports
22 22 #-----------------------------------------------------------------------------
23 23
24 24 import json
25 25 import os
26 26 import sys
27 27 import time
28 28
29 29 import zmq
30 30 from zmq.eventloop import ioloop
31 31
32 32 from IPython.core.profiledir import ProfileDir
33 33 from IPython.parallel.apps.baseapp import (
34 34 BaseParallelApplication,
35 35 base_aliases,
36 36 base_flags,
37 37 catch_config_error,
38 38 )
39 39 from IPython.kernel.zmq.log import EnginePUBHandler
40 40 from IPython.kernel.zmq.ipkernel import Kernel
41 41 from IPython.kernel.zmq.kernelapp import IPKernelApp
42 42 from IPython.kernel.zmq.session import (
43 43 Session, session_aliases, session_flags
44 44 )
45 45 from IPython.kernel.zmq.zmqshell import ZMQInteractiveShell
46 46
47 47 from IPython.config.configurable import Configurable
48 48
49 49 from IPython.parallel.engine.engine import EngineFactory
50 50 from IPython.parallel.util import disambiguate_ip_address
51 51
52 52 from IPython.utils.importstring import import_item
53 53 from IPython.utils.py3compat import cast_bytes
54 54 from IPython.utils.traitlets import Bool, Unicode, Dict, List, Float, Instance
55 55
56 56
57 57 #-----------------------------------------------------------------------------
58 58 # Module level variables
59 59 #-----------------------------------------------------------------------------
60 60
61 61 _description = """Start an IPython engine for parallel computing.
62 62
63 63 IPython engines run in parallel and perform computations on behalf of a client
64 64 and controller. A controller needs to be started before the engines. The
65 65 engine can be configured using command line options or using a cluster
66 66 directory. Cluster directories contain config, log and security files and are
67 67 usually located in your ipython directory and named as "profile_name".
68 68 See the `profile` and `profile-dir` options for details.
69 69 """
70 70
71 71 _examples = """
72 72 ipengine --ip=192.168.0.1 --port=1000 # connect to hub at ip and port
73 73 ipengine --log-to-file --log-level=DEBUG # log to a file with DEBUG verbosity
74 74 """
75 75
76 76 #-----------------------------------------------------------------------------
77 77 # MPI configuration
78 78 #-----------------------------------------------------------------------------
79 79
80 80 mpi4py_init = """from mpi4py import MPI as mpi
81 81 mpi.size = mpi.COMM_WORLD.Get_size()
82 82 mpi.rank = mpi.COMM_WORLD.Get_rank()
83 83 """
84 84
85 85
86 86 pytrilinos_init = """from PyTrilinos import Epetra
87 87 class SimpleStruct:
88 88 pass
89 89 mpi = SimpleStruct()
90 90 mpi.rank = 0
91 91 mpi.size = 0
92 92 """
93 93
94 94 class MPI(Configurable):
95 95 """Configurable for MPI initialization"""
96 96 use = Unicode('', config=True,
97 97 help='How to enable MPI (mpi4py, pytrilinos, or empty string to disable).'
98 98 )
99 99
100 100 def _use_changed(self, name, old, new):
101 101 # load default init script if it's not set
102 102 if not self.init_script:
103 103 self.init_script = self.default_inits.get(new, '')
104 104
105 105 init_script = Unicode('', config=True,
106 106 help="Initialization code for MPI")
107 107
108 108 default_inits = Dict({'mpi4py' : mpi4py_init, 'pytrilinos':pytrilinos_init},
109 109 config=True)
110 110
111 111
112 112 #-----------------------------------------------------------------------------
113 113 # Main application
114 114 #-----------------------------------------------------------------------------
115 115 aliases = dict(
116 116 file = 'IPEngineApp.url_file',
117 117 c = 'IPEngineApp.startup_command',
118 118 s = 'IPEngineApp.startup_script',
119 119
120 120 url = 'EngineFactory.url',
121 121 ssh = 'EngineFactory.sshserver',
122 122 sshkey = 'EngineFactory.sshkey',
123 123 ip = 'EngineFactory.ip',
124 124 transport = 'EngineFactory.transport',
125 125 port = 'EngineFactory.regport',
126 126 location = 'EngineFactory.location',
127 127
128 128 timeout = 'EngineFactory.timeout',
129 129
130 130 mpi = 'MPI.use',
131 131
132 132 )
133 133 aliases.update(base_aliases)
134 134 aliases.update(session_aliases)
135 135 flags = {}
136 136 flags.update(base_flags)
137 137 flags.update(session_flags)
138 138
139 139 class IPEngineApp(BaseParallelApplication):
140 140
141 141 name = 'ipengine'
142 142 description = _description
143 143 examples = _examples
144 144 classes = List([ZMQInteractiveShell, ProfileDir, Session, EngineFactory, Kernel, MPI])
145 145
146 146 startup_script = Unicode(u'', config=True,
147 147 help='specify a script to be run at startup')
148 148 startup_command = Unicode('', config=True,
149 149 help='specify a command to be run at startup')
150 150
151 151 url_file = Unicode(u'', config=True,
152 152 help="""The full location of the file containing the connection information for
153 153 the controller. If this is not given, the file must be in the
154 154 security directory of the cluster directory. This location is
155 155 resolved using the `profile` or `profile_dir` options.""",
156 156 )
157 157 wait_for_url_file = Float(5, config=True,
158 158 help="""The maximum number of seconds to wait for url_file to exist.
159 159 This is useful for batch-systems and shared-filesystems where the
160 160 controller and engine are started at the same time and it
161 161 may take a moment for the controller to write the connector files.""")
162 162
163 163 url_file_name = Unicode(u'ipcontroller-engine.json', config=True)
164 164
165 165 def _cluster_id_changed(self, name, old, new):
166 166 if new:
167 167 base = 'ipcontroller-%s' % new
168 168 else:
169 169 base = 'ipcontroller'
170 170 self.url_file_name = "%s-engine.json" % base
171 171
172 172 log_url = Unicode('', config=True,
173 173 help="""The URL for the iploggerapp instance, for forwarding
174 174 logging to a central location.""")
175 175
176 176 # an IPKernelApp instance, used to setup listening for shell frontends
177 177 kernel_app = Instance(IPKernelApp)
178 178
179 179 aliases = Dict(aliases)
180 180 flags = Dict(flags)
181 181
182 182 @property
183 183 def kernel(self):
184 184 """allow access to the Kernel object, so I look like IPKernelApp"""
185 185 return self.engine.kernel
186 186
187 187 def find_url_file(self):
188 188 """Set the url file.
189 189
190 190 Here we don't try to actually see if it exists for is valid as that
191 191 is hadled by the connection logic.
192 192 """
193 193 config = self.config
194 194 # Find the actual controller key file
195 195 if not self.url_file:
196 196 self.url_file = os.path.join(
197 197 self.profile_dir.security_dir,
198 198 self.url_file_name
199 199 )
200 200
201 201 def load_connector_file(self):
202 202 """load config from a JSON connector file,
203 203 at a *lower* priority than command-line/config files.
204 204 """
205 205
206 206 self.log.info("Loading url_file %r", self.url_file)
207 207 config = self.config
208 208
209 209 with open(self.url_file) as f:
210 210 d = json.loads(f.read())
211 211
212 212 # allow hand-override of location for disambiguation
213 213 # and ssh-server
214 try:
215 config.EngineFactory.location
216 except AttributeError:
214 if 'EngineFactory.location' not in config:
217 215 config.EngineFactory.location = d['location']
218
219 try:
220 config.EngineFactory.sshserver
221 except AttributeError:
216 if 'EngineFactory.sshserver' not in config:
222 217 config.EngineFactory.sshserver = d.get('ssh')
223 218
224 219 location = config.EngineFactory.location
225 220
226 221 proto, ip = d['interface'].split('://')
227 222 ip = disambiguate_ip_address(ip, location)
228 223 d['interface'] = '%s://%s' % (proto, ip)
229 224
230 225 # DO NOT allow override of basic URLs, serialization, or key
231 226 # JSON file takes top priority there
232 227 config.Session.key = cast_bytes(d['key'])
233 228 config.Session.signature_scheme = d['signature_scheme']
234 229
235 230 config.EngineFactory.url = d['interface'] + ':%i' % d['registration']
236 231
237 232 config.Session.packer = d['pack']
238 233 config.Session.unpacker = d['unpack']
239 234
240 235 self.log.debug("Config changed:")
241 236 self.log.debug("%r", config)
242 237 self.connection_info = d
243 238
244 239 def bind_kernel(self, **kwargs):
245 240 """Promote engine to listening kernel, accessible to frontends."""
246 241 if self.kernel_app is not None:
247 242 return
248 243
249 244 self.log.info("Opening ports for direct connections as an IPython kernel")
250 245
251 246 kernel = self.kernel
252 247
253 248 kwargs.setdefault('config', self.config)
254 249 kwargs.setdefault('log', self.log)
255 250 kwargs.setdefault('profile_dir', self.profile_dir)
256 251 kwargs.setdefault('session', self.engine.session)
257 252
258 253 app = self.kernel_app = IPKernelApp(**kwargs)
259 254
260 255 # allow IPKernelApp.instance():
261 256 IPKernelApp._instance = app
262 257
263 258 app.init_connection_file()
264 259 # relevant contents of init_sockets:
265 260
266 261 app.shell_port = app._bind_socket(kernel.shell_streams[0], app.shell_port)
267 262 app.log.debug("shell ROUTER Channel on port: %i", app.shell_port)
268 263
269 264 app.iopub_port = app._bind_socket(kernel.iopub_socket, app.iopub_port)
270 265 app.log.debug("iopub PUB Channel on port: %i", app.iopub_port)
271 266
272 267 kernel.stdin_socket = self.engine.context.socket(zmq.ROUTER)
273 268 app.stdin_port = app._bind_socket(kernel.stdin_socket, app.stdin_port)
274 269 app.log.debug("stdin ROUTER Channel on port: %i", app.stdin_port)
275 270
276 271 # start the heartbeat, and log connection info:
277 272
278 273 app.init_heartbeat()
279 274
280 275 app.log_connection_info()
281 276 app.write_connection_file()
282 277
283 278
284 279 def init_engine(self):
285 280 # This is the working dir by now.
286 281 sys.path.insert(0, '')
287 282 config = self.config
288 283 # print config
289 284 self.find_url_file()
290 285
291 286 # was the url manually specified?
292 287 keys = set(self.config.EngineFactory.keys())
293 288 keys = keys.union(set(self.config.RegistrationFactory.keys()))
294 289
295 290 if keys.intersection(set(['ip', 'url', 'port'])):
296 291 # Connection info was specified, don't wait for the file
297 292 url_specified = True
298 293 self.wait_for_url_file = 0
299 294 else:
300 295 url_specified = False
301 296
302 297 if self.wait_for_url_file and not os.path.exists(self.url_file):
303 298 self.log.warn("url_file %r not found", self.url_file)
304 299 self.log.warn("Waiting up to %.1f seconds for it to arrive.", self.wait_for_url_file)
305 300 tic = time.time()
306 301 while not os.path.exists(self.url_file) and (time.time()-tic < self.wait_for_url_file):
307 302 # wait for url_file to exist, or until time limit
308 303 time.sleep(0.1)
309 304
310 305 if os.path.exists(self.url_file):
311 306 self.load_connector_file()
312 307 elif not url_specified:
313 308 self.log.fatal("Fatal: url file never arrived: %s", self.url_file)
314 309 self.exit(1)
315 310
311 exec_lines = []
312 for app in ('IPKernelApp', 'InteractiveShellApp'):
313 if '%s.exec_lines' in config:
314 exec_lines = config.IPKernelApp.exec_lines = config[app].exec_lines
315 break
316 316
317 try:
318 exec_lines = config.IPKernelApp.exec_lines
319 except AttributeError:
320 try:
321 exec_lines = config.InteractiveShellApp.exec_lines
322 except AttributeError:
323 exec_lines = config.IPKernelApp.exec_lines = []
324 try:
325 exec_files = config.IPKernelApp.exec_files
326 except AttributeError:
327 try:
328 exec_files = config.InteractiveShellApp.exec_files
329 except AttributeError:
330 exec_files = config.IPKernelApp.exec_files = []
317 exec_files = []
318 for app in ('IPKernelApp', 'InteractiveShellApp'):
319 if '%s.exec_files' in config:
320 exec_files = config.IPKernelApp.exec_files = config[app].exec_files
321 break
331 322
332 323 if self.startup_script:
333 324 exec_files.append(self.startup_script)
334 325 if self.startup_command:
335 326 exec_lines.append(self.startup_command)
336 327
337 328 # Create the underlying shell class and Engine
338 329 # shell_class = import_item(self.master_config.Global.shell_class)
339 330 # print self.config
340 331 try:
341 332 self.engine = EngineFactory(config=config, log=self.log,
342 333 connection_info=self.connection_info,
343 334 )
344 335 except:
345 336 self.log.error("Couldn't start the Engine", exc_info=True)
346 337 self.exit(1)
347 338
348 339 def forward_logging(self):
349 340 if self.log_url:
350 341 self.log.info("Forwarding logging to %s", self.log_url)
351 342 context = self.engine.context
352 343 lsock = context.socket(zmq.PUB)
353 344 lsock.connect(self.log_url)
354 345 handler = EnginePUBHandler(self.engine, lsock)
355 346 handler.setLevel(self.log_level)
356 347 self.log.addHandler(handler)
357 348
358 349 def init_mpi(self):
359 350 global mpi
360 351 self.mpi = MPI(parent=self)
361 352
362 353 mpi_import_statement = self.mpi.init_script
363 354 if mpi_import_statement:
364 355 try:
365 356 self.log.info("Initializing MPI:")
366 357 self.log.info(mpi_import_statement)
367 358 exec mpi_import_statement in globals()
368 359 except:
369 360 mpi = None
370 361 else:
371 362 mpi = None
372 363
373 364 @catch_config_error
374 365 def initialize(self, argv=None):
375 366 super(IPEngineApp, self).initialize(argv)
376 367 self.init_mpi()
377 368 self.init_engine()
378 369 self.forward_logging()
379 370
380 371 def start(self):
381 372 self.engine.start()
382 373 try:
383 374 self.engine.loop.start()
384 375 except KeyboardInterrupt:
385 376 self.log.critical("Engine Interrupted, shutting down...\n")
386 377
387 378
388 379 launch_new_instance = IPEngineApp.launch_instance
389 380
390 381
391 382 if __name__ == '__main__':
392 383 launch_new_instance()
393 384
@@ -1,1422 +1,1421 b''
1 1 """The IPython Controller Hub with 0MQ
2 2 This is the master object that handles connections from engines and clients,
3 3 and monitors traffic through the various queues.
4 4
5 5 Authors:
6 6
7 7 * Min RK
8 8 """
9 9 #-----------------------------------------------------------------------------
10 10 # Copyright (C) 2010-2011 The IPython Development Team
11 11 #
12 12 # Distributed under the terms of the BSD License. The full license is in
13 13 # the file COPYING, distributed as part of this software.
14 14 #-----------------------------------------------------------------------------
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # Imports
18 18 #-----------------------------------------------------------------------------
19 19 from __future__ import print_function
20 20
21 21 import json
22 22 import os
23 23 import sys
24 24 import time
25 25 from datetime import datetime
26 26
27 27 import zmq
28 28 from zmq.eventloop import ioloop
29 29 from zmq.eventloop.zmqstream import ZMQStream
30 30
31 31 # internal:
32 32 from IPython.utils.importstring import import_item
33 33 from IPython.utils.localinterfaces import localhost
34 34 from IPython.utils.py3compat import cast_bytes
35 35 from IPython.utils.traitlets import (
36 36 HasTraits, Instance, Integer, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
37 37 )
38 38
39 39 from IPython.parallel import error, util
40 40 from IPython.parallel.factory import RegistrationFactory
41 41
42 42 from IPython.kernel.zmq.session import SessionFactory
43 43
44 44 from .heartmonitor import HeartMonitor
45 45
46 46 #-----------------------------------------------------------------------------
47 47 # Code
48 48 #-----------------------------------------------------------------------------
49 49
50 50 def _passer(*args, **kwargs):
51 51 return
52 52
53 53 def _printer(*args, **kwargs):
54 54 print (args)
55 55 print (kwargs)
56 56
57 57 def empty_record():
58 58 """Return an empty dict with all record keys."""
59 59 return {
60 60 'msg_id' : None,
61 61 'header' : None,
62 62 'metadata' : None,
63 63 'content': None,
64 64 'buffers': None,
65 65 'submitted': None,
66 66 'client_uuid' : None,
67 67 'engine_uuid' : None,
68 68 'started': None,
69 69 'completed': None,
70 70 'resubmitted': None,
71 71 'received': None,
72 72 'result_header' : None,
73 73 'result_metadata' : None,
74 74 'result_content' : None,
75 75 'result_buffers' : None,
76 76 'queue' : None,
77 77 'pyin' : None,
78 78 'pyout': None,
79 79 'pyerr': None,
80 80 'stdout': '',
81 81 'stderr': '',
82 82 }
83 83
84 84 def init_record(msg):
85 85 """Initialize a TaskRecord based on a request."""
86 86 header = msg['header']
87 87 return {
88 88 'msg_id' : header['msg_id'],
89 89 'header' : header,
90 90 'content': msg['content'],
91 91 'metadata': msg['metadata'],
92 92 'buffers': msg['buffers'],
93 93 'submitted': header['date'],
94 94 'client_uuid' : None,
95 95 'engine_uuid' : None,
96 96 'started': None,
97 97 'completed': None,
98 98 'resubmitted': None,
99 99 'received': None,
100 100 'result_header' : None,
101 101 'result_metadata': None,
102 102 'result_content' : None,
103 103 'result_buffers' : None,
104 104 'queue' : None,
105 105 'pyin' : None,
106 106 'pyout': None,
107 107 'pyerr': None,
108 108 'stdout': '',
109 109 'stderr': '',
110 110 }
111 111
112 112
113 113 class EngineConnector(HasTraits):
114 114 """A simple object for accessing the various zmq connections of an object.
115 115 Attributes are:
116 116 id (int): engine ID
117 117 uuid (unicode): engine UUID
118 118 pending: set of msg_ids
119 119 stallback: DelayedCallback for stalled registration
120 120 """
121 121
122 122 id = Integer(0)
123 123 uuid = Unicode()
124 124 pending = Set()
125 125 stallback = Instance(ioloop.DelayedCallback)
126 126
127 127
128 128 _db_shortcuts = {
129 129 'sqlitedb' : 'IPython.parallel.controller.sqlitedb.SQLiteDB',
130 130 'mongodb' : 'IPython.parallel.controller.mongodb.MongoDB',
131 131 'dictdb' : 'IPython.parallel.controller.dictdb.DictDB',
132 132 'nodb' : 'IPython.parallel.controller.dictdb.NoDB',
133 133 }
134 134
135 135 class HubFactory(RegistrationFactory):
136 136 """The Configurable for setting up a Hub."""
137 137
138 138 # port-pairs for monitoredqueues:
139 139 hb = Tuple(Integer,Integer,config=True,
140 140 help="""PUB/ROUTER Port pair for Engine heartbeats""")
141 141 def _hb_default(self):
142 142 return tuple(util.select_random_ports(2))
143 143
144 144 mux = Tuple(Integer,Integer,config=True,
145 145 help="""Client/Engine Port pair for MUX queue""")
146 146
147 147 def _mux_default(self):
148 148 return tuple(util.select_random_ports(2))
149 149
150 150 task = Tuple(Integer,Integer,config=True,
151 151 help="""Client/Engine Port pair for Task queue""")
152 152 def _task_default(self):
153 153 return tuple(util.select_random_ports(2))
154 154
155 155 control = Tuple(Integer,Integer,config=True,
156 156 help="""Client/Engine Port pair for Control queue""")
157 157
158 158 def _control_default(self):
159 159 return tuple(util.select_random_ports(2))
160 160
161 161 iopub = Tuple(Integer,Integer,config=True,
162 162 help="""Client/Engine Port pair for IOPub relay""")
163 163
164 164 def _iopub_default(self):
165 165 return tuple(util.select_random_ports(2))
166 166
167 167 # single ports:
168 168 mon_port = Integer(config=True,
169 169 help="""Monitor (SUB) port for queue traffic""")
170 170
171 171 def _mon_port_default(self):
172 172 return util.select_random_ports(1)[0]
173 173
174 174 notifier_port = Integer(config=True,
175 175 help="""PUB port for sending engine status notifications""")
176 176
177 177 def _notifier_port_default(self):
178 178 return util.select_random_ports(1)[0]
179 179
180 180 engine_ip = Unicode(config=True,
181 181 help="IP on which to listen for engine connections. [default: loopback]")
182 182 def _engine_ip_default(self):
183 183 return localhost()
184 184 engine_transport = Unicode('tcp', config=True,
185 185 help="0MQ transport for engine connections. [default: tcp]")
186 186
187 187 client_ip = Unicode(config=True,
188 188 help="IP on which to listen for client connections. [default: loopback]")
189 189 client_transport = Unicode('tcp', config=True,
190 190 help="0MQ transport for client connections. [default : tcp]")
191 191
192 192 monitor_ip = Unicode(config=True,
193 193 help="IP on which to listen for monitor messages. [default: loopback]")
194 194 monitor_transport = Unicode('tcp', config=True,
195 195 help="0MQ transport for monitor messages. [default : tcp]")
196 196
197 197 _client_ip_default = _monitor_ip_default = _engine_ip_default
198 198
199 199
200 200 monitor_url = Unicode('')
201 201
202 202 db_class = DottedObjectName('NoDB',
203 203 config=True, help="""The class to use for the DB backend
204 204
205 205 Options include:
206 206
207 207 SQLiteDB: SQLite
208 208 MongoDB : use MongoDB
209 209 DictDB : in-memory storage (fastest, but be mindful of memory growth of the Hub)
210 210 NoDB : disable database altogether (default)
211 211
212 212 """)
213 213
214 214 # not configurable
215 215 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
216 216 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
217 217
218 218 def _ip_changed(self, name, old, new):
219 219 self.engine_ip = new
220 220 self.client_ip = new
221 221 self.monitor_ip = new
222 222 self._update_monitor_url()
223 223
224 224 def _update_monitor_url(self):
225 225 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
226 226
227 227 def _transport_changed(self, name, old, new):
228 228 self.engine_transport = new
229 229 self.client_transport = new
230 230 self.monitor_transport = new
231 231 self._update_monitor_url()
232 232
233 233 def __init__(self, **kwargs):
234 234 super(HubFactory, self).__init__(**kwargs)
235 235 self._update_monitor_url()
236 236
237 237
238 238 def construct(self):
239 239 self.init_hub()
240 240
241 241 def start(self):
242 242 self.heartmonitor.start()
243 243 self.log.info("Heartmonitor started")
244 244
245 245 def client_url(self, channel):
246 246 """return full zmq url for a named client channel"""
247 247 return "%s://%s:%i" % (self.client_transport, self.client_ip, self.client_info[channel])
248 248
249 249 def engine_url(self, channel):
250 250 """return full zmq url for a named engine channel"""
251 251 return "%s://%s:%i" % (self.engine_transport, self.engine_ip, self.engine_info[channel])
252 252
253 253 def init_hub(self):
254 254 """construct Hub object"""
255 255
256 256 ctx = self.context
257 257 loop = self.loop
258
259 try:
258 if 'TaskScheduler.scheme_name' in self.config:
260 259 scheme = self.config.TaskScheduler.scheme_name
261 except AttributeError:
260 else:
262 261 from .scheduler import TaskScheduler
263 262 scheme = TaskScheduler.scheme_name.get_default_value()
264 263
265 264 # build connection dicts
266 265 engine = self.engine_info = {
267 266 'interface' : "%s://%s" % (self.engine_transport, self.engine_ip),
268 267 'registration' : self.regport,
269 268 'control' : self.control[1],
270 269 'mux' : self.mux[1],
271 270 'hb_ping' : self.hb[0],
272 271 'hb_pong' : self.hb[1],
273 272 'task' : self.task[1],
274 273 'iopub' : self.iopub[1],
275 274 }
276 275
277 276 client = self.client_info = {
278 277 'interface' : "%s://%s" % (self.client_transport, self.client_ip),
279 278 'registration' : self.regport,
280 279 'control' : self.control[0],
281 280 'mux' : self.mux[0],
282 281 'task' : self.task[0],
283 282 'task_scheme' : scheme,
284 283 'iopub' : self.iopub[0],
285 284 'notification' : self.notifier_port,
286 285 }
287 286
288 287 self.log.debug("Hub engine addrs: %s", self.engine_info)
289 288 self.log.debug("Hub client addrs: %s", self.client_info)
290 289
291 290 # Registrar socket
292 291 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
293 292 util.set_hwm(q, 0)
294 293 q.bind(self.client_url('registration'))
295 294 self.log.info("Hub listening on %s for registration.", self.client_url('registration'))
296 295 if self.client_ip != self.engine_ip:
297 296 q.bind(self.engine_url('registration'))
298 297 self.log.info("Hub listening on %s for registration.", self.engine_url('registration'))
299 298
300 299 ### Engine connections ###
301 300
302 301 # heartbeat
303 302 hpub = ctx.socket(zmq.PUB)
304 303 hpub.bind(self.engine_url('hb_ping'))
305 304 hrep = ctx.socket(zmq.ROUTER)
306 305 util.set_hwm(hrep, 0)
307 306 hrep.bind(self.engine_url('hb_pong'))
308 307 self.heartmonitor = HeartMonitor(loop=loop, parent=self, log=self.log,
309 308 pingstream=ZMQStream(hpub,loop),
310 309 pongstream=ZMQStream(hrep,loop)
311 310 )
312 311
313 312 ### Client connections ###
314 313
315 314 # Notifier socket
316 315 n = ZMQStream(ctx.socket(zmq.PUB), loop)
317 316 n.bind(self.client_url('notification'))
318 317
319 318 ### build and launch the queues ###
320 319
321 320 # monitor socket
322 321 sub = ctx.socket(zmq.SUB)
323 322 sub.setsockopt(zmq.SUBSCRIBE, b"")
324 323 sub.bind(self.monitor_url)
325 324 sub.bind('inproc://monitor')
326 325 sub = ZMQStream(sub, loop)
327 326
328 327 # connect the db
329 328 db_class = _db_shortcuts.get(self.db_class.lower(), self.db_class)
330 329 self.log.info('Hub using DB backend: %r', (db_class.split('.')[-1]))
331 330 self.db = import_item(str(db_class))(session=self.session.session,
332 331 parent=self, log=self.log)
333 332 time.sleep(.25)
334 333
335 334 # resubmit stream
336 335 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
337 336 url = util.disambiguate_url(self.client_url('task'))
338 337 r.connect(url)
339 338
340 339 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
341 340 query=q, notifier=n, resubmit=r, db=self.db,
342 341 engine_info=self.engine_info, client_info=self.client_info,
343 342 log=self.log)
344 343
345 344
346 345 class Hub(SessionFactory):
347 346 """The IPython Controller Hub with 0MQ connections
348 347
349 348 Parameters
350 349 ==========
351 350 loop: zmq IOLoop instance
352 351 session: Session object
353 352 <removed> context: zmq context for creating new connections (?)
354 353 queue: ZMQStream for monitoring the command queue (SUB)
355 354 query: ZMQStream for engine registration and client queries requests (ROUTER)
356 355 heartbeat: HeartMonitor object checking the pulse of the engines
357 356 notifier: ZMQStream for broadcasting engine registration changes (PUB)
358 357 db: connection to db for out of memory logging of commands
359 358 NotImplemented
360 359 engine_info: dict of zmq connection information for engines to connect
361 360 to the queues.
362 361 client_info: dict of zmq connection information for engines to connect
363 362 to the queues.
364 363 """
365 364
366 365 engine_state_file = Unicode()
367 366
368 367 # internal data structures:
369 368 ids=Set() # engine IDs
370 369 keytable=Dict()
371 370 by_ident=Dict()
372 371 engines=Dict()
373 372 clients=Dict()
374 373 hearts=Dict()
375 374 pending=Set()
376 375 queues=Dict() # pending msg_ids keyed by engine_id
377 376 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
378 377 completed=Dict() # completed msg_ids keyed by engine_id
379 378 all_completed=Set() # completed msg_ids keyed by engine_id
380 379 dead_engines=Set() # completed msg_ids keyed by engine_id
381 380 unassigned=Set() # set of task msg_ds not yet assigned a destination
382 381 incoming_registrations=Dict()
383 382 registration_timeout=Integer()
384 383 _idcounter=Integer(0)
385 384
386 385 # objects from constructor:
387 386 query=Instance(ZMQStream)
388 387 monitor=Instance(ZMQStream)
389 388 notifier=Instance(ZMQStream)
390 389 resubmit=Instance(ZMQStream)
391 390 heartmonitor=Instance(HeartMonitor)
392 391 db=Instance(object)
393 392 client_info=Dict()
394 393 engine_info=Dict()
395 394
396 395
397 396 def __init__(self, **kwargs):
398 397 """
399 398 # universal:
400 399 loop: IOLoop for creating future connections
401 400 session: streamsession for sending serialized data
402 401 # engine:
403 402 queue: ZMQStream for monitoring queue messages
404 403 query: ZMQStream for engine+client registration and client requests
405 404 heartbeat: HeartMonitor object for tracking engines
406 405 # extra:
407 406 db: ZMQStream for db connection (NotImplemented)
408 407 engine_info: zmq address/protocol dict for engine connections
409 408 client_info: zmq address/protocol dict for client connections
410 409 """
411 410
412 411 super(Hub, self).__init__(**kwargs)
413 412 self.registration_timeout = max(10000, 5*self.heartmonitor.period)
414 413
415 414 # register our callbacks
416 415 self.query.on_recv(self.dispatch_query)
417 416 self.monitor.on_recv(self.dispatch_monitor_traffic)
418 417
419 418 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
420 419 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
421 420
422 421 self.monitor_handlers = {b'in' : self.save_queue_request,
423 422 b'out': self.save_queue_result,
424 423 b'intask': self.save_task_request,
425 424 b'outtask': self.save_task_result,
426 425 b'tracktask': self.save_task_destination,
427 426 b'incontrol': _passer,
428 427 b'outcontrol': _passer,
429 428 b'iopub': self.save_iopub_message,
430 429 }
431 430
432 431 self.query_handlers = {'queue_request': self.queue_status,
433 432 'result_request': self.get_results,
434 433 'history_request': self.get_history,
435 434 'db_request': self.db_query,
436 435 'purge_request': self.purge_results,
437 436 'load_request': self.check_load,
438 437 'resubmit_request': self.resubmit_task,
439 438 'shutdown_request': self.shutdown_request,
440 439 'registration_request' : self.register_engine,
441 440 'unregistration_request' : self.unregister_engine,
442 441 'connection_request': self.connection_request,
443 442 }
444 443
445 444 # ignore resubmit replies
446 445 self.resubmit.on_recv(lambda msg: None, copy=False)
447 446
448 447 self.log.info("hub::created hub")
449 448
450 449 @property
451 450 def _next_id(self):
452 451 """gemerate a new ID.
453 452
454 453 No longer reuse old ids, just count from 0."""
455 454 newid = self._idcounter
456 455 self._idcounter += 1
457 456 return newid
458 457 # newid = 0
459 458 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
460 459 # # print newid, self.ids, self.incoming_registrations
461 460 # while newid in self.ids or newid in incoming:
462 461 # newid += 1
463 462 # return newid
464 463
465 464 #-----------------------------------------------------------------------------
466 465 # message validation
467 466 #-----------------------------------------------------------------------------
468 467
469 468 def _validate_targets(self, targets):
470 469 """turn any valid targets argument into a list of integer ids"""
471 470 if targets is None:
472 471 # default to all
473 472 return self.ids
474 473
475 474 if isinstance(targets, (int,str,unicode)):
476 475 # only one target specified
477 476 targets = [targets]
478 477 _targets = []
479 478 for t in targets:
480 479 # map raw identities to ids
481 480 if isinstance(t, (str,unicode)):
482 481 t = self.by_ident.get(cast_bytes(t), t)
483 482 _targets.append(t)
484 483 targets = _targets
485 484 bad_targets = [ t for t in targets if t not in self.ids ]
486 485 if bad_targets:
487 486 raise IndexError("No Such Engine: %r" % bad_targets)
488 487 if not targets:
489 488 raise IndexError("No Engines Registered")
490 489 return targets
491 490
492 491 #-----------------------------------------------------------------------------
493 492 # dispatch methods (1 per stream)
494 493 #-----------------------------------------------------------------------------
495 494
496 495
497 496 @util.log_errors
498 497 def dispatch_monitor_traffic(self, msg):
499 498 """all ME and Task queue messages come through here, as well as
500 499 IOPub traffic."""
501 500 self.log.debug("monitor traffic: %r", msg[0])
502 501 switch = msg[0]
503 502 try:
504 503 idents, msg = self.session.feed_identities(msg[1:])
505 504 except ValueError:
506 505 idents=[]
507 506 if not idents:
508 507 self.log.error("Monitor message without topic: %r", msg)
509 508 return
510 509 handler = self.monitor_handlers.get(switch, None)
511 510 if handler is not None:
512 511 handler(idents, msg)
513 512 else:
514 513 self.log.error("Unrecognized monitor topic: %r", switch)
515 514
516 515
517 516 @util.log_errors
518 517 def dispatch_query(self, msg):
519 518 """Route registration requests and queries from clients."""
520 519 try:
521 520 idents, msg = self.session.feed_identities(msg)
522 521 except ValueError:
523 522 idents = []
524 523 if not idents:
525 524 self.log.error("Bad Query Message: %r", msg)
526 525 return
527 526 client_id = idents[0]
528 527 try:
529 528 msg = self.session.unserialize(msg, content=True)
530 529 except Exception:
531 530 content = error.wrap_exception()
532 531 self.log.error("Bad Query Message: %r", msg, exc_info=True)
533 532 self.session.send(self.query, "hub_error", ident=client_id,
534 533 content=content)
535 534 return
536 535 # print client_id, header, parent, content
537 536 #switch on message type:
538 537 msg_type = msg['header']['msg_type']
539 538 self.log.info("client::client %r requested %r", client_id, msg_type)
540 539 handler = self.query_handlers.get(msg_type, None)
541 540 try:
542 541 assert handler is not None, "Bad Message Type: %r" % msg_type
543 542 except:
544 543 content = error.wrap_exception()
545 544 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
546 545 self.session.send(self.query, "hub_error", ident=client_id,
547 546 content=content)
548 547 return
549 548
550 549 else:
551 550 handler(idents, msg)
552 551
553 552 def dispatch_db(self, msg):
554 553 """"""
555 554 raise NotImplementedError
556 555
557 556 #---------------------------------------------------------------------------
558 557 # handler methods (1 per event)
559 558 #---------------------------------------------------------------------------
560 559
561 560 #----------------------- Heartbeat --------------------------------------
562 561
563 562 def handle_new_heart(self, heart):
564 563 """handler to attach to heartbeater.
565 564 Called when a new heart starts to beat.
566 565 Triggers completion of registration."""
567 566 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
568 567 if heart not in self.incoming_registrations:
569 568 self.log.info("heartbeat::ignoring new heart: %r", heart)
570 569 else:
571 570 self.finish_registration(heart)
572 571
573 572
574 573 def handle_heart_failure(self, heart):
575 574 """handler to attach to heartbeater.
576 575 called when a previously registered heart fails to respond to beat request.
577 576 triggers unregistration"""
578 577 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
579 578 eid = self.hearts.get(heart, None)
580 579 uuid = self.engines[eid].uuid
581 580 if eid is None or self.keytable[eid] in self.dead_engines:
582 581 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
583 582 else:
584 583 self.unregister_engine(heart, dict(content=dict(id=eid, queue=uuid)))
585 584
586 585 #----------------------- MUX Queue Traffic ------------------------------
587 586
588 587 def save_queue_request(self, idents, msg):
589 588 if len(idents) < 2:
590 589 self.log.error("invalid identity prefix: %r", idents)
591 590 return
592 591 queue_id, client_id = idents[:2]
593 592 try:
594 593 msg = self.session.unserialize(msg)
595 594 except Exception:
596 595 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
597 596 return
598 597
599 598 eid = self.by_ident.get(queue_id, None)
600 599 if eid is None:
601 600 self.log.error("queue::target %r not registered", queue_id)
602 601 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
603 602 return
604 603 record = init_record(msg)
605 604 msg_id = record['msg_id']
606 605 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
607 606 # Unicode in records
608 607 record['engine_uuid'] = queue_id.decode('ascii')
609 608 record['client_uuid'] = msg['header']['session']
610 609 record['queue'] = 'mux'
611 610
612 611 try:
613 612 # it's posible iopub arrived first:
614 613 existing = self.db.get_record(msg_id)
615 614 for key,evalue in existing.iteritems():
616 615 rvalue = record.get(key, None)
617 616 if evalue and rvalue and evalue != rvalue:
618 617 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
619 618 elif evalue and not rvalue:
620 619 record[key] = evalue
621 620 try:
622 621 self.db.update_record(msg_id, record)
623 622 except Exception:
624 623 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
625 624 except KeyError:
626 625 try:
627 626 self.db.add_record(msg_id, record)
628 627 except Exception:
629 628 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
630 629
631 630
632 631 self.pending.add(msg_id)
633 632 self.queues[eid].append(msg_id)
634 633
635 634 def save_queue_result(self, idents, msg):
636 635 if len(idents) < 2:
637 636 self.log.error("invalid identity prefix: %r", idents)
638 637 return
639 638
640 639 client_id, queue_id = idents[:2]
641 640 try:
642 641 msg = self.session.unserialize(msg)
643 642 except Exception:
644 643 self.log.error("queue::engine %r sent invalid message to %r: %r",
645 644 queue_id, client_id, msg, exc_info=True)
646 645 return
647 646
648 647 eid = self.by_ident.get(queue_id, None)
649 648 if eid is None:
650 649 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
651 650 return
652 651
653 652 parent = msg['parent_header']
654 653 if not parent:
655 654 return
656 655 msg_id = parent['msg_id']
657 656 if msg_id in self.pending:
658 657 self.pending.remove(msg_id)
659 658 self.all_completed.add(msg_id)
660 659 self.queues[eid].remove(msg_id)
661 660 self.completed[eid].append(msg_id)
662 661 self.log.info("queue::request %r completed on %s", msg_id, eid)
663 662 elif msg_id not in self.all_completed:
664 663 # it could be a result from a dead engine that died before delivering the
665 664 # result
666 665 self.log.warn("queue:: unknown msg finished %r", msg_id)
667 666 return
668 667 # update record anyway, because the unregistration could have been premature
669 668 rheader = msg['header']
670 669 md = msg['metadata']
671 670 completed = rheader['date']
672 671 started = md.get('started', None)
673 672 result = {
674 673 'result_header' : rheader,
675 674 'result_metadata': md,
676 675 'result_content': msg['content'],
677 676 'received': datetime.now(),
678 677 'started' : started,
679 678 'completed' : completed
680 679 }
681 680
682 681 result['result_buffers'] = msg['buffers']
683 682 try:
684 683 self.db.update_record(msg_id, result)
685 684 except Exception:
686 685 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
687 686
688 687
689 688 #--------------------- Task Queue Traffic ------------------------------
690 689
691 690 def save_task_request(self, idents, msg):
692 691 """Save the submission of a task."""
693 692 client_id = idents[0]
694 693
695 694 try:
696 695 msg = self.session.unserialize(msg)
697 696 except Exception:
698 697 self.log.error("task::client %r sent invalid task message: %r",
699 698 client_id, msg, exc_info=True)
700 699 return
701 700 record = init_record(msg)
702 701
703 702 record['client_uuid'] = msg['header']['session']
704 703 record['queue'] = 'task'
705 704 header = msg['header']
706 705 msg_id = header['msg_id']
707 706 self.pending.add(msg_id)
708 707 self.unassigned.add(msg_id)
709 708 try:
710 709 # it's posible iopub arrived first:
711 710 existing = self.db.get_record(msg_id)
712 711 if existing['resubmitted']:
713 712 for key in ('submitted', 'client_uuid', 'buffers'):
714 713 # don't clobber these keys on resubmit
715 714 # submitted and client_uuid should be different
716 715 # and buffers might be big, and shouldn't have changed
717 716 record.pop(key)
718 717 # still check content,header which should not change
719 718 # but are not expensive to compare as buffers
720 719
721 720 for key,evalue in existing.iteritems():
722 721 if key.endswith('buffers'):
723 722 # don't compare buffers
724 723 continue
725 724 rvalue = record.get(key, None)
726 725 if evalue and rvalue and evalue != rvalue:
727 726 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
728 727 elif evalue and not rvalue:
729 728 record[key] = evalue
730 729 try:
731 730 self.db.update_record(msg_id, record)
732 731 except Exception:
733 732 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
734 733 except KeyError:
735 734 try:
736 735 self.db.add_record(msg_id, record)
737 736 except Exception:
738 737 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
739 738 except Exception:
740 739 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
741 740
742 741 def save_task_result(self, idents, msg):
743 742 """save the result of a completed task."""
744 743 client_id = idents[0]
745 744 try:
746 745 msg = self.session.unserialize(msg)
747 746 except Exception:
748 747 self.log.error("task::invalid task result message send to %r: %r",
749 748 client_id, msg, exc_info=True)
750 749 return
751 750
752 751 parent = msg['parent_header']
753 752 if not parent:
754 753 # print msg
755 754 self.log.warn("Task %r had no parent!", msg)
756 755 return
757 756 msg_id = parent['msg_id']
758 757 if msg_id in self.unassigned:
759 758 self.unassigned.remove(msg_id)
760 759
761 760 header = msg['header']
762 761 md = msg['metadata']
763 762 engine_uuid = md.get('engine', u'')
764 763 eid = self.by_ident.get(cast_bytes(engine_uuid), None)
765 764
766 765 status = md.get('status', None)
767 766
768 767 if msg_id in self.pending:
769 768 self.log.info("task::task %r finished on %s", msg_id, eid)
770 769 self.pending.remove(msg_id)
771 770 self.all_completed.add(msg_id)
772 771 if eid is not None:
773 772 if status != 'aborted':
774 773 self.completed[eid].append(msg_id)
775 774 if msg_id in self.tasks[eid]:
776 775 self.tasks[eid].remove(msg_id)
777 776 completed = header['date']
778 777 started = md.get('started', None)
779 778 result = {
780 779 'result_header' : header,
781 780 'result_metadata': msg['metadata'],
782 781 'result_content': msg['content'],
783 782 'started' : started,
784 783 'completed' : completed,
785 784 'received' : datetime.now(),
786 785 'engine_uuid': engine_uuid,
787 786 }
788 787
789 788 result['result_buffers'] = msg['buffers']
790 789 try:
791 790 self.db.update_record(msg_id, result)
792 791 except Exception:
793 792 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
794 793
795 794 else:
796 795 self.log.debug("task::unknown task %r finished", msg_id)
797 796
798 797 def save_task_destination(self, idents, msg):
799 798 try:
800 799 msg = self.session.unserialize(msg, content=True)
801 800 except Exception:
802 801 self.log.error("task::invalid task tracking message", exc_info=True)
803 802 return
804 803 content = msg['content']
805 804 # print (content)
806 805 msg_id = content['msg_id']
807 806 engine_uuid = content['engine_id']
808 807 eid = self.by_ident[cast_bytes(engine_uuid)]
809 808
810 809 self.log.info("task::task %r arrived on %r", msg_id, eid)
811 810 if msg_id in self.unassigned:
812 811 self.unassigned.remove(msg_id)
813 812 # else:
814 813 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
815 814
816 815 self.tasks[eid].append(msg_id)
817 816 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
818 817 try:
819 818 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
820 819 except Exception:
821 820 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
822 821
823 822
824 823 def mia_task_request(self, idents, msg):
825 824 raise NotImplementedError
826 825 client_id = idents[0]
827 826 # content = dict(mia=self.mia,status='ok')
828 827 # self.session.send('mia_reply', content=content, idents=client_id)
829 828
830 829
831 830 #--------------------- IOPub Traffic ------------------------------
832 831
833 832 def save_iopub_message(self, topics, msg):
834 833 """save an iopub message into the db"""
835 834 # print (topics)
836 835 try:
837 836 msg = self.session.unserialize(msg, content=True)
838 837 except Exception:
839 838 self.log.error("iopub::invalid IOPub message", exc_info=True)
840 839 return
841 840
842 841 parent = msg['parent_header']
843 842 if not parent:
844 843 self.log.warn("iopub::IOPub message lacks parent: %r", msg)
845 844 return
846 845 msg_id = parent['msg_id']
847 846 msg_type = msg['header']['msg_type']
848 847 content = msg['content']
849 848
850 849 # ensure msg_id is in db
851 850 try:
852 851 rec = self.db.get_record(msg_id)
853 852 except KeyError:
854 853 rec = empty_record()
855 854 rec['msg_id'] = msg_id
856 855 self.db.add_record(msg_id, rec)
857 856 # stream
858 857 d = {}
859 858 if msg_type == 'stream':
860 859 name = content['name']
861 860 s = rec[name] or ''
862 861 d[name] = s + content['data']
863 862
864 863 elif msg_type == 'pyerr':
865 864 d['pyerr'] = content
866 865 elif msg_type == 'pyin':
867 866 d['pyin'] = content['code']
868 867 elif msg_type in ('display_data', 'pyout'):
869 868 d[msg_type] = content
870 869 elif msg_type == 'status':
871 870 pass
872 871 elif msg_type == 'data_pub':
873 872 self.log.info("ignored data_pub message for %s" % msg_id)
874 873 else:
875 874 self.log.warn("unhandled iopub msg_type: %r", msg_type)
876 875
877 876 if not d:
878 877 return
879 878
880 879 try:
881 880 self.db.update_record(msg_id, d)
882 881 except Exception:
883 882 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
884 883
885 884
886 885
887 886 #-------------------------------------------------------------------------
888 887 # Registration requests
889 888 #-------------------------------------------------------------------------
890 889
891 890 def connection_request(self, client_id, msg):
892 891 """Reply with connection addresses for clients."""
893 892 self.log.info("client::client %r connected", client_id)
894 893 content = dict(status='ok')
895 894 jsonable = {}
896 895 for k,v in self.keytable.iteritems():
897 896 if v not in self.dead_engines:
898 897 jsonable[str(k)] = v
899 898 content['engines'] = jsonable
900 899 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
901 900
902 901 def register_engine(self, reg, msg):
903 902 """Register a new engine."""
904 903 content = msg['content']
905 904 try:
906 905 uuid = content['uuid']
907 906 except KeyError:
908 907 self.log.error("registration::queue not specified", exc_info=True)
909 908 return
910 909
911 910 eid = self._next_id
912 911
913 912 self.log.debug("registration::register_engine(%i, %r)", eid, uuid)
914 913
915 914 content = dict(id=eid,status='ok',hb_period=self.heartmonitor.period)
916 915 # check if requesting available IDs:
917 916 if cast_bytes(uuid) in self.by_ident:
918 917 try:
919 918 raise KeyError("uuid %r in use" % uuid)
920 919 except:
921 920 content = error.wrap_exception()
922 921 self.log.error("uuid %r in use", uuid, exc_info=True)
923 922 else:
924 923 for h, ec in self.incoming_registrations.iteritems():
925 924 if uuid == h:
926 925 try:
927 926 raise KeyError("heart_id %r in use" % uuid)
928 927 except:
929 928 self.log.error("heart_id %r in use", uuid, exc_info=True)
930 929 content = error.wrap_exception()
931 930 break
932 931 elif uuid == ec.uuid:
933 932 try:
934 933 raise KeyError("uuid %r in use" % uuid)
935 934 except:
936 935 self.log.error("uuid %r in use", uuid, exc_info=True)
937 936 content = error.wrap_exception()
938 937 break
939 938
940 939 msg = self.session.send(self.query, "registration_reply",
941 940 content=content,
942 941 ident=reg)
943 942
944 943 heart = cast_bytes(uuid)
945 944
946 945 if content['status'] == 'ok':
947 946 if heart in self.heartmonitor.hearts:
948 947 # already beating
949 948 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid)
950 949 self.finish_registration(heart)
951 950 else:
952 951 purge = lambda : self._purge_stalled_registration(heart)
953 952 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
954 953 dc.start()
955 954 self.incoming_registrations[heart] = EngineConnector(id=eid,uuid=uuid,stallback=dc)
956 955 else:
957 956 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
958 957
959 958 return eid
960 959
961 960 def unregister_engine(self, ident, msg):
962 961 """Unregister an engine that explicitly requested to leave."""
963 962 try:
964 963 eid = msg['content']['id']
965 964 except:
966 965 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
967 966 return
968 967 self.log.info("registration::unregister_engine(%r)", eid)
969 968 # print (eid)
970 969 uuid = self.keytable[eid]
971 970 content=dict(id=eid, uuid=uuid)
972 971 self.dead_engines.add(uuid)
973 972 # self.ids.remove(eid)
974 973 # uuid = self.keytable.pop(eid)
975 974 #
976 975 # ec = self.engines.pop(eid)
977 976 # self.hearts.pop(ec.heartbeat)
978 977 # self.by_ident.pop(ec.queue)
979 978 # self.completed.pop(eid)
980 979 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
981 980 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
982 981 dc.start()
983 982 ############## TODO: HANDLE IT ################
984 983
985 984 self._save_engine_state()
986 985
987 986 if self.notifier:
988 987 self.session.send(self.notifier, "unregistration_notification", content=content)
989 988
990 989 def _handle_stranded_msgs(self, eid, uuid):
991 990 """Handle messages known to be on an engine when the engine unregisters.
992 991
993 992 It is possible that this will fire prematurely - that is, an engine will
994 993 go down after completing a result, and the client will be notified
995 994 that the result failed and later receive the actual result.
996 995 """
997 996
998 997 outstanding = self.queues[eid]
999 998
1000 999 for msg_id in outstanding:
1001 1000 self.pending.remove(msg_id)
1002 1001 self.all_completed.add(msg_id)
1003 1002 try:
1004 1003 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
1005 1004 except:
1006 1005 content = error.wrap_exception()
1007 1006 # build a fake header:
1008 1007 header = {}
1009 1008 header['engine'] = uuid
1010 1009 header['date'] = datetime.now()
1011 1010 rec = dict(result_content=content, result_header=header, result_buffers=[])
1012 1011 rec['completed'] = header['date']
1013 1012 rec['engine_uuid'] = uuid
1014 1013 try:
1015 1014 self.db.update_record(msg_id, rec)
1016 1015 except Exception:
1017 1016 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
1018 1017
1019 1018
1020 1019 def finish_registration(self, heart):
1021 1020 """Second half of engine registration, called after our HeartMonitor
1022 1021 has received a beat from the Engine's Heart."""
1023 1022 try:
1024 1023 ec = self.incoming_registrations.pop(heart)
1025 1024 except KeyError:
1026 1025 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
1027 1026 return
1028 1027 self.log.info("registration::finished registering engine %i:%s", ec.id, ec.uuid)
1029 1028 if ec.stallback is not None:
1030 1029 ec.stallback.stop()
1031 1030 eid = ec.id
1032 1031 self.ids.add(eid)
1033 1032 self.keytable[eid] = ec.uuid
1034 1033 self.engines[eid] = ec
1035 1034 self.by_ident[cast_bytes(ec.uuid)] = ec.id
1036 1035 self.queues[eid] = list()
1037 1036 self.tasks[eid] = list()
1038 1037 self.completed[eid] = list()
1039 1038 self.hearts[heart] = eid
1040 1039 content = dict(id=eid, uuid=self.engines[eid].uuid)
1041 1040 if self.notifier:
1042 1041 self.session.send(self.notifier, "registration_notification", content=content)
1043 1042 self.log.info("engine::Engine Connected: %i", eid)
1044 1043
1045 1044 self._save_engine_state()
1046 1045
1047 1046 def _purge_stalled_registration(self, heart):
1048 1047 if heart in self.incoming_registrations:
1049 1048 ec = self.incoming_registrations.pop(heart)
1050 1049 self.log.info("registration::purging stalled registration: %i", ec.id)
1051 1050 else:
1052 1051 pass
1053 1052
1054 1053 #-------------------------------------------------------------------------
1055 1054 # Engine State
1056 1055 #-------------------------------------------------------------------------
1057 1056
1058 1057
1059 1058 def _cleanup_engine_state_file(self):
1060 1059 """cleanup engine state mapping"""
1061 1060
1062 1061 if os.path.exists(self.engine_state_file):
1063 1062 self.log.debug("cleaning up engine state: %s", self.engine_state_file)
1064 1063 try:
1065 1064 os.remove(self.engine_state_file)
1066 1065 except IOError:
1067 1066 self.log.error("Couldn't cleanup file: %s", self.engine_state_file, exc_info=True)
1068 1067
1069 1068
1070 1069 def _save_engine_state(self):
1071 1070 """save engine mapping to JSON file"""
1072 1071 if not self.engine_state_file:
1073 1072 return
1074 1073 self.log.debug("save engine state to %s" % self.engine_state_file)
1075 1074 state = {}
1076 1075 engines = {}
1077 1076 for eid, ec in self.engines.iteritems():
1078 1077 if ec.uuid not in self.dead_engines:
1079 1078 engines[eid] = ec.uuid
1080 1079
1081 1080 state['engines'] = engines
1082 1081
1083 1082 state['next_id'] = self._idcounter
1084 1083
1085 1084 with open(self.engine_state_file, 'w') as f:
1086 1085 json.dump(state, f)
1087 1086
1088 1087
1089 1088 def _load_engine_state(self):
1090 1089 """load engine mapping from JSON file"""
1091 1090 if not os.path.exists(self.engine_state_file):
1092 1091 return
1093 1092
1094 1093 self.log.info("loading engine state from %s" % self.engine_state_file)
1095 1094
1096 1095 with open(self.engine_state_file) as f:
1097 1096 state = json.load(f)
1098 1097
1099 1098 save_notifier = self.notifier
1100 1099 self.notifier = None
1101 1100 for eid, uuid in state['engines'].iteritems():
1102 1101 heart = uuid.encode('ascii')
1103 1102 # start with this heart as current and beating:
1104 1103 self.heartmonitor.responses.add(heart)
1105 1104 self.heartmonitor.hearts.add(heart)
1106 1105
1107 1106 self.incoming_registrations[heart] = EngineConnector(id=int(eid), uuid=uuid)
1108 1107 self.finish_registration(heart)
1109 1108
1110 1109 self.notifier = save_notifier
1111 1110
1112 1111 self._idcounter = state['next_id']
1113 1112
1114 1113 #-------------------------------------------------------------------------
1115 1114 # Client Requests
1116 1115 #-------------------------------------------------------------------------
1117 1116
1118 1117 def shutdown_request(self, client_id, msg):
1119 1118 """handle shutdown request."""
1120 1119 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1121 1120 # also notify other clients of shutdown
1122 1121 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1123 1122 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1124 1123 dc.start()
1125 1124
1126 1125 def _shutdown(self):
1127 1126 self.log.info("hub::hub shutting down.")
1128 1127 time.sleep(0.1)
1129 1128 sys.exit(0)
1130 1129
1131 1130
1132 1131 def check_load(self, client_id, msg):
1133 1132 content = msg['content']
1134 1133 try:
1135 1134 targets = content['targets']
1136 1135 targets = self._validate_targets(targets)
1137 1136 except:
1138 1137 content = error.wrap_exception()
1139 1138 self.session.send(self.query, "hub_error",
1140 1139 content=content, ident=client_id)
1141 1140 return
1142 1141
1143 1142 content = dict(status='ok')
1144 1143 # loads = {}
1145 1144 for t in targets:
1146 1145 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1147 1146 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1148 1147
1149 1148
1150 1149 def queue_status(self, client_id, msg):
1151 1150 """Return the Queue status of one or more targets.
1152 1151 if verbose: return the msg_ids
1153 1152 else: return len of each type.
1154 1153 keys: queue (pending MUX jobs)
1155 1154 tasks (pending Task jobs)
1156 1155 completed (finished jobs from both queues)"""
1157 1156 content = msg['content']
1158 1157 targets = content['targets']
1159 1158 try:
1160 1159 targets = self._validate_targets(targets)
1161 1160 except:
1162 1161 content = error.wrap_exception()
1163 1162 self.session.send(self.query, "hub_error",
1164 1163 content=content, ident=client_id)
1165 1164 return
1166 1165 verbose = content.get('verbose', False)
1167 1166 content = dict(status='ok')
1168 1167 for t in targets:
1169 1168 queue = self.queues[t]
1170 1169 completed = self.completed[t]
1171 1170 tasks = self.tasks[t]
1172 1171 if not verbose:
1173 1172 queue = len(queue)
1174 1173 completed = len(completed)
1175 1174 tasks = len(tasks)
1176 1175 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1177 1176 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1178 1177 # print (content)
1179 1178 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1180 1179
1181 1180 def purge_results(self, client_id, msg):
1182 1181 """Purge results from memory. This method is more valuable before we move
1183 1182 to a DB based message storage mechanism."""
1184 1183 content = msg['content']
1185 1184 self.log.info("Dropping records with %s", content)
1186 1185 msg_ids = content.get('msg_ids', [])
1187 1186 reply = dict(status='ok')
1188 1187 if msg_ids == 'all':
1189 1188 try:
1190 1189 self.db.drop_matching_records(dict(completed={'$ne':None}))
1191 1190 except Exception:
1192 1191 reply = error.wrap_exception()
1193 1192 else:
1194 1193 pending = filter(lambda m: m in self.pending, msg_ids)
1195 1194 if pending:
1196 1195 try:
1197 1196 raise IndexError("msg pending: %r" % pending[0])
1198 1197 except:
1199 1198 reply = error.wrap_exception()
1200 1199 else:
1201 1200 try:
1202 1201 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1203 1202 except Exception:
1204 1203 reply = error.wrap_exception()
1205 1204
1206 1205 if reply['status'] == 'ok':
1207 1206 eids = content.get('engine_ids', [])
1208 1207 for eid in eids:
1209 1208 if eid not in self.engines:
1210 1209 try:
1211 1210 raise IndexError("No such engine: %i" % eid)
1212 1211 except:
1213 1212 reply = error.wrap_exception()
1214 1213 break
1215 1214 uid = self.engines[eid].uuid
1216 1215 try:
1217 1216 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1218 1217 except Exception:
1219 1218 reply = error.wrap_exception()
1220 1219 break
1221 1220
1222 1221 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1223 1222
1224 1223 def resubmit_task(self, client_id, msg):
1225 1224 """Resubmit one or more tasks."""
1226 1225 def finish(reply):
1227 1226 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1228 1227
1229 1228 content = msg['content']
1230 1229 msg_ids = content['msg_ids']
1231 1230 reply = dict(status='ok')
1232 1231 try:
1233 1232 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1234 1233 'header', 'content', 'buffers'])
1235 1234 except Exception:
1236 1235 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1237 1236 return finish(error.wrap_exception())
1238 1237
1239 1238 # validate msg_ids
1240 1239 found_ids = [ rec['msg_id'] for rec in records ]
1241 1240 pending_ids = [ msg_id for msg_id in found_ids if msg_id in self.pending ]
1242 1241 if len(records) > len(msg_ids):
1243 1242 try:
1244 1243 raise RuntimeError("DB appears to be in an inconsistent state."
1245 1244 "More matching records were found than should exist")
1246 1245 except Exception:
1247 1246 return finish(error.wrap_exception())
1248 1247 elif len(records) < len(msg_ids):
1249 1248 missing = [ m for m in msg_ids if m not in found_ids ]
1250 1249 try:
1251 1250 raise KeyError("No such msg(s): %r" % missing)
1252 1251 except KeyError:
1253 1252 return finish(error.wrap_exception())
1254 1253 elif pending_ids:
1255 1254 pass
1256 1255 # no need to raise on resubmit of pending task, now that we
1257 1256 # resubmit under new ID, but do we want to raise anyway?
1258 1257 # msg_id = invalid_ids[0]
1259 1258 # try:
1260 1259 # raise ValueError("Task(s) %r appears to be inflight" % )
1261 1260 # except Exception:
1262 1261 # return finish(error.wrap_exception())
1263 1262
1264 1263 # mapping of original IDs to resubmitted IDs
1265 1264 resubmitted = {}
1266 1265
1267 1266 # send the messages
1268 1267 for rec in records:
1269 1268 header = rec['header']
1270 1269 msg = self.session.msg(header['msg_type'], parent=header)
1271 1270 msg_id = msg['msg_id']
1272 1271 msg['content'] = rec['content']
1273 1272
1274 1273 # use the old header, but update msg_id and timestamp
1275 1274 fresh = msg['header']
1276 1275 header['msg_id'] = fresh['msg_id']
1277 1276 header['date'] = fresh['date']
1278 1277 msg['header'] = header
1279 1278
1280 1279 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1281 1280
1282 1281 resubmitted[rec['msg_id']] = msg_id
1283 1282 self.pending.add(msg_id)
1284 1283 msg['buffers'] = rec['buffers']
1285 1284 try:
1286 1285 self.db.add_record(msg_id, init_record(msg))
1287 1286 except Exception:
1288 1287 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1289 1288 return finish(error.wrap_exception())
1290 1289
1291 1290 finish(dict(status='ok', resubmitted=resubmitted))
1292 1291
1293 1292 # store the new IDs in the Task DB
1294 1293 for msg_id, resubmit_id in resubmitted.iteritems():
1295 1294 try:
1296 1295 self.db.update_record(msg_id, {'resubmitted' : resubmit_id})
1297 1296 except Exception:
1298 1297 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1299 1298
1300 1299
1301 1300 def _extract_record(self, rec):
1302 1301 """decompose a TaskRecord dict into subsection of reply for get_result"""
1303 1302 io_dict = {}
1304 1303 for key in ('pyin', 'pyout', 'pyerr', 'stdout', 'stderr'):
1305 1304 io_dict[key] = rec[key]
1306 1305 content = {
1307 1306 'header': rec['header'],
1308 1307 'metadata': rec['metadata'],
1309 1308 'result_metadata': rec['result_metadata'],
1310 1309 'result_header' : rec['result_header'],
1311 1310 'result_content': rec['result_content'],
1312 1311 'received' : rec['received'],
1313 1312 'io' : io_dict,
1314 1313 }
1315 1314 if rec['result_buffers']:
1316 1315 buffers = map(bytes, rec['result_buffers'])
1317 1316 else:
1318 1317 buffers = []
1319 1318
1320 1319 return content, buffers
1321 1320
1322 1321 def get_results(self, client_id, msg):
1323 1322 """Get the result of 1 or more messages."""
1324 1323 content = msg['content']
1325 1324 msg_ids = sorted(set(content['msg_ids']))
1326 1325 statusonly = content.get('status_only', False)
1327 1326 pending = []
1328 1327 completed = []
1329 1328 content = dict(status='ok')
1330 1329 content['pending'] = pending
1331 1330 content['completed'] = completed
1332 1331 buffers = []
1333 1332 if not statusonly:
1334 1333 try:
1335 1334 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1336 1335 # turn match list into dict, for faster lookup
1337 1336 records = {}
1338 1337 for rec in matches:
1339 1338 records[rec['msg_id']] = rec
1340 1339 except Exception:
1341 1340 content = error.wrap_exception()
1342 1341 self.session.send(self.query, "result_reply", content=content,
1343 1342 parent=msg, ident=client_id)
1344 1343 return
1345 1344 else:
1346 1345 records = {}
1347 1346 for msg_id in msg_ids:
1348 1347 if msg_id in self.pending:
1349 1348 pending.append(msg_id)
1350 1349 elif msg_id in self.all_completed:
1351 1350 completed.append(msg_id)
1352 1351 if not statusonly:
1353 1352 c,bufs = self._extract_record(records[msg_id])
1354 1353 content[msg_id] = c
1355 1354 buffers.extend(bufs)
1356 1355 elif msg_id in records:
1357 1356 if rec['completed']:
1358 1357 completed.append(msg_id)
1359 1358 c,bufs = self._extract_record(records[msg_id])
1360 1359 content[msg_id] = c
1361 1360 buffers.extend(bufs)
1362 1361 else:
1363 1362 pending.append(msg_id)
1364 1363 else:
1365 1364 try:
1366 1365 raise KeyError('No such message: '+msg_id)
1367 1366 except:
1368 1367 content = error.wrap_exception()
1369 1368 break
1370 1369 self.session.send(self.query, "result_reply", content=content,
1371 1370 parent=msg, ident=client_id,
1372 1371 buffers=buffers)
1373 1372
1374 1373 def get_history(self, client_id, msg):
1375 1374 """Get a list of all msg_ids in our DB records"""
1376 1375 try:
1377 1376 msg_ids = self.db.get_history()
1378 1377 except Exception as e:
1379 1378 content = error.wrap_exception()
1380 1379 else:
1381 1380 content = dict(status='ok', history=msg_ids)
1382 1381
1383 1382 self.session.send(self.query, "history_reply", content=content,
1384 1383 parent=msg, ident=client_id)
1385 1384
1386 1385 def db_query(self, client_id, msg):
1387 1386 """Perform a raw query on the task record database."""
1388 1387 content = msg['content']
1389 1388 query = content.get('query', {})
1390 1389 keys = content.get('keys', None)
1391 1390 buffers = []
1392 1391 empty = list()
1393 1392 try:
1394 1393 records = self.db.find_records(query, keys)
1395 1394 except Exception as e:
1396 1395 content = error.wrap_exception()
1397 1396 else:
1398 1397 # extract buffers from reply content:
1399 1398 if keys is not None:
1400 1399 buffer_lens = [] if 'buffers' in keys else None
1401 1400 result_buffer_lens = [] if 'result_buffers' in keys else None
1402 1401 else:
1403 1402 buffer_lens = None
1404 1403 result_buffer_lens = None
1405 1404
1406 1405 for rec in records:
1407 1406 # buffers may be None, so double check
1408 1407 b = rec.pop('buffers', empty) or empty
1409 1408 if buffer_lens is not None:
1410 1409 buffer_lens.append(len(b))
1411 1410 buffers.extend(b)
1412 1411 rb = rec.pop('result_buffers', empty) or empty
1413 1412 if result_buffer_lens is not None:
1414 1413 result_buffer_lens.append(len(rb))
1415 1414 buffers.extend(rb)
1416 1415 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1417 1416 result_buffer_lens=result_buffer_lens)
1418 1417 # self.log.debug (content)
1419 1418 self.session.send(self.query, "db_reply", content=content,
1420 1419 parent=msg, ident=client_id,
1421 1420 buffers=buffers)
1422 1421
@@ -1,382 +1,374 b''
1 1 """ A minimal application using the Qt console-style IPython frontend.
2 2
3 3 This is not a complete console app, as subprocess will not be able to receive
4 4 input, there is no real readline support, among other limitations.
5 5
6 6 Authors:
7 7
8 8 * Evan Patterson
9 9 * Min RK
10 10 * Erik Tollerud
11 11 * Fernando Perez
12 12 * Bussonnier Matthias
13 13 * Thomas Kluyver
14 14 * Paul Ivanov
15 15
16 16 """
17 17
18 18 #-----------------------------------------------------------------------------
19 19 # Imports
20 20 #-----------------------------------------------------------------------------
21 21
22 22 # stdlib imports
23 23 import os
24 24 import signal
25 25 import sys
26 26
27 27 # If run on Windows, install an exception hook which pops up a
28 28 # message box. Pythonw.exe hides the console, so without this
29 29 # the application silently fails to load.
30 30 #
31 31 # We always install this handler, because the expectation is for
32 32 # qtconsole to bring up a GUI even if called from the console.
33 33 # The old handler is called, so the exception is printed as well.
34 34 # If desired, check for pythonw with an additional condition
35 35 # (sys.executable.lower().find('pythonw.exe') >= 0).
36 36 if os.name == 'nt':
37 37 old_excepthook = sys.excepthook
38 38
39 39 def gui_excepthook(exctype, value, tb):
40 40 try:
41 41 import ctypes, traceback
42 42 MB_ICONERROR = 0x00000010L
43 43 title = u'Error starting IPython QtConsole'
44 44 msg = u''.join(traceback.format_exception(exctype, value, tb))
45 45 ctypes.windll.user32.MessageBoxW(0, msg, title, MB_ICONERROR)
46 46 finally:
47 47 # Also call the old exception hook to let it do
48 48 # its thing too.
49 49 old_excepthook(exctype, value, tb)
50 50
51 51 sys.excepthook = gui_excepthook
52 52
53 53 # System library imports
54 54 from IPython.external.qt import QtCore, QtGui
55 55
56 56 # Local imports
57 57 from IPython.config.application import catch_config_error
58 58 from IPython.core.application import BaseIPythonApplication
59 59 from IPython.qt.console.ipython_widget import IPythonWidget
60 60 from IPython.qt.console.rich_ipython_widget import RichIPythonWidget
61 61 from IPython.qt.console import styles
62 62 from IPython.qt.console.mainwindow import MainWindow
63 63 from IPython.qt.client import QtKernelClient
64 64 from IPython.qt.manager import QtKernelManager
65 65 from IPython.utils.traitlets import (
66 66 Dict, Unicode, CBool, Any
67 67 )
68 68
69 69 from IPython.consoleapp import (
70 70 IPythonConsoleApp, app_aliases, app_flags, flags, aliases
71 71 )
72 72
73 73 #-----------------------------------------------------------------------------
74 74 # Network Constants
75 75 #-----------------------------------------------------------------------------
76 76
77 77 from IPython.utils.localinterfaces import is_local_ip
78 78
79 79 #-----------------------------------------------------------------------------
80 80 # Globals
81 81 #-----------------------------------------------------------------------------
82 82
83 83 _examples = """
84 84 ipython qtconsole # start the qtconsole
85 85 ipython qtconsole --matplotlib=inline # start with matplotlib inline plotting mode
86 86 """
87 87
88 88 #-----------------------------------------------------------------------------
89 89 # Aliases and Flags
90 90 #-----------------------------------------------------------------------------
91 91
92 92 # start with copy of flags
93 93 flags = dict(flags)
94 94 qt_flags = {
95 95 'plain' : ({'IPythonQtConsoleApp' : {'plain' : True}},
96 96 "Disable rich text support."),
97 97 }
98 98
99 99 # and app_flags from the Console Mixin
100 100 qt_flags.update(app_flags)
101 101 # add frontend flags to the full set
102 102 flags.update(qt_flags)
103 103
104 104 # start with copy of front&backend aliases list
105 105 aliases = dict(aliases)
106 106 qt_aliases = dict(
107 107 style = 'IPythonWidget.syntax_style',
108 108 stylesheet = 'IPythonQtConsoleApp.stylesheet',
109 109 colors = 'ZMQInteractiveShell.colors',
110 110
111 111 editor = 'IPythonWidget.editor',
112 112 paging = 'ConsoleWidget.paging',
113 113 )
114 114 # and app_aliases from the Console Mixin
115 115 qt_aliases.update(app_aliases)
116 116 qt_aliases.update({'gui-completion':'ConsoleWidget.gui_completion'})
117 117 # add frontend aliases to the full set
118 118 aliases.update(qt_aliases)
119 119
120 120 # get flags&aliases into sets, and remove a couple that
121 121 # shouldn't be scrubbed from backend flags:
122 122 qt_aliases = set(qt_aliases.keys())
123 123 qt_aliases.remove('colors')
124 124 qt_flags = set(qt_flags.keys())
125 125
126 126 #-----------------------------------------------------------------------------
127 127 # Classes
128 128 #-----------------------------------------------------------------------------
129 129
130 130 #-----------------------------------------------------------------------------
131 131 # IPythonQtConsole
132 132 #-----------------------------------------------------------------------------
133 133
134 134
135 135 class IPythonQtConsoleApp(BaseIPythonApplication, IPythonConsoleApp):
136 136 name = 'ipython-qtconsole'
137 137
138 138 description = """
139 139 The IPython QtConsole.
140 140
141 141 This launches a Console-style application using Qt. It is not a full
142 142 console, in that launched terminal subprocesses will not be able to accept
143 143 input.
144 144
145 145 The QtConsole supports various extra features beyond the Terminal IPython
146 146 shell, such as inline plotting with matplotlib, via:
147 147
148 148 ipython qtconsole --matplotlib=inline
149 149
150 150 as well as saving your session as HTML, and printing the output.
151 151
152 152 """
153 153 examples = _examples
154 154
155 155 classes = [IPythonWidget] + IPythonConsoleApp.classes
156 156 flags = Dict(flags)
157 157 aliases = Dict(aliases)
158 158 frontend_flags = Any(qt_flags)
159 159 frontend_aliases = Any(qt_aliases)
160 160 kernel_client_class = QtKernelClient
161 161 kernel_manager_class = QtKernelManager
162 162
163 163 stylesheet = Unicode('', config=True,
164 164 help="path to a custom CSS stylesheet")
165 165
166 166 hide_menubar = CBool(False, config=True,
167 167 help="Start the console window with the menu bar hidden.")
168 168
169 169 maximize = CBool(False, config=True,
170 170 help="Start the console window maximized.")
171 171
172 172 plain = CBool(False, config=True,
173 173 help="Use a plaintext widget instead of rich text (plain can't print/save).")
174 174
175 175 def _plain_changed(self, name, old, new):
176 176 kind = 'plain' if new else 'rich'
177 177 self.config.ConsoleWidget.kind = kind
178 178 if new:
179 179 self.widget_factory = IPythonWidget
180 180 else:
181 181 self.widget_factory = RichIPythonWidget
182 182
183 183 # the factory for creating a widget
184 184 widget_factory = Any(RichIPythonWidget)
185 185
186 186 def parse_command_line(self, argv=None):
187 187 super(IPythonQtConsoleApp, self).parse_command_line(argv)
188 188 self.build_kernel_argv(argv)
189 189
190 190
191 191 def new_frontend_master(self):
192 192 """ Create and return new frontend attached to new kernel, launched on localhost.
193 193 """
194 194 kernel_manager = self.kernel_manager_class(
195 195 connection_file=self._new_connection_file(),
196 196 parent=self,
197 197 autorestart=True,
198 198 )
199 199 # start the kernel
200 200 kwargs = dict()
201 201 kwargs['extra_arguments'] = self.kernel_argv
202 202 kernel_manager.start_kernel(**kwargs)
203 203 kernel_manager.client_factory = self.kernel_client_class
204 204 kernel_client = kernel_manager.client()
205 205 kernel_client.start_channels(shell=True, iopub=True)
206 206 widget = self.widget_factory(config=self.config,
207 207 local_kernel=True)
208 208 self.init_colors(widget)
209 209 widget.kernel_manager = kernel_manager
210 210 widget.kernel_client = kernel_client
211 211 widget._existing = False
212 212 widget._may_close = True
213 213 widget._confirm_exit = self.confirm_exit
214 214 return widget
215 215
216 216 def new_frontend_slave(self, current_widget):
217 217 """Create and return a new frontend attached to an existing kernel.
218 218
219 219 Parameters
220 220 ----------
221 221 current_widget : IPythonWidget
222 222 The IPythonWidget whose kernel this frontend is to share
223 223 """
224 224 kernel_client = self.kernel_client_class(
225 225 connection_file=current_widget.kernel_client.connection_file,
226 226 config = self.config,
227 227 )
228 228 kernel_client.load_connection_file()
229 229 kernel_client.start_channels()
230 230 widget = self.widget_factory(config=self.config,
231 231 local_kernel=False)
232 232 self.init_colors(widget)
233 233 widget._existing = True
234 234 widget._may_close = False
235 235 widget._confirm_exit = False
236 236 widget.kernel_client = kernel_client
237 237 widget.kernel_manager = current_widget.kernel_manager
238 238 return widget
239 239
240 240 def init_qt_app(self):
241 241 # separate from qt_elements, because it must run first
242 242 self.app = QtGui.QApplication([])
243 243
244 244 def init_qt_elements(self):
245 245 # Create the widget.
246 246
247 247 base_path = os.path.abspath(os.path.dirname(__file__))
248 248 icon_path = os.path.join(base_path, 'resources', 'icon', 'IPythonConsole.svg')
249 249 self.app.icon = QtGui.QIcon(icon_path)
250 250 QtGui.QApplication.setWindowIcon(self.app.icon)
251 251
252 252 ip = self.ip
253 253 local_kernel = (not self.existing) or is_local_ip(ip)
254 254 self.widget = self.widget_factory(config=self.config,
255 255 local_kernel=local_kernel)
256 256 self.init_colors(self.widget)
257 257 self.widget._existing = self.existing
258 258 self.widget._may_close = not self.existing
259 259 self.widget._confirm_exit = self.confirm_exit
260 260
261 261 self.widget.kernel_manager = self.kernel_manager
262 262 self.widget.kernel_client = self.kernel_client
263 263 self.window = MainWindow(self.app,
264 264 confirm_exit=self.confirm_exit,
265 265 new_frontend_factory=self.new_frontend_master,
266 266 slave_frontend_factory=self.new_frontend_slave,
267 267 )
268 268 self.window.log = self.log
269 269 self.window.add_tab_with_frontend(self.widget)
270 270 self.window.init_menu_bar()
271 271
272 272 # Ignore on OSX, where there is always a menu bar
273 273 if sys.platform != 'darwin' and self.hide_menubar:
274 274 self.window.menuBar().setVisible(False)
275 275
276 276 self.window.setWindowTitle('IPython')
277 277
278 278 def init_colors(self, widget):
279 279 """Configure the coloring of the widget"""
280 280 # Note: This will be dramatically simplified when colors
281 281 # are removed from the backend.
282 282
283 283 # parse the colors arg down to current known labels
284 try:
285 colors = self.config.ZMQInteractiveShell.colors
286 except AttributeError:
287 colors = None
288 try:
289 style = self.config.IPythonWidget.syntax_style
290 except AttributeError:
291 style = None
292 try:
293 sheet = self.config.IPythonWidget.style_sheet
294 except AttributeError:
295 sheet = None
284 cfg = self.config
285 colors = cfg.ZMQInteractiveShell.colors if 'ZMQInteractiveShell.colors' in cfg else None
286 style = cfg.IPythonWidget.syntax_style if 'IPythonWidget.syntax_style' in cfg else None
287 sheet = cfg.IPythonWidget.style_sheet if 'IPythonWidget.style_sheet' in cfg else None
296 288
297 289 # find the value for colors:
298 290 if colors:
299 291 colors=colors.lower()
300 292 if colors in ('lightbg', 'light'):
301 293 colors='lightbg'
302 294 elif colors in ('dark', 'linux'):
303 295 colors='linux'
304 296 else:
305 297 colors='nocolor'
306 298 elif style:
307 299 if style=='bw':
308 300 colors='nocolor'
309 301 elif styles.dark_style(style):
310 302 colors='linux'
311 303 else:
312 304 colors='lightbg'
313 305 else:
314 306 colors=None
315 307
316 308 # Configure the style
317 309 if style:
318 310 widget.style_sheet = styles.sheet_from_template(style, colors)
319 311 widget.syntax_style = style
320 312 widget._syntax_style_changed()
321 313 widget._style_sheet_changed()
322 314 elif colors:
323 315 # use a default dark/light/bw style
324 316 widget.set_default_style(colors=colors)
325 317
326 318 if self.stylesheet:
327 319 # we got an explicit stylesheet
328 320 if os.path.isfile(self.stylesheet):
329 321 with open(self.stylesheet) as f:
330 322 sheet = f.read()
331 323 else:
332 324 raise IOError("Stylesheet %r not found." % self.stylesheet)
333 325 if sheet:
334 326 widget.style_sheet = sheet
335 327 widget._style_sheet_changed()
336 328
337 329
338 330 def init_signal(self):
339 331 """allow clean shutdown on sigint"""
340 332 signal.signal(signal.SIGINT, lambda sig, frame: self.exit(-2))
341 333 # need a timer, so that QApplication doesn't block until a real
342 334 # Qt event fires (can require mouse movement)
343 335 # timer trick from http://stackoverflow.com/q/4938723/938949
344 336 timer = QtCore.QTimer()
345 337 # Let the interpreter run each 200 ms:
346 338 timer.timeout.connect(lambda: None)
347 339 timer.start(200)
348 340 # hold onto ref, so the timer doesn't get cleaned up
349 341 self._sigint_timer = timer
350 342
351 343 @catch_config_error
352 344 def initialize(self, argv=None):
353 345 self.init_qt_app()
354 346 super(IPythonQtConsoleApp, self).initialize(argv)
355 347 IPythonConsoleApp.initialize(self,argv)
356 348 self.init_qt_elements()
357 349 self.init_signal()
358 350
359 351 def start(self):
360 352
361 353 # draw the window
362 354 if self.maximize:
363 355 self.window.showMaximized()
364 356 else:
365 357 self.window.show()
366 358 self.window.raise_()
367 359
368 360 # Start the application main loop.
369 361 self.app.exec_()
370 362
371 363 #-----------------------------------------------------------------------------
372 364 # Main entry point
373 365 #-----------------------------------------------------------------------------
374 366
375 367 def main():
376 368 app = IPythonQtConsoleApp()
377 369 app.initialize()
378 370 app.start()
379 371
380 372
381 373 if __name__ == '__main__':
382 374 main()
General Comments 0
You need to be logged in to leave comments. Login now