##// END OF EJS Templates
Fix rpmlint: non-executable-script...
Thomas Spura -
Show More
@@ -1,328 +1,327 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 A base class for objects that are configurable.
3 A base class for objects that are configurable.
5
4
6 Authors:
5 Authors:
7
6
8 * Brian Granger
7 * Brian Granger
9 * Fernando Perez
8 * Fernando Perez
10 * Min RK
9 * Min RK
11 """
10 """
12
11
13 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
14 # Copyright (C) 2008-2011 The IPython Development Team
13 # Copyright (C) 2008-2011 The IPython Development Team
15 #
14 #
16 # Distributed under the terms of the BSD License. The full license is in
15 # Distributed under the terms of the BSD License. The full license is in
17 # the file COPYING, distributed as part of this software.
16 # the file COPYING, distributed as part of this software.
18 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
19
18
20 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
21 # Imports
20 # Imports
22 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
23
22
24 import datetime
23 import datetime
25 from copy import deepcopy
24 from copy import deepcopy
26
25
27 from loader import Config
26 from loader import Config
28 from IPython.utils.traitlets import HasTraits, Instance
27 from IPython.utils.traitlets import HasTraits, Instance
29 from IPython.utils.text import indent, wrap_paragraphs
28 from IPython.utils.text import indent, wrap_paragraphs
30
29
31
30
32 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
33 # Helper classes for Configurables
32 # Helper classes for Configurables
34 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
35
34
36
35
37 class ConfigurableError(Exception):
36 class ConfigurableError(Exception):
38 pass
37 pass
39
38
40
39
41 class MultipleInstanceError(ConfigurableError):
40 class MultipleInstanceError(ConfigurableError):
42 pass
41 pass
43
42
44 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
45 # Configurable implementation
44 # Configurable implementation
46 #-----------------------------------------------------------------------------
45 #-----------------------------------------------------------------------------
47
46
48 class Configurable(HasTraits):
47 class Configurable(HasTraits):
49
48
50 config = Instance(Config,(),{})
49 config = Instance(Config,(),{})
51 created = None
50 created = None
52
51
53 def __init__(self, **kwargs):
52 def __init__(self, **kwargs):
54 """Create a configurable given a config config.
53 """Create a configurable given a config config.
55
54
56 Parameters
55 Parameters
57 ----------
56 ----------
58 config : Config
57 config : Config
59 If this is empty, default values are used. If config is a
58 If this is empty, default values are used. If config is a
60 :class:`Config` instance, it will be used to configure the
59 :class:`Config` instance, it will be used to configure the
61 instance.
60 instance.
62
61
63 Notes
62 Notes
64 -----
63 -----
65 Subclasses of Configurable must call the :meth:`__init__` method of
64 Subclasses of Configurable must call the :meth:`__init__` method of
66 :class:`Configurable` *before* doing anything else and using
65 :class:`Configurable` *before* doing anything else and using
67 :func:`super`::
66 :func:`super`::
68
67
69 class MyConfigurable(Configurable):
68 class MyConfigurable(Configurable):
70 def __init__(self, config=None):
69 def __init__(self, config=None):
71 super(MyConfigurable, self).__init__(config)
70 super(MyConfigurable, self).__init__(config)
72 # Then any other code you need to finish initialization.
71 # Then any other code you need to finish initialization.
73
72
74 This ensures that instances will be configured properly.
73 This ensures that instances will be configured properly.
75 """
74 """
76 config = kwargs.pop('config', None)
75 config = kwargs.pop('config', None)
77 if config is not None:
76 if config is not None:
78 # We used to deepcopy, but for now we are trying to just save
77 # We used to deepcopy, but for now we are trying to just save
79 # by reference. This *could* have side effects as all components
78 # by reference. This *could* have side effects as all components
80 # will share config. In fact, I did find such a side effect in
79 # will share config. In fact, I did find such a side effect in
81 # _config_changed below. If a config attribute value was a mutable type
80 # _config_changed below. If a config attribute value was a mutable type
82 # all instances of a component were getting the same copy, effectively
81 # all instances of a component were getting the same copy, effectively
83 # making that a class attribute.
82 # making that a class attribute.
84 # self.config = deepcopy(config)
83 # self.config = deepcopy(config)
85 self.config = config
84 self.config = config
86 # This should go second so individual keyword arguments override
85 # This should go second so individual keyword arguments override
87 # the values in config.
86 # the values in config.
88 super(Configurable, self).__init__(**kwargs)
87 super(Configurable, self).__init__(**kwargs)
89 self.created = datetime.datetime.now()
88 self.created = datetime.datetime.now()
90
89
91 #-------------------------------------------------------------------------
90 #-------------------------------------------------------------------------
92 # Static trait notifiations
91 # Static trait notifiations
93 #-------------------------------------------------------------------------
92 #-------------------------------------------------------------------------
94
93
95 def _config_changed(self, name, old, new):
94 def _config_changed(self, name, old, new):
96 """Update all the class traits having ``config=True`` as metadata.
95 """Update all the class traits having ``config=True`` as metadata.
97
96
98 For any class trait with a ``config`` metadata attribute that is
97 For any class trait with a ``config`` metadata attribute that is
99 ``True``, we update the trait with the value of the corresponding
98 ``True``, we update the trait with the value of the corresponding
100 config entry.
99 config entry.
101 """
100 """
102 # Get all traits with a config metadata entry that is True
101 # Get all traits with a config metadata entry that is True
103 traits = self.traits(config=True)
102 traits = self.traits(config=True)
104
103
105 # We auto-load config section for this class as well as any parent
104 # We auto-load config section for this class as well as any parent
106 # classes that are Configurable subclasses. This starts with Configurable
105 # classes that are Configurable subclasses. This starts with Configurable
107 # and works down the mro loading the config for each section.
106 # and works down the mro loading the config for each section.
108 section_names = [cls.__name__ for cls in \
107 section_names = [cls.__name__ for cls in \
109 reversed(self.__class__.__mro__) if
108 reversed(self.__class__.__mro__) if
110 issubclass(cls, Configurable) and issubclass(self.__class__, cls)]
109 issubclass(cls, Configurable) and issubclass(self.__class__, cls)]
111
110
112 for sname in section_names:
111 for sname in section_names:
113 # Don't do a blind getattr as that would cause the config to
112 # Don't do a blind getattr as that would cause the config to
114 # dynamically create the section with name self.__class__.__name__.
113 # dynamically create the section with name self.__class__.__name__.
115 if new._has_section(sname):
114 if new._has_section(sname):
116 my_config = new[sname]
115 my_config = new[sname]
117 for k, v in traits.iteritems():
116 for k, v in traits.iteritems():
118 # Don't allow traitlets with config=True to start with
117 # Don't allow traitlets with config=True to start with
119 # uppercase. Otherwise, they are confused with Config
118 # uppercase. Otherwise, they are confused with Config
120 # subsections. But, developers shouldn't have uppercase
119 # subsections. But, developers shouldn't have uppercase
121 # attributes anyways! (PEP 6)
120 # attributes anyways! (PEP 6)
122 if k[0].upper()==k[0] and not k.startswith('_'):
121 if k[0].upper()==k[0] and not k.startswith('_'):
123 raise ConfigurableError('Configurable traitlets with '
122 raise ConfigurableError('Configurable traitlets with '
124 'config=True must start with a lowercase so they are '
123 'config=True must start with a lowercase so they are '
125 'not confused with Config subsections: %s.%s' % \
124 'not confused with Config subsections: %s.%s' % \
126 (self.__class__.__name__, k))
125 (self.__class__.__name__, k))
127 try:
126 try:
128 # Here we grab the value from the config
127 # Here we grab the value from the config
129 # If k has the naming convention of a config
128 # If k has the naming convention of a config
130 # section, it will be auto created.
129 # section, it will be auto created.
131 config_value = my_config[k]
130 config_value = my_config[k]
132 except KeyError:
131 except KeyError:
133 pass
132 pass
134 else:
133 else:
135 # print "Setting %s.%s from %s.%s=%r" % \
134 # print "Setting %s.%s from %s.%s=%r" % \
136 # (self.__class__.__name__,k,sname,k,config_value)
135 # (self.__class__.__name__,k,sname,k,config_value)
137 # We have to do a deepcopy here if we don't deepcopy the entire
136 # We have to do a deepcopy here if we don't deepcopy the entire
138 # config object. If we don't, a mutable config_value will be
137 # config object. If we don't, a mutable config_value will be
139 # shared by all instances, effectively making it a class attribute.
138 # shared by all instances, effectively making it a class attribute.
140 setattr(self, k, deepcopy(config_value))
139 setattr(self, k, deepcopy(config_value))
141
140
142 @classmethod
141 @classmethod
143 def class_get_help(cls):
142 def class_get_help(cls):
144 """Get the help string for this class in ReST format."""
143 """Get the help string for this class in ReST format."""
145 cls_traits = cls.class_traits(config=True)
144 cls_traits = cls.class_traits(config=True)
146 final_help = []
145 final_help = []
147 final_help.append(u'%s options' % cls.__name__)
146 final_help.append(u'%s options' % cls.__name__)
148 final_help.append(len(final_help[0])*u'-')
147 final_help.append(len(final_help[0])*u'-')
149 for k,v in cls.class_traits(config=True).iteritems():
148 for k,v in cls.class_traits(config=True).iteritems():
150 help = cls.class_get_trait_help(v)
149 help = cls.class_get_trait_help(v)
151 final_help.append(help)
150 final_help.append(help)
152 return '\n'.join(final_help)
151 return '\n'.join(final_help)
153
152
154 @classmethod
153 @classmethod
155 def class_get_trait_help(cls, trait):
154 def class_get_trait_help(cls, trait):
156 """Get the help string for a single trait."""
155 """Get the help string for a single trait."""
157 lines = []
156 lines = []
158 header = "--%s.%s=<%s>" % (cls.__name__, trait.name, trait.__class__.__name__)
157 header = "--%s.%s=<%s>" % (cls.__name__, trait.name, trait.__class__.__name__)
159 lines.append(header)
158 lines.append(header)
160 try:
159 try:
161 dvr = repr(trait.get_default_value())
160 dvr = repr(trait.get_default_value())
162 except Exception:
161 except Exception:
163 dvr = None # ignore defaults we can't construct
162 dvr = None # ignore defaults we can't construct
164 if dvr is not None:
163 if dvr is not None:
165 if len(dvr) > 64:
164 if len(dvr) > 64:
166 dvr = dvr[:61]+'...'
165 dvr = dvr[:61]+'...'
167 lines.append(indent('Default: %s'%dvr, 4))
166 lines.append(indent('Default: %s'%dvr, 4))
168 if 'Enum' in trait.__class__.__name__:
167 if 'Enum' in trait.__class__.__name__:
169 # include Enum choices
168 # include Enum choices
170 lines.append(indent('Choices: %r'%(trait.values,)))
169 lines.append(indent('Choices: %r'%(trait.values,)))
171
170
172 help = trait.get_metadata('help')
171 help = trait.get_metadata('help')
173 if help is not None:
172 if help is not None:
174 help = '\n'.join(wrap_paragraphs(help, 76))
173 help = '\n'.join(wrap_paragraphs(help, 76))
175 lines.append(indent(help, 4))
174 lines.append(indent(help, 4))
176 return '\n'.join(lines)
175 return '\n'.join(lines)
177
176
178 @classmethod
177 @classmethod
179 def class_print_help(cls):
178 def class_print_help(cls):
180 """Get the help string for a single trait and print it."""
179 """Get the help string for a single trait and print it."""
181 print cls.class_get_help()
180 print cls.class_get_help()
182
181
183 @classmethod
182 @classmethod
184 def class_config_section(cls):
183 def class_config_section(cls):
185 """Get the config class config section"""
184 """Get the config class config section"""
186 def c(s):
185 def c(s):
187 """return a commented, wrapped block."""
186 """return a commented, wrapped block."""
188 s = '\n\n'.join(wrap_paragraphs(s, 78))
187 s = '\n\n'.join(wrap_paragraphs(s, 78))
189
188
190 return '# ' + s.replace('\n', '\n# ')
189 return '# ' + s.replace('\n', '\n# ')
191
190
192 # section header
191 # section header
193 breaker = '#' + '-'*78
192 breaker = '#' + '-'*78
194 s = "# %s configuration"%cls.__name__
193 s = "# %s configuration"%cls.__name__
195 lines = [breaker, s, breaker, '']
194 lines = [breaker, s, breaker, '']
196 # get the description trait
195 # get the description trait
197 desc = cls.class_traits().get('description')
196 desc = cls.class_traits().get('description')
198 if desc:
197 if desc:
199 desc = desc.default_value
198 desc = desc.default_value
200 else:
199 else:
201 # no description trait, use __doc__
200 # no description trait, use __doc__
202 desc = getattr(cls, '__doc__', '')
201 desc = getattr(cls, '__doc__', '')
203 if desc:
202 if desc:
204 lines.append(c(desc))
203 lines.append(c(desc))
205 lines.append('')
204 lines.append('')
206
205
207 parents = []
206 parents = []
208 for parent in cls.mro():
207 for parent in cls.mro():
209 # only include parents that are not base classes
208 # only include parents that are not base classes
210 # and are not the class itself
209 # and are not the class itself
211 if issubclass(parent, Configurable) and \
210 if issubclass(parent, Configurable) and \
212 not parent in (Configurable, SingletonConfigurable, cls):
211 not parent in (Configurable, SingletonConfigurable, cls):
213 parents.append(parent)
212 parents.append(parent)
214
213
215 if parents:
214 if parents:
216 pstr = ', '.join([ p.__name__ for p in parents ])
215 pstr = ', '.join([ p.__name__ for p in parents ])
217 lines.append(c('%s will inherit config from: %s'%(cls.__name__, pstr)))
216 lines.append(c('%s will inherit config from: %s'%(cls.__name__, pstr)))
218 lines.append('')
217 lines.append('')
219
218
220 for name,trait in cls.class_traits(config=True).iteritems():
219 for name,trait in cls.class_traits(config=True).iteritems():
221 help = trait.get_metadata('help') or ''
220 help = trait.get_metadata('help') or ''
222 lines.append(c(help))
221 lines.append(c(help))
223 lines.append('# c.%s.%s = %r'%(cls.__name__, name, trait.get_default_value()))
222 lines.append('# c.%s.%s = %r'%(cls.__name__, name, trait.get_default_value()))
224 lines.append('')
223 lines.append('')
225 return '\n'.join(lines)
224 return '\n'.join(lines)
226
225
227
226
228
227
229 class SingletonConfigurable(Configurable):
228 class SingletonConfigurable(Configurable):
230 """A configurable that only allows one instance.
229 """A configurable that only allows one instance.
231
230
232 This class is for classes that should only have one instance of itself
231 This class is for classes that should only have one instance of itself
233 or *any* subclass. To create and retrieve such a class use the
232 or *any* subclass. To create and retrieve such a class use the
234 :meth:`SingletonConfigurable.instance` method.
233 :meth:`SingletonConfigurable.instance` method.
235 """
234 """
236
235
237 _instance = None
236 _instance = None
238
237
239 @classmethod
238 @classmethod
240 def _walk_mro(cls):
239 def _walk_mro(cls):
241 """Walk the cls.mro() for parent classes that are also singletons
240 """Walk the cls.mro() for parent classes that are also singletons
242
241
243 For use in instance()
242 For use in instance()
244 """
243 """
245
244
246 for subclass in cls.mro():
245 for subclass in cls.mro():
247 if issubclass(cls, subclass) and \
246 if issubclass(cls, subclass) and \
248 issubclass(subclass, SingletonConfigurable) and \
247 issubclass(subclass, SingletonConfigurable) and \
249 subclass != SingletonConfigurable:
248 subclass != SingletonConfigurable:
250 yield subclass
249 yield subclass
251
250
252 @classmethod
251 @classmethod
253 def clear_instance(cls):
252 def clear_instance(cls):
254 """unset _instance for this class and singleton parents.
253 """unset _instance for this class and singleton parents.
255 """
254 """
256 if not cls.initialized():
255 if not cls.initialized():
257 return
256 return
258 for subclass in cls._walk_mro():
257 for subclass in cls._walk_mro():
259 if isinstance(subclass._instance, cls):
258 if isinstance(subclass._instance, cls):
260 # only clear instances that are instances
259 # only clear instances that are instances
261 # of the calling class
260 # of the calling class
262 subclass._instance = None
261 subclass._instance = None
263
262
264 @classmethod
263 @classmethod
265 def instance(cls, *args, **kwargs):
264 def instance(cls, *args, **kwargs):
266 """Returns a global instance of this class.
265 """Returns a global instance of this class.
267
266
268 This method create a new instance if none have previously been created
267 This method create a new instance if none have previously been created
269 and returns a previously created instance is one already exists.
268 and returns a previously created instance is one already exists.
270
269
271 The arguments and keyword arguments passed to this method are passed
270 The arguments and keyword arguments passed to this method are passed
272 on to the :meth:`__init__` method of the class upon instantiation.
271 on to the :meth:`__init__` method of the class upon instantiation.
273
272
274 Examples
273 Examples
275 --------
274 --------
276
275
277 Create a singleton class using instance, and retrieve it::
276 Create a singleton class using instance, and retrieve it::
278
277
279 >>> from IPython.config.configurable import SingletonConfigurable
278 >>> from IPython.config.configurable import SingletonConfigurable
280 >>> class Foo(SingletonConfigurable): pass
279 >>> class Foo(SingletonConfigurable): pass
281 >>> foo = Foo.instance()
280 >>> foo = Foo.instance()
282 >>> foo == Foo.instance()
281 >>> foo == Foo.instance()
283 True
282 True
284
283
285 Create a subclass that is retrived using the base class instance::
284 Create a subclass that is retrived using the base class instance::
286
285
287 >>> class Bar(SingletonConfigurable): pass
286 >>> class Bar(SingletonConfigurable): pass
288 >>> class Bam(Bar): pass
287 >>> class Bam(Bar): pass
289 >>> bam = Bam.instance()
288 >>> bam = Bam.instance()
290 >>> bam == Bar.instance()
289 >>> bam == Bar.instance()
291 True
290 True
292 """
291 """
293 # Create and save the instance
292 # Create and save the instance
294 if cls._instance is None:
293 if cls._instance is None:
295 inst = cls(*args, **kwargs)
294 inst = cls(*args, **kwargs)
296 # Now make sure that the instance will also be returned by
295 # Now make sure that the instance will also be returned by
297 # parent classes' _instance attribute.
296 # parent classes' _instance attribute.
298 for subclass in cls._walk_mro():
297 for subclass in cls._walk_mro():
299 subclass._instance = inst
298 subclass._instance = inst
300
299
301 if isinstance(cls._instance, cls):
300 if isinstance(cls._instance, cls):
302 return cls._instance
301 return cls._instance
303 else:
302 else:
304 raise MultipleInstanceError(
303 raise MultipleInstanceError(
305 'Multiple incompatible subclass instances of '
304 'Multiple incompatible subclass instances of '
306 '%s are being created.' % cls.__name__
305 '%s are being created.' % cls.__name__
307 )
306 )
308
307
309 @classmethod
308 @classmethod
310 def initialized(cls):
309 def initialized(cls):
311 """Has an instance been created?"""
310 """Has an instance been created?"""
312 return hasattr(cls, "_instance") and cls._instance is not None
311 return hasattr(cls, "_instance") and cls._instance is not None
313
312
314
313
315 class LoggingConfigurable(Configurable):
314 class LoggingConfigurable(Configurable):
316 """A parent class for Configurables that log.
315 """A parent class for Configurables that log.
317
316
318 Subclasses have a log trait, and the default behavior
317 Subclasses have a log trait, and the default behavior
319 is to get the logger from the currently running Application
318 is to get the logger from the currently running Application
320 via Application.instance().log.
319 via Application.instance().log.
321 """
320 """
322
321
323 log = Instance('logging.Logger')
322 log = Instance('logging.Logger')
324 def _log_default(self):
323 def _log_default(self):
325 from IPython.config.application import Application
324 from IPython.config.application import Application
326 return Application.instance().log
325 return Application.instance().log
327
326
328
327
@@ -1,166 +1,165 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 Tests for IPython.config.configurable
3 Tests for IPython.config.configurable
5
4
6 Authors:
5 Authors:
7
6
8 * Brian Granger
7 * Brian Granger
9 * Fernando Perez (design help)
8 * Fernando Perez (design help)
10 """
9 """
11
10
12 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
13 # Copyright (C) 2008-2010 The IPython Development Team
12 # Copyright (C) 2008-2010 The IPython Development Team
14 #
13 #
15 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
16 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
17 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
18
17
19 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
20 # Imports
19 # Imports
21 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
22
21
23 from unittest import TestCase
22 from unittest import TestCase
24
23
25 from IPython.config.configurable import (
24 from IPython.config.configurable import (
26 Configurable,
25 Configurable,
27 SingletonConfigurable
26 SingletonConfigurable
28 )
27 )
29
28
30 from IPython.utils.traitlets import (
29 from IPython.utils.traitlets import (
31 Int, Float, Unicode
30 Int, Float, Unicode
32 )
31 )
33
32
34 from IPython.config.loader import Config
33 from IPython.config.loader import Config
35
34
36
35
37 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
38 # Test cases
37 # Test cases
39 #-----------------------------------------------------------------------------
38 #-----------------------------------------------------------------------------
40
39
41
40
42 class MyConfigurable(Configurable):
41 class MyConfigurable(Configurable):
43 a = Int(1, config=True, help="The integer a.")
42 a = Int(1, config=True, help="The integer a.")
44 b = Float(1.0, config=True, help="The integer b.")
43 b = Float(1.0, config=True, help="The integer b.")
45 c = Unicode('no config')
44 c = Unicode('no config')
46
45
47
46
48 mc_help=u"""MyConfigurable options
47 mc_help=u"""MyConfigurable options
49 ----------------------
48 ----------------------
50 --MyConfigurable.a=<Int>
49 --MyConfigurable.a=<Int>
51 Default: 1
50 Default: 1
52 The integer a.
51 The integer a.
53 --MyConfigurable.b=<Float>
52 --MyConfigurable.b=<Float>
54 Default: 1.0
53 Default: 1.0
55 The integer b."""
54 The integer b."""
56
55
57 class Foo(Configurable):
56 class Foo(Configurable):
58 a = Int(0, config=True, help="The integer a.")
57 a = Int(0, config=True, help="The integer a.")
59 b = Unicode('nope', config=True)
58 b = Unicode('nope', config=True)
60
59
61
60
62 class Bar(Foo):
61 class Bar(Foo):
63 b = Unicode('gotit', config=False, help="The string b.")
62 b = Unicode('gotit', config=False, help="The string b.")
64 c = Float(config=True, help="The string c.")
63 c = Float(config=True, help="The string c.")
65
64
66
65
67 class TestConfigurable(TestCase):
66 class TestConfigurable(TestCase):
68
67
69 def test_default(self):
68 def test_default(self):
70 c1 = Configurable()
69 c1 = Configurable()
71 c2 = Configurable(config=c1.config)
70 c2 = Configurable(config=c1.config)
72 c3 = Configurable(config=c2.config)
71 c3 = Configurable(config=c2.config)
73 self.assertEquals(c1.config, c2.config)
72 self.assertEquals(c1.config, c2.config)
74 self.assertEquals(c2.config, c3.config)
73 self.assertEquals(c2.config, c3.config)
75
74
76 def test_custom(self):
75 def test_custom(self):
77 config = Config()
76 config = Config()
78 config.foo = 'foo'
77 config.foo = 'foo'
79 config.bar = 'bar'
78 config.bar = 'bar'
80 c1 = Configurable(config=config)
79 c1 = Configurable(config=config)
81 c2 = Configurable(config=c1.config)
80 c2 = Configurable(config=c1.config)
82 c3 = Configurable(config=c2.config)
81 c3 = Configurable(config=c2.config)
83 self.assertEquals(c1.config, config)
82 self.assertEquals(c1.config, config)
84 self.assertEquals(c2.config, config)
83 self.assertEquals(c2.config, config)
85 self.assertEquals(c3.config, config)
84 self.assertEquals(c3.config, config)
86 # Test that copies are not made
85 # Test that copies are not made
87 self.assert_(c1.config is config)
86 self.assert_(c1.config is config)
88 self.assert_(c2.config is config)
87 self.assert_(c2.config is config)
89 self.assert_(c3.config is config)
88 self.assert_(c3.config is config)
90 self.assert_(c1.config is c2.config)
89 self.assert_(c1.config is c2.config)
91 self.assert_(c2.config is c3.config)
90 self.assert_(c2.config is c3.config)
92
91
93 def test_inheritance(self):
92 def test_inheritance(self):
94 config = Config()
93 config = Config()
95 config.MyConfigurable.a = 2
94 config.MyConfigurable.a = 2
96 config.MyConfigurable.b = 2.0
95 config.MyConfigurable.b = 2.0
97 c1 = MyConfigurable(config=config)
96 c1 = MyConfigurable(config=config)
98 c2 = MyConfigurable(config=c1.config)
97 c2 = MyConfigurable(config=c1.config)
99 self.assertEquals(c1.a, config.MyConfigurable.a)
98 self.assertEquals(c1.a, config.MyConfigurable.a)
100 self.assertEquals(c1.b, config.MyConfigurable.b)
99 self.assertEquals(c1.b, config.MyConfigurable.b)
101 self.assertEquals(c2.a, config.MyConfigurable.a)
100 self.assertEquals(c2.a, config.MyConfigurable.a)
102 self.assertEquals(c2.b, config.MyConfigurable.b)
101 self.assertEquals(c2.b, config.MyConfigurable.b)
103
102
104 def test_parent(self):
103 def test_parent(self):
105 config = Config()
104 config = Config()
106 config.Foo.a = 10
105 config.Foo.a = 10
107 config.Foo.b = "wow"
106 config.Foo.b = "wow"
108 config.Bar.b = 'later'
107 config.Bar.b = 'later'
109 config.Bar.c = 100.0
108 config.Bar.c = 100.0
110 f = Foo(config=config)
109 f = Foo(config=config)
111 b = Bar(config=f.config)
110 b = Bar(config=f.config)
112 self.assertEquals(f.a, 10)
111 self.assertEquals(f.a, 10)
113 self.assertEquals(f.b, 'wow')
112 self.assertEquals(f.b, 'wow')
114 self.assertEquals(b.b, 'gotit')
113 self.assertEquals(b.b, 'gotit')
115 self.assertEquals(b.c, 100.0)
114 self.assertEquals(b.c, 100.0)
116
115
117 def test_override1(self):
116 def test_override1(self):
118 config = Config()
117 config = Config()
119 config.MyConfigurable.a = 2
118 config.MyConfigurable.a = 2
120 config.MyConfigurable.b = 2.0
119 config.MyConfigurable.b = 2.0
121 c = MyConfigurable(a=3, config=config)
120 c = MyConfigurable(a=3, config=config)
122 self.assertEquals(c.a, 3)
121 self.assertEquals(c.a, 3)
123 self.assertEquals(c.b, config.MyConfigurable.b)
122 self.assertEquals(c.b, config.MyConfigurable.b)
124 self.assertEquals(c.c, 'no config')
123 self.assertEquals(c.c, 'no config')
125
124
126 def test_override2(self):
125 def test_override2(self):
127 config = Config()
126 config = Config()
128 config.Foo.a = 1
127 config.Foo.a = 1
129 config.Bar.b = 'or' # Up above b is config=False, so this won't do it.
128 config.Bar.b = 'or' # Up above b is config=False, so this won't do it.
130 config.Bar.c = 10.0
129 config.Bar.c = 10.0
131 c = Bar(config=config)
130 c = Bar(config=config)
132 self.assertEquals(c.a, config.Foo.a)
131 self.assertEquals(c.a, config.Foo.a)
133 self.assertEquals(c.b, 'gotit')
132 self.assertEquals(c.b, 'gotit')
134 self.assertEquals(c.c, config.Bar.c)
133 self.assertEquals(c.c, config.Bar.c)
135 c = Bar(a=2, b='and', c=20.0, config=config)
134 c = Bar(a=2, b='and', c=20.0, config=config)
136 self.assertEquals(c.a, 2)
135 self.assertEquals(c.a, 2)
137 self.assertEquals(c.b, 'and')
136 self.assertEquals(c.b, 'and')
138 self.assertEquals(c.c, 20.0)
137 self.assertEquals(c.c, 20.0)
139
138
140 def test_help(self):
139 def test_help(self):
141 self.assertEquals(MyConfigurable.class_get_help(), mc_help)
140 self.assertEquals(MyConfigurable.class_get_help(), mc_help)
142
141
143
142
144 class TestSingletonConfigurable(TestCase):
143 class TestSingletonConfigurable(TestCase):
145
144
146 def test_instance(self):
145 def test_instance(self):
147 from IPython.config.configurable import SingletonConfigurable
146 from IPython.config.configurable import SingletonConfigurable
148 class Foo(SingletonConfigurable): pass
147 class Foo(SingletonConfigurable): pass
149 self.assertEquals(Foo.initialized(), False)
148 self.assertEquals(Foo.initialized(), False)
150 foo = Foo.instance()
149 foo = Foo.instance()
151 self.assertEquals(Foo.initialized(), True)
150 self.assertEquals(Foo.initialized(), True)
152 self.assertEquals(foo, Foo.instance())
151 self.assertEquals(foo, Foo.instance())
153 self.assertEquals(SingletonConfigurable._instance, None)
152 self.assertEquals(SingletonConfigurable._instance, None)
154
153
155 def test_inheritance(self):
154 def test_inheritance(self):
156 class Bar(SingletonConfigurable): pass
155 class Bar(SingletonConfigurable): pass
157 class Bam(Bar): pass
156 class Bam(Bar): pass
158 self.assertEquals(Bar.initialized(), False)
157 self.assertEquals(Bar.initialized(), False)
159 self.assertEquals(Bam.initialized(), False)
158 self.assertEquals(Bam.initialized(), False)
160 bam = Bam.instance()
159 bam = Bam.instance()
161 bam == Bar.instance()
160 bam == Bar.instance()
162 self.assertEquals(Bar.initialized(), True)
161 self.assertEquals(Bar.initialized(), True)
163 self.assertEquals(Bam.initialized(), True)
162 self.assertEquals(Bam.initialized(), True)
164 self.assertEquals(bam, Bam._instance)
163 self.assertEquals(bam, Bam._instance)
165 self.assertEquals(bam, Bar._instance)
164 self.assertEquals(bam, Bar._instance)
166 self.assertEquals(SingletonConfigurable._instance, None)
165 self.assertEquals(SingletonConfigurable._instance, None)
@@ -1,226 +1,225 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 Tests for IPython.config.loader
3 Tests for IPython.config.loader
5
4
6 Authors:
5 Authors:
7
6
8 * Brian Granger
7 * Brian Granger
9 * Fernando Perez (design help)
8 * Fernando Perez (design help)
10 """
9 """
11
10
12 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
13 # Copyright (C) 2008-2009 The IPython Development Team
12 # Copyright (C) 2008-2009 The IPython Development Team
14 #
13 #
15 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
16 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
17 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
18
17
19 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
20 # Imports
19 # Imports
21 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
22
21
23 import os
22 import os
24 import sys
23 import sys
25 from tempfile import mkstemp
24 from tempfile import mkstemp
26 from unittest import TestCase
25 from unittest import TestCase
27
26
28 from nose import SkipTest
27 from nose import SkipTest
29
28
30 from IPython.testing.tools import mute_warn
29 from IPython.testing.tools import mute_warn
31
30
32 from IPython.utils.traitlets import Int, Unicode
31 from IPython.utils.traitlets import Int, Unicode
33 from IPython.config.configurable import Configurable
32 from IPython.config.configurable import Configurable
34 from IPython.config.loader import (
33 from IPython.config.loader import (
35 Config,
34 Config,
36 PyFileConfigLoader,
35 PyFileConfigLoader,
37 KeyValueConfigLoader,
36 KeyValueConfigLoader,
38 ArgParseConfigLoader,
37 ArgParseConfigLoader,
39 ConfigError
38 ConfigError
40 )
39 )
41
40
42 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
43 # Actual tests
42 # Actual tests
44 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
45
44
46
45
47 pyfile = """
46 pyfile = """
48 c = get_config()
47 c = get_config()
49 c.a=10
48 c.a=10
50 c.b=20
49 c.b=20
51 c.Foo.Bar.value=10
50 c.Foo.Bar.value=10
52 c.Foo.Bam.value=range(10)
51 c.Foo.Bam.value=range(10)
53 c.D.C.value='hi there'
52 c.D.C.value='hi there'
54 """
53 """
55
54
56 class TestPyFileCL(TestCase):
55 class TestPyFileCL(TestCase):
57
56
58 def test_basic(self):
57 def test_basic(self):
59 fd, fname = mkstemp('.py')
58 fd, fname = mkstemp('.py')
60 f = os.fdopen(fd, 'w')
59 f = os.fdopen(fd, 'w')
61 f.write(pyfile)
60 f.write(pyfile)
62 f.close()
61 f.close()
63 # Unlink the file
62 # Unlink the file
64 cl = PyFileConfigLoader(fname)
63 cl = PyFileConfigLoader(fname)
65 config = cl.load_config()
64 config = cl.load_config()
66 self.assertEquals(config.a, 10)
65 self.assertEquals(config.a, 10)
67 self.assertEquals(config.b, 20)
66 self.assertEquals(config.b, 20)
68 self.assertEquals(config.Foo.Bar.value, 10)
67 self.assertEquals(config.Foo.Bar.value, 10)
69 self.assertEquals(config.Foo.Bam.value, range(10))
68 self.assertEquals(config.Foo.Bam.value, range(10))
70 self.assertEquals(config.D.C.value, 'hi there')
69 self.assertEquals(config.D.C.value, 'hi there')
71
70
72 class MyLoader1(ArgParseConfigLoader):
71 class MyLoader1(ArgParseConfigLoader):
73 def _add_arguments(self):
72 def _add_arguments(self):
74 p = self.parser
73 p = self.parser
75 p.add_argument('-f', '--foo', dest='Global.foo', type=str)
74 p.add_argument('-f', '--foo', dest='Global.foo', type=str)
76 p.add_argument('-b', dest='MyClass.bar', type=int)
75 p.add_argument('-b', dest='MyClass.bar', type=int)
77 p.add_argument('-n', dest='n', action='store_true')
76 p.add_argument('-n', dest='n', action='store_true')
78 p.add_argument('Global.bam', type=str)
77 p.add_argument('Global.bam', type=str)
79
78
80 class MyLoader2(ArgParseConfigLoader):
79 class MyLoader2(ArgParseConfigLoader):
81 def _add_arguments(self):
80 def _add_arguments(self):
82 subparsers = self.parser.add_subparsers(dest='subparser_name')
81 subparsers = self.parser.add_subparsers(dest='subparser_name')
83 subparser1 = subparsers.add_parser('1')
82 subparser1 = subparsers.add_parser('1')
84 subparser1.add_argument('-x',dest='Global.x')
83 subparser1.add_argument('-x',dest='Global.x')
85 subparser2 = subparsers.add_parser('2')
84 subparser2 = subparsers.add_parser('2')
86 subparser2.add_argument('y')
85 subparser2.add_argument('y')
87
86
88 class TestArgParseCL(TestCase):
87 class TestArgParseCL(TestCase):
89
88
90 def test_basic(self):
89 def test_basic(self):
91 cl = MyLoader1()
90 cl = MyLoader1()
92 config = cl.load_config('-f hi -b 10 -n wow'.split())
91 config = cl.load_config('-f hi -b 10 -n wow'.split())
93 self.assertEquals(config.Global.foo, 'hi')
92 self.assertEquals(config.Global.foo, 'hi')
94 self.assertEquals(config.MyClass.bar, 10)
93 self.assertEquals(config.MyClass.bar, 10)
95 self.assertEquals(config.n, True)
94 self.assertEquals(config.n, True)
96 self.assertEquals(config.Global.bam, 'wow')
95 self.assertEquals(config.Global.bam, 'wow')
97 config = cl.load_config(['wow'])
96 config = cl.load_config(['wow'])
98 self.assertEquals(config.keys(), ['Global'])
97 self.assertEquals(config.keys(), ['Global'])
99 self.assertEquals(config.Global.keys(), ['bam'])
98 self.assertEquals(config.Global.keys(), ['bam'])
100 self.assertEquals(config.Global.bam, 'wow')
99 self.assertEquals(config.Global.bam, 'wow')
101
100
102 def test_add_arguments(self):
101 def test_add_arguments(self):
103 cl = MyLoader2()
102 cl = MyLoader2()
104 config = cl.load_config('2 frobble'.split())
103 config = cl.load_config('2 frobble'.split())
105 self.assertEquals(config.subparser_name, '2')
104 self.assertEquals(config.subparser_name, '2')
106 self.assertEquals(config.y, 'frobble')
105 self.assertEquals(config.y, 'frobble')
107 config = cl.load_config('1 -x frobble'.split())
106 config = cl.load_config('1 -x frobble'.split())
108 self.assertEquals(config.subparser_name, '1')
107 self.assertEquals(config.subparser_name, '1')
109 self.assertEquals(config.Global.x, 'frobble')
108 self.assertEquals(config.Global.x, 'frobble')
110
109
111 def test_argv(self):
110 def test_argv(self):
112 cl = MyLoader1(argv='-f hi -b 10 -n wow'.split())
111 cl = MyLoader1(argv='-f hi -b 10 -n wow'.split())
113 config = cl.load_config()
112 config = cl.load_config()
114 self.assertEquals(config.Global.foo, 'hi')
113 self.assertEquals(config.Global.foo, 'hi')
115 self.assertEquals(config.MyClass.bar, 10)
114 self.assertEquals(config.MyClass.bar, 10)
116 self.assertEquals(config.n, True)
115 self.assertEquals(config.n, True)
117 self.assertEquals(config.Global.bam, 'wow')
116 self.assertEquals(config.Global.bam, 'wow')
118
117
119
118
120 class TestKeyValueCL(TestCase):
119 class TestKeyValueCL(TestCase):
121
120
122 def test_basic(self):
121 def test_basic(self):
123 cl = KeyValueConfigLoader()
122 cl = KeyValueConfigLoader()
124 argv = ['--'+s.strip('c.') for s in pyfile.split('\n')[2:-1]]
123 argv = ['--'+s.strip('c.') for s in pyfile.split('\n')[2:-1]]
125 with mute_warn():
124 with mute_warn():
126 config = cl.load_config(argv)
125 config = cl.load_config(argv)
127 self.assertEquals(config.a, 10)
126 self.assertEquals(config.a, 10)
128 self.assertEquals(config.b, 20)
127 self.assertEquals(config.b, 20)
129 self.assertEquals(config.Foo.Bar.value, 10)
128 self.assertEquals(config.Foo.Bar.value, 10)
130 self.assertEquals(config.Foo.Bam.value, range(10))
129 self.assertEquals(config.Foo.Bam.value, range(10))
131 self.assertEquals(config.D.C.value, 'hi there')
130 self.assertEquals(config.D.C.value, 'hi there')
132
131
133 def test_extra_args(self):
132 def test_extra_args(self):
134 cl = KeyValueConfigLoader()
133 cl = KeyValueConfigLoader()
135 with mute_warn():
134 with mute_warn():
136 config = cl.load_config(['--a=5', 'b', '--c=10', 'd'])
135 config = cl.load_config(['--a=5', 'b', '--c=10', 'd'])
137 self.assertEquals(cl.extra_args, ['b', 'd'])
136 self.assertEquals(cl.extra_args, ['b', 'd'])
138 self.assertEquals(config.a, 5)
137 self.assertEquals(config.a, 5)
139 self.assertEquals(config.c, 10)
138 self.assertEquals(config.c, 10)
140 with mute_warn():
139 with mute_warn():
141 config = cl.load_config(['--', '--a=5', '--c=10'])
140 config = cl.load_config(['--', '--a=5', '--c=10'])
142 self.assertEquals(cl.extra_args, ['--a=5', '--c=10'])
141 self.assertEquals(cl.extra_args, ['--a=5', '--c=10'])
143
142
144 def test_unicode_args(self):
143 def test_unicode_args(self):
145 cl = KeyValueConfigLoader()
144 cl = KeyValueConfigLoader()
146 argv = [u'--a=épsîlön']
145 argv = [u'--a=épsîlön']
147 with mute_warn():
146 with mute_warn():
148 config = cl.load_config(argv)
147 config = cl.load_config(argv)
149 self.assertEquals(config.a, u'épsîlön')
148 self.assertEquals(config.a, u'épsîlön')
150
149
151 def test_unicode_bytes_args(self):
150 def test_unicode_bytes_args(self):
152 uarg = u'--a=é'
151 uarg = u'--a=é'
153 try:
152 try:
154 barg = uarg.encode(sys.stdin.encoding)
153 barg = uarg.encode(sys.stdin.encoding)
155 except (TypeError, UnicodeEncodeError):
154 except (TypeError, UnicodeEncodeError):
156 raise SkipTest("sys.stdin.encoding can't handle 'é'")
155 raise SkipTest("sys.stdin.encoding can't handle 'é'")
157
156
158 cl = KeyValueConfigLoader()
157 cl = KeyValueConfigLoader()
159 with mute_warn():
158 with mute_warn():
160 config = cl.load_config([barg])
159 config = cl.load_config([barg])
161 self.assertEquals(config.a, u'é')
160 self.assertEquals(config.a, u'é')
162
161
163
162
164 class TestConfig(TestCase):
163 class TestConfig(TestCase):
165
164
166 def test_setget(self):
165 def test_setget(self):
167 c = Config()
166 c = Config()
168 c.a = 10
167 c.a = 10
169 self.assertEquals(c.a, 10)
168 self.assertEquals(c.a, 10)
170 self.assertEquals(c.has_key('b'), False)
169 self.assertEquals(c.has_key('b'), False)
171
170
172 def test_auto_section(self):
171 def test_auto_section(self):
173 c = Config()
172 c = Config()
174 self.assertEquals(c.has_key('A'), True)
173 self.assertEquals(c.has_key('A'), True)
175 self.assertEquals(c._has_section('A'), False)
174 self.assertEquals(c._has_section('A'), False)
176 A = c.A
175 A = c.A
177 A.foo = 'hi there'
176 A.foo = 'hi there'
178 self.assertEquals(c._has_section('A'), True)
177 self.assertEquals(c._has_section('A'), True)
179 self.assertEquals(c.A.foo, 'hi there')
178 self.assertEquals(c.A.foo, 'hi there')
180 del c.A
179 del c.A
181 self.assertEquals(len(c.A.keys()),0)
180 self.assertEquals(len(c.A.keys()),0)
182
181
183 def test_merge_doesnt_exist(self):
182 def test_merge_doesnt_exist(self):
184 c1 = Config()
183 c1 = Config()
185 c2 = Config()
184 c2 = Config()
186 c2.bar = 10
185 c2.bar = 10
187 c2.Foo.bar = 10
186 c2.Foo.bar = 10
188 c1._merge(c2)
187 c1._merge(c2)
189 self.assertEquals(c1.Foo.bar, 10)
188 self.assertEquals(c1.Foo.bar, 10)
190 self.assertEquals(c1.bar, 10)
189 self.assertEquals(c1.bar, 10)
191 c2.Bar.bar = 10
190 c2.Bar.bar = 10
192 c1._merge(c2)
191 c1._merge(c2)
193 self.assertEquals(c1.Bar.bar, 10)
192 self.assertEquals(c1.Bar.bar, 10)
194
193
195 def test_merge_exists(self):
194 def test_merge_exists(self):
196 c1 = Config()
195 c1 = Config()
197 c2 = Config()
196 c2 = Config()
198 c1.Foo.bar = 10
197 c1.Foo.bar = 10
199 c1.Foo.bam = 30
198 c1.Foo.bam = 30
200 c2.Foo.bar = 20
199 c2.Foo.bar = 20
201 c2.Foo.wow = 40
200 c2.Foo.wow = 40
202 c1._merge(c2)
201 c1._merge(c2)
203 self.assertEquals(c1.Foo.bam, 30)
202 self.assertEquals(c1.Foo.bam, 30)
204 self.assertEquals(c1.Foo.bar, 20)
203 self.assertEquals(c1.Foo.bar, 20)
205 self.assertEquals(c1.Foo.wow, 40)
204 self.assertEquals(c1.Foo.wow, 40)
206 c2.Foo.Bam.bam = 10
205 c2.Foo.Bam.bam = 10
207 c1._merge(c2)
206 c1._merge(c2)
208 self.assertEquals(c1.Foo.Bam.bam, 10)
207 self.assertEquals(c1.Foo.Bam.bam, 10)
209
208
210 def test_deepcopy(self):
209 def test_deepcopy(self):
211 c1 = Config()
210 c1 = Config()
212 c1.Foo.bar = 10
211 c1.Foo.bar = 10
213 c1.Foo.bam = 30
212 c1.Foo.bam = 30
214 c1.a = 'asdf'
213 c1.a = 'asdf'
215 c1.b = range(10)
214 c1.b = range(10)
216 import copy
215 import copy
217 c2 = copy.deepcopy(c1)
216 c2 = copy.deepcopy(c1)
218 self.assertEquals(c1, c2)
217 self.assertEquals(c1, c2)
219 self.assert_(c1 is not c2)
218 self.assert_(c1 is not c2)
220 self.assert_(c1.Foo is not c2.Foo)
219 self.assert_(c1.Foo is not c2.Foo)
221
220
222 def test_builtin(self):
221 def test_builtin(self):
223 c1 = Config()
222 c1 = Config()
224 exec 'foo = True' in c1
223 exec 'foo = True' in c1
225 self.assertEquals(c1.foo, True)
224 self.assertEquals(c1.foo, True)
226 self.assertRaises(ConfigError, setattr, c1, 'ValueError', 10)
225 self.assertRaises(ConfigError, setattr, c1, 'ValueError', 10)
@@ -1,264 +1,263 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 System command aliases.
3 System command aliases.
5
4
6 Authors:
5 Authors:
7
6
8 * Fernando Perez
7 * Fernando Perez
9 * Brian Granger
8 * Brian Granger
10 """
9 """
11
10
12 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
13 # Copyright (C) 2008-2010 The IPython Development Team
12 # Copyright (C) 2008-2010 The IPython Development Team
14 #
13 #
15 # Distributed under the terms of the BSD License.
14 # Distributed under the terms of the BSD License.
16 #
15 #
17 # The full license is in the file COPYING.txt, distributed with this software.
16 # The full license is in the file COPYING.txt, distributed with this software.
18 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
19
18
20 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
21 # Imports
20 # Imports
22 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
23
22
24 import __builtin__
23 import __builtin__
25 import keyword
24 import keyword
26 import os
25 import os
27 import re
26 import re
28 import sys
27 import sys
29
28
30 from IPython.config.configurable import Configurable
29 from IPython.config.configurable import Configurable
31 from IPython.core.splitinput import split_user_input
30 from IPython.core.splitinput import split_user_input
32
31
33 from IPython.utils.traitlets import List, Instance
32 from IPython.utils.traitlets import List, Instance
34 from IPython.utils.autoattr import auto_attr
33 from IPython.utils.autoattr import auto_attr
35 from IPython.utils.warn import warn, error
34 from IPython.utils.warn import warn, error
36
35
37 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
38 # Utilities
37 # Utilities
39 #-----------------------------------------------------------------------------
38 #-----------------------------------------------------------------------------
40
39
41 # This is used as the pattern for calls to split_user_input.
40 # This is used as the pattern for calls to split_user_input.
42 shell_line_split = re.compile(r'^(\s*)(\S*\s*)(.*$)')
41 shell_line_split = re.compile(r'^(\s*)(\S*\s*)(.*$)')
43
42
44 def default_aliases():
43 def default_aliases():
45 """Return list of shell aliases to auto-define.
44 """Return list of shell aliases to auto-define.
46 """
45 """
47 # Note: the aliases defined here should be safe to use on a kernel
46 # Note: the aliases defined here should be safe to use on a kernel
48 # regardless of what frontend it is attached to. Frontends that use a
47 # regardless of what frontend it is attached to. Frontends that use a
49 # kernel in-process can define additional aliases that will only work in
48 # kernel in-process can define additional aliases that will only work in
50 # their case. For example, things like 'less' or 'clear' that manipulate
49 # their case. For example, things like 'less' or 'clear' that manipulate
51 # the terminal should NOT be declared here, as they will only work if the
50 # the terminal should NOT be declared here, as they will only work if the
52 # kernel is running inside a true terminal, and not over the network.
51 # kernel is running inside a true terminal, and not over the network.
53
52
54 if os.name == 'posix':
53 if os.name == 'posix':
55 default_aliases = [('mkdir', 'mkdir'), ('rmdir', 'rmdir'),
54 default_aliases = [('mkdir', 'mkdir'), ('rmdir', 'rmdir'),
56 ('mv', 'mv -i'), ('rm', 'rm -i'), ('cp', 'cp -i'),
55 ('mv', 'mv -i'), ('rm', 'rm -i'), ('cp', 'cp -i'),
57 ('cat', 'cat'),
56 ('cat', 'cat'),
58 ]
57 ]
59 # Useful set of ls aliases. The GNU and BSD options are a little
58 # Useful set of ls aliases. The GNU and BSD options are a little
60 # different, so we make aliases that provide as similar as possible
59 # different, so we make aliases that provide as similar as possible
61 # behavior in ipython, by passing the right flags for each platform
60 # behavior in ipython, by passing the right flags for each platform
62 if sys.platform.startswith('linux'):
61 if sys.platform.startswith('linux'):
63 ls_aliases = [('ls', 'ls -F --color'),
62 ls_aliases = [('ls', 'ls -F --color'),
64 # long ls
63 # long ls
65 ('ll', 'ls -F -o --color'),
64 ('ll', 'ls -F -o --color'),
66 # ls normal files only
65 # ls normal files only
67 ('lf', 'ls -F -o --color %l | grep ^-'),
66 ('lf', 'ls -F -o --color %l | grep ^-'),
68 # ls symbolic links
67 # ls symbolic links
69 ('lk', 'ls -F -o --color %l | grep ^l'),
68 ('lk', 'ls -F -o --color %l | grep ^l'),
70 # directories or links to directories,
69 # directories or links to directories,
71 ('ldir', 'ls -F -o --color %l | grep /$'),
70 ('ldir', 'ls -F -o --color %l | grep /$'),
72 # things which are executable
71 # things which are executable
73 ('lx', 'ls -F -o --color %l | grep ^-..x'),
72 ('lx', 'ls -F -o --color %l | grep ^-..x'),
74 ]
73 ]
75 else:
74 else:
76 # BSD, OSX, etc.
75 # BSD, OSX, etc.
77 ls_aliases = [('ls', 'ls -F'),
76 ls_aliases = [('ls', 'ls -F'),
78 # long ls
77 # long ls
79 ('ll', 'ls -F -l'),
78 ('ll', 'ls -F -l'),
80 # ls normal files only
79 # ls normal files only
81 ('lf', 'ls -F -l %l | grep ^-'),
80 ('lf', 'ls -F -l %l | grep ^-'),
82 # ls symbolic links
81 # ls symbolic links
83 ('lk', 'ls -F -l %l | grep ^l'),
82 ('lk', 'ls -F -l %l | grep ^l'),
84 # directories or links to directories,
83 # directories or links to directories,
85 ('ldir', 'ls -F -l %l | grep /$'),
84 ('ldir', 'ls -F -l %l | grep /$'),
86 # things which are executable
85 # things which are executable
87 ('lx', 'ls -F -l %l | grep ^-..x'),
86 ('lx', 'ls -F -l %l | grep ^-..x'),
88 ]
87 ]
89 default_aliases = default_aliases + ls_aliases
88 default_aliases = default_aliases + ls_aliases
90 elif os.name in ['nt', 'dos']:
89 elif os.name in ['nt', 'dos']:
91 default_aliases = [('ls', 'dir /on'),
90 default_aliases = [('ls', 'dir /on'),
92 ('ddir', 'dir /ad /on'), ('ldir', 'dir /ad /on'),
91 ('ddir', 'dir /ad /on'), ('ldir', 'dir /ad /on'),
93 ('mkdir', 'mkdir'), ('rmdir', 'rmdir'),
92 ('mkdir', 'mkdir'), ('rmdir', 'rmdir'),
94 ('echo', 'echo'), ('ren', 'ren'), ('copy', 'copy'),
93 ('echo', 'echo'), ('ren', 'ren'), ('copy', 'copy'),
95 ]
94 ]
96 else:
95 else:
97 default_aliases = []
96 default_aliases = []
98
97
99 return default_aliases
98 return default_aliases
100
99
101
100
102 class AliasError(Exception):
101 class AliasError(Exception):
103 pass
102 pass
104
103
105
104
106 class InvalidAliasError(AliasError):
105 class InvalidAliasError(AliasError):
107 pass
106 pass
108
107
109 #-----------------------------------------------------------------------------
108 #-----------------------------------------------------------------------------
110 # Main AliasManager class
109 # Main AliasManager class
111 #-----------------------------------------------------------------------------
110 #-----------------------------------------------------------------------------
112
111
113 class AliasManager(Configurable):
112 class AliasManager(Configurable):
114
113
115 default_aliases = List(default_aliases(), config=True)
114 default_aliases = List(default_aliases(), config=True)
116 user_aliases = List(default_value=[], config=True)
115 user_aliases = List(default_value=[], config=True)
117 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
116 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
118
117
119 def __init__(self, shell=None, config=None):
118 def __init__(self, shell=None, config=None):
120 super(AliasManager, self).__init__(shell=shell, config=config)
119 super(AliasManager, self).__init__(shell=shell, config=config)
121 self.alias_table = {}
120 self.alias_table = {}
122 self.exclude_aliases()
121 self.exclude_aliases()
123 self.init_aliases()
122 self.init_aliases()
124
123
125 def __contains__(self, name):
124 def __contains__(self, name):
126 return name in self.alias_table
125 return name in self.alias_table
127
126
128 @property
127 @property
129 def aliases(self):
128 def aliases(self):
130 return [(item[0], item[1][1]) for item in self.alias_table.iteritems()]
129 return [(item[0], item[1][1]) for item in self.alias_table.iteritems()]
131
130
132 def exclude_aliases(self):
131 def exclude_aliases(self):
133 # set of things NOT to alias (keywords, builtins and some magics)
132 # set of things NOT to alias (keywords, builtins and some magics)
134 no_alias = set(['cd','popd','pushd','dhist','alias','unalias'])
133 no_alias = set(['cd','popd','pushd','dhist','alias','unalias'])
135 no_alias.update(set(keyword.kwlist))
134 no_alias.update(set(keyword.kwlist))
136 no_alias.update(set(__builtin__.__dict__.keys()))
135 no_alias.update(set(__builtin__.__dict__.keys()))
137 self.no_alias = no_alias
136 self.no_alias = no_alias
138
137
139 def init_aliases(self):
138 def init_aliases(self):
140 # Load default aliases
139 # Load default aliases
141 for name, cmd in self.default_aliases:
140 for name, cmd in self.default_aliases:
142 self.soft_define_alias(name, cmd)
141 self.soft_define_alias(name, cmd)
143
142
144 # Load user aliases
143 # Load user aliases
145 for name, cmd in self.user_aliases:
144 for name, cmd in self.user_aliases:
146 self.soft_define_alias(name, cmd)
145 self.soft_define_alias(name, cmd)
147
146
148 def clear_aliases(self):
147 def clear_aliases(self):
149 self.alias_table.clear()
148 self.alias_table.clear()
150
149
151 def soft_define_alias(self, name, cmd):
150 def soft_define_alias(self, name, cmd):
152 """Define an alias, but don't raise on an AliasError."""
151 """Define an alias, but don't raise on an AliasError."""
153 try:
152 try:
154 self.define_alias(name, cmd)
153 self.define_alias(name, cmd)
155 except AliasError, e:
154 except AliasError, e:
156 error("Invalid alias: %s" % e)
155 error("Invalid alias: %s" % e)
157
156
158 def define_alias(self, name, cmd):
157 def define_alias(self, name, cmd):
159 """Define a new alias after validating it.
158 """Define a new alias after validating it.
160
159
161 This will raise an :exc:`AliasError` if there are validation
160 This will raise an :exc:`AliasError` if there are validation
162 problems.
161 problems.
163 """
162 """
164 nargs = self.validate_alias(name, cmd)
163 nargs = self.validate_alias(name, cmd)
165 self.alias_table[name] = (nargs, cmd)
164 self.alias_table[name] = (nargs, cmd)
166
165
167 def undefine_alias(self, name):
166 def undefine_alias(self, name):
168 if self.alias_table.has_key(name):
167 if self.alias_table.has_key(name):
169 del self.alias_table[name]
168 del self.alias_table[name]
170
169
171 def validate_alias(self, name, cmd):
170 def validate_alias(self, name, cmd):
172 """Validate an alias and return the its number of arguments."""
171 """Validate an alias and return the its number of arguments."""
173 if name in self.no_alias:
172 if name in self.no_alias:
174 raise InvalidAliasError("The name %s can't be aliased "
173 raise InvalidAliasError("The name %s can't be aliased "
175 "because it is a keyword or builtin." % name)
174 "because it is a keyword or builtin." % name)
176 if not (isinstance(cmd, basestring)):
175 if not (isinstance(cmd, basestring)):
177 raise InvalidAliasError("An alias command must be a string, "
176 raise InvalidAliasError("An alias command must be a string, "
178 "got: %r" % name)
177 "got: %r" % name)
179 nargs = cmd.count('%s')
178 nargs = cmd.count('%s')
180 if nargs>0 and cmd.find('%l')>=0:
179 if nargs>0 and cmd.find('%l')>=0:
181 raise InvalidAliasError('The %s and %l specifiers are mutually '
180 raise InvalidAliasError('The %s and %l specifiers are mutually '
182 'exclusive in alias definitions.')
181 'exclusive in alias definitions.')
183 return nargs
182 return nargs
184
183
185 def call_alias(self, alias, rest=''):
184 def call_alias(self, alias, rest=''):
186 """Call an alias given its name and the rest of the line."""
185 """Call an alias given its name and the rest of the line."""
187 cmd = self.transform_alias(alias, rest)
186 cmd = self.transform_alias(alias, rest)
188 try:
187 try:
189 self.shell.system(cmd)
188 self.shell.system(cmd)
190 except:
189 except:
191 self.shell.showtraceback()
190 self.shell.showtraceback()
192
191
193 def transform_alias(self, alias,rest=''):
192 def transform_alias(self, alias,rest=''):
194 """Transform alias to system command string."""
193 """Transform alias to system command string."""
195 nargs, cmd = self.alias_table[alias]
194 nargs, cmd = self.alias_table[alias]
196
195
197 if ' ' in cmd and os.path.isfile(cmd):
196 if ' ' in cmd and os.path.isfile(cmd):
198 cmd = '"%s"' % cmd
197 cmd = '"%s"' % cmd
199
198
200 # Expand the %l special to be the user's input line
199 # Expand the %l special to be the user's input line
201 if cmd.find('%l') >= 0:
200 if cmd.find('%l') >= 0:
202 cmd = cmd.replace('%l', rest)
201 cmd = cmd.replace('%l', rest)
203 rest = ''
202 rest = ''
204 if nargs==0:
203 if nargs==0:
205 # Simple, argument-less aliases
204 # Simple, argument-less aliases
206 cmd = '%s %s' % (cmd, rest)
205 cmd = '%s %s' % (cmd, rest)
207 else:
206 else:
208 # Handle aliases with positional arguments
207 # Handle aliases with positional arguments
209 args = rest.split(None, nargs)
208 args = rest.split(None, nargs)
210 if len(args) < nargs:
209 if len(args) < nargs:
211 raise AliasError('Alias <%s> requires %s arguments, %s given.' %
210 raise AliasError('Alias <%s> requires %s arguments, %s given.' %
212 (alias, nargs, len(args)))
211 (alias, nargs, len(args)))
213 cmd = '%s %s' % (cmd % tuple(args[:nargs]),' '.join(args[nargs:]))
212 cmd = '%s %s' % (cmd % tuple(args[:nargs]),' '.join(args[nargs:]))
214 return cmd
213 return cmd
215
214
216 def expand_alias(self, line):
215 def expand_alias(self, line):
217 """ Expand an alias in the command line
216 """ Expand an alias in the command line
218
217
219 Returns the provided command line, possibly with the first word
218 Returns the provided command line, possibly with the first word
220 (command) translated according to alias expansion rules.
219 (command) translated according to alias expansion rules.
221
220
222 [ipython]|16> _ip.expand_aliases("np myfile.txt")
221 [ipython]|16> _ip.expand_aliases("np myfile.txt")
223 <16> 'q:/opt/np/notepad++.exe myfile.txt'
222 <16> 'q:/opt/np/notepad++.exe myfile.txt'
224 """
223 """
225
224
226 pre,fn,rest = split_user_input(line)
225 pre,fn,rest = split_user_input(line)
227 res = pre + self.expand_aliases(fn, rest)
226 res = pre + self.expand_aliases(fn, rest)
228 return res
227 return res
229
228
230 def expand_aliases(self, fn, rest):
229 def expand_aliases(self, fn, rest):
231 """Expand multiple levels of aliases:
230 """Expand multiple levels of aliases:
232
231
233 if:
232 if:
234
233
235 alias foo bar /tmp
234 alias foo bar /tmp
236 alias baz foo
235 alias baz foo
237
236
238 then:
237 then:
239
238
240 baz huhhahhei -> bar /tmp huhhahhei
239 baz huhhahhei -> bar /tmp huhhahhei
241 """
240 """
242 line = fn + " " + rest
241 line = fn + " " + rest
243
242
244 done = set()
243 done = set()
245 while 1:
244 while 1:
246 pre,fn,rest = split_user_input(line, shell_line_split)
245 pre,fn,rest = split_user_input(line, shell_line_split)
247 if fn in self.alias_table:
246 if fn in self.alias_table:
248 if fn in done:
247 if fn in done:
249 warn("Cyclic alias definition, repeated '%s'" % fn)
248 warn("Cyclic alias definition, repeated '%s'" % fn)
250 return ""
249 return ""
251 done.add(fn)
250 done.add(fn)
252
251
253 l2 = self.transform_alias(fn, rest)
252 l2 = self.transform_alias(fn, rest)
254 if l2 == line:
253 if l2 == line:
255 break
254 break
256 # ls -> ls -F should not recurse forever
255 # ls -> ls -F should not recurse forever
257 if l2.split(None,1)[0] == line.split(None,1)[0]:
256 if l2.split(None,1)[0] == line.split(None,1)[0]:
258 line = l2
257 line = l2
259 break
258 break
260 line=l2
259 line=l2
261 else:
260 else:
262 break
261 break
263
262
264 return line
263 return line
@@ -1,71 +1,70 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 Autocall capabilities for IPython.core.
3 Autocall capabilities for IPython.core.
5
4
6 Authors:
5 Authors:
7
6
8 * Brian Granger
7 * Brian Granger
9 * Fernando Perez
8 * Fernando Perez
10 * Thomas Kluyver
9 * Thomas Kluyver
11
10
12 Notes
11 Notes
13 -----
12 -----
14 """
13 """
15
14
16 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
17 # Copyright (C) 2008-2009 The IPython Development Team
16 # Copyright (C) 2008-2009 The IPython Development Team
18 #
17 #
19 # Distributed under the terms of the BSD License. The full license is in
18 # Distributed under the terms of the BSD License. The full license is in
20 # the file COPYING, distributed as part of this software.
19 # the file COPYING, distributed as part of this software.
21 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
22
21
23 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
24 # Imports
23 # Imports
25 #-----------------------------------------------------------------------------
24 #-----------------------------------------------------------------------------
26
25
27
26
28 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
29 # Code
28 # Code
30 #-----------------------------------------------------------------------------
29 #-----------------------------------------------------------------------------
31
30
32 class IPyAutocall(object):
31 class IPyAutocall(object):
33 """ Instances of this class are always autocalled
32 """ Instances of this class are always autocalled
34
33
35 This happens regardless of 'autocall' variable state. Use this to
34 This happens regardless of 'autocall' variable state. Use this to
36 develop macro-like mechanisms.
35 develop macro-like mechanisms.
37 """
36 """
38 _ip = None
37 _ip = None
39 rewrite = True
38 rewrite = True
40 def __init__(self, ip=None):
39 def __init__(self, ip=None):
41 self._ip = ip
40 self._ip = ip
42
41
43 def set_ip(self, ip):
42 def set_ip(self, ip):
44 """ Will be used to set _ip point to current ipython instance b/f call
43 """ Will be used to set _ip point to current ipython instance b/f call
45
44
46 Override this method if you don't want this to happen.
45 Override this method if you don't want this to happen.
47
46
48 """
47 """
49 self._ip = ip
48 self._ip = ip
50
49
51
50
52 class ExitAutocall(IPyAutocall):
51 class ExitAutocall(IPyAutocall):
53 """An autocallable object which will be added to the user namespace so that
52 """An autocallable object which will be added to the user namespace so that
54 exit, exit(), quit or quit() are all valid ways to close the shell."""
53 exit, exit(), quit or quit() are all valid ways to close the shell."""
55 rewrite = False
54 rewrite = False
56
55
57 def __call__(self):
56 def __call__(self):
58 self._ip.ask_exit()
57 self._ip.ask_exit()
59
58
60 class ZMQExitAutocall(ExitAutocall):
59 class ZMQExitAutocall(ExitAutocall):
61 """Exit IPython. Autocallable, so it needn't be explicitly called.
60 """Exit IPython. Autocallable, so it needn't be explicitly called.
62
61
63 Parameters
62 Parameters
64 ----------
63 ----------
65 keep_kernel : bool
64 keep_kernel : bool
66 If True, leave the kernel alive. Otherwise, tell the kernel to exit too
65 If True, leave the kernel alive. Otherwise, tell the kernel to exit too
67 (default).
66 (default).
68 """
67 """
69 def __call__(self, keep_kernel=False):
68 def __call__(self, keep_kernel=False):
70 self._ip.keepkernel_on_exit = keep_kernel
69 self._ip.keepkernel_on_exit = keep_kernel
71 self._ip.ask_exit()
70 self._ip.ask_exit()
@@ -1,71 +1,70 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 A context manager for handling sys.displayhook.
3 A context manager for handling sys.displayhook.
5
4
6 Authors:
5 Authors:
7
6
8 * Robert Kern
7 * Robert Kern
9 * Brian Granger
8 * Brian Granger
10 """
9 """
11
10
12 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
13 # Copyright (C) 2008-2009 The IPython Development Team
12 # Copyright (C) 2008-2009 The IPython Development Team
14 #
13 #
15 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
16 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
17 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
18
17
19 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
20 # Imports
19 # Imports
21 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
22
21
23 import sys
22 import sys
24
23
25 from IPython.config.configurable import Configurable
24 from IPython.config.configurable import Configurable
26 from IPython.utils.traitlets import Any
25 from IPython.utils.traitlets import Any
27
26
28 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
29 # Classes and functions
28 # Classes and functions
30 #-----------------------------------------------------------------------------
29 #-----------------------------------------------------------------------------
31
30
32
31
33 class DisplayTrap(Configurable):
32 class DisplayTrap(Configurable):
34 """Object to manage sys.displayhook.
33 """Object to manage sys.displayhook.
35
34
36 This came from IPython.core.kernel.display_hook, but is simplified
35 This came from IPython.core.kernel.display_hook, but is simplified
37 (no callbacks or formatters) until more of the core is refactored.
36 (no callbacks or formatters) until more of the core is refactored.
38 """
37 """
39
38
40 hook = Any
39 hook = Any
41
40
42 def __init__(self, hook=None):
41 def __init__(self, hook=None):
43 super(DisplayTrap, self).__init__(hook=hook, config=None)
42 super(DisplayTrap, self).__init__(hook=hook, config=None)
44 self.old_hook = None
43 self.old_hook = None
45 # We define this to track if a single BuiltinTrap is nested.
44 # We define this to track if a single BuiltinTrap is nested.
46 # Only turn off the trap when the outermost call to __exit__ is made.
45 # Only turn off the trap when the outermost call to __exit__ is made.
47 self._nested_level = 0
46 self._nested_level = 0
48
47
49 def __enter__(self):
48 def __enter__(self):
50 if self._nested_level == 0:
49 if self._nested_level == 0:
51 self.set()
50 self.set()
52 self._nested_level += 1
51 self._nested_level += 1
53 return self
52 return self
54
53
55 def __exit__(self, type, value, traceback):
54 def __exit__(self, type, value, traceback):
56 if self._nested_level == 1:
55 if self._nested_level == 1:
57 self.unset()
56 self.unset()
58 self._nested_level -= 1
57 self._nested_level -= 1
59 # Returning False will cause exceptions to propagate
58 # Returning False will cause exceptions to propagate
60 return False
59 return False
61
60
62 def set(self):
61 def set(self):
63 """Set the hook."""
62 """Set the hook."""
64 if sys.displayhook is not self.hook:
63 if sys.displayhook is not self.hook:
65 self.old_hook = sys.displayhook
64 self.old_hook = sys.displayhook
66 sys.displayhook = self.hook
65 sys.displayhook = self.hook
67
66
68 def unset(self):
67 def unset(self):
69 """Unset the hook."""
68 """Unset the hook."""
70 sys.displayhook = self.old_hook
69 sys.displayhook = self.old_hook
71
70
@@ -1,52 +1,51 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 Global exception classes for IPython.core.
3 Global exception classes for IPython.core.
5
4
6 Authors:
5 Authors:
7
6
8 * Brian Granger
7 * Brian Granger
9 * Fernando Perez
8 * Fernando Perez
10
9
11 Notes
10 Notes
12 -----
11 -----
13 """
12 """
14
13
15 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
16 # Copyright (C) 2008-2009 The IPython Development Team
15 # Copyright (C) 2008-2009 The IPython Development Team
17 #
16 #
18 # Distributed under the terms of the BSD License. The full license is in
17 # Distributed under the terms of the BSD License. The full license is in
19 # the file COPYING, distributed as part of this software.
18 # the file COPYING, distributed as part of this software.
20 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
21
20
22 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
23 # Imports
22 # Imports
24 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
25
24
26 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
27 # Exception classes
26 # Exception classes
28 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
29
28
30 class IPythonCoreError(Exception):
29 class IPythonCoreError(Exception):
31 pass
30 pass
32
31
33
32
34 class TryNext(IPythonCoreError):
33 class TryNext(IPythonCoreError):
35 """Try next hook exception.
34 """Try next hook exception.
36
35
37 Raise this in your hook function to indicate that the next hook handler
36 Raise this in your hook function to indicate that the next hook handler
38 should be used to handle the operation. If you pass arguments to the
37 should be used to handle the operation. If you pass arguments to the
39 constructor those arguments will be used by the next hook instead of the
38 constructor those arguments will be used by the next hook instead of the
40 original ones.
39 original ones.
41 """
40 """
42
41
43 def __init__(self, *args, **kwargs):
42 def __init__(self, *args, **kwargs):
44 self.args = args
43 self.args = args
45 self.kwargs = kwargs
44 self.kwargs = kwargs
46
45
47 class UsageError(IPythonCoreError):
46 class UsageError(IPythonCoreError):
48 """Error in magic function arguments, etc.
47 """Error in magic function arguments, etc.
49
48
50 Something that probably won't warrant a full traceback, but should
49 Something that probably won't warrant a full traceback, but should
51 nevertheless interrupt a macro / batch file.
50 nevertheless interrupt a macro / batch file.
52 """ No newline at end of file
51 """
@@ -1,30 +1,29 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 This module is *completely* deprecated and should no longer be used for
3 This module is *completely* deprecated and should no longer be used for
5 any purpose. Currently, we have a few parts of the core that have
4 any purpose. Currently, we have a few parts of the core that have
6 not been componentized and thus, still rely on this module. When everything
5 not been componentized and thus, still rely on this module. When everything
7 has been made into a component, this module will be sent to deathrow.
6 has been made into a component, this module will be sent to deathrow.
8 """
7 """
9
8
10 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
11 # Copyright (C) 2008-2009 The IPython Development Team
10 # Copyright (C) 2008-2009 The IPython Development Team
12 #
11 #
13 # Distributed under the terms of the BSD License. The full license is in
12 # Distributed under the terms of the BSD License. The full license is in
14 # the file COPYING, distributed as part of this software.
13 # the file COPYING, distributed as part of this software.
15 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
16
15
17 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
18 # Imports
17 # Imports
19 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
20
19
21 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
22 # Classes and functions
21 # Classes and functions
23 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
24
23
25
24
26 def get():
25 def get():
27 """Get the global InteractiveShell instance."""
26 """Get the global InteractiveShell instance."""
28 from IPython.core.interactiveshell import InteractiveShell
27 from IPython.core.interactiveshell import InteractiveShell
29 return InteractiveShell.instance()
28 return InteractiveShell.instance()
30
29
@@ -1,327 +1,326 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 Paging capabilities for IPython.core
3 Paging capabilities for IPython.core
5
4
6 Authors:
5 Authors:
7
6
8 * Brian Granger
7 * Brian Granger
9 * Fernando Perez
8 * Fernando Perez
10
9
11 Notes
10 Notes
12 -----
11 -----
13
12
14 For now this uses ipapi, so it can't be in IPython.utils. If we can get
13 For now this uses ipapi, so it can't be in IPython.utils. If we can get
15 rid of that dependency, we could move it there.
14 rid of that dependency, we could move it there.
16 -----
15 -----
17 """
16 """
18
17
19 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
20 # Copyright (C) 2008-2009 The IPython Development Team
19 # Copyright (C) 2008-2009 The IPython Development Team
21 #
20 #
22 # Distributed under the terms of the BSD License. The full license is in
21 # Distributed under the terms of the BSD License. The full license is in
23 # the file COPYING, distributed as part of this software.
22 # the file COPYING, distributed as part of this software.
24 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
25
24
26 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
27 # Imports
26 # Imports
28 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
29
28
30 import os
29 import os
31 import re
30 import re
32 import sys
31 import sys
33 import tempfile
32 import tempfile
34
33
35 from IPython.core import ipapi
34 from IPython.core import ipapi
36 from IPython.core.error import TryNext
35 from IPython.core.error import TryNext
37 from IPython.utils.cursesimport import use_curses
36 from IPython.utils.cursesimport import use_curses
38 from IPython.utils.data import chop
37 from IPython.utils.data import chop
39 from IPython.utils import io
38 from IPython.utils import io
40 from IPython.utils.process import system
39 from IPython.utils.process import system
41 from IPython.utils.terminal import get_terminal_size
40 from IPython.utils.terminal import get_terminal_size
42
41
43
42
44 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
45 # Classes and functions
44 # Classes and functions
46 #-----------------------------------------------------------------------------
45 #-----------------------------------------------------------------------------
47
46
48 esc_re = re.compile(r"(\x1b[^m]+m)")
47 esc_re = re.compile(r"(\x1b[^m]+m)")
49
48
50 def page_dumb(strng, start=0, screen_lines=25):
49 def page_dumb(strng, start=0, screen_lines=25):
51 """Very dumb 'pager' in Python, for when nothing else works.
50 """Very dumb 'pager' in Python, for when nothing else works.
52
51
53 Only moves forward, same interface as page(), except for pager_cmd and
52 Only moves forward, same interface as page(), except for pager_cmd and
54 mode."""
53 mode."""
55
54
56 out_ln = strng.splitlines()[start:]
55 out_ln = strng.splitlines()[start:]
57 screens = chop(out_ln,screen_lines-1)
56 screens = chop(out_ln,screen_lines-1)
58 if len(screens) == 1:
57 if len(screens) == 1:
59 print >>io.stdout, os.linesep.join(screens[0])
58 print >>io.stdout, os.linesep.join(screens[0])
60 else:
59 else:
61 last_escape = ""
60 last_escape = ""
62 for scr in screens[0:-1]:
61 for scr in screens[0:-1]:
63 hunk = os.linesep.join(scr)
62 hunk = os.linesep.join(scr)
64 print >>io.stdout, last_escape + hunk
63 print >>io.stdout, last_escape + hunk
65 if not page_more():
64 if not page_more():
66 return
65 return
67 esc_list = esc_re.findall(hunk)
66 esc_list = esc_re.findall(hunk)
68 if len(esc_list) > 0:
67 if len(esc_list) > 0:
69 last_escape = esc_list[-1]
68 last_escape = esc_list[-1]
70 print >>io.stdout, last_escape + os.linesep.join(screens[-1])
69 print >>io.stdout, last_escape + os.linesep.join(screens[-1])
71
70
72
71
73 def page(strng, start=0, screen_lines=0, pager_cmd=None):
72 def page(strng, start=0, screen_lines=0, pager_cmd=None):
74 """Print a string, piping through a pager after a certain length.
73 """Print a string, piping through a pager after a certain length.
75
74
76 The screen_lines parameter specifies the number of *usable* lines of your
75 The screen_lines parameter specifies the number of *usable* lines of your
77 terminal screen (total lines minus lines you need to reserve to show other
76 terminal screen (total lines minus lines you need to reserve to show other
78 information).
77 information).
79
78
80 If you set screen_lines to a number <=0, page() will try to auto-determine
79 If you set screen_lines to a number <=0, page() will try to auto-determine
81 your screen size and will only use up to (screen_size+screen_lines) for
80 your screen size and will only use up to (screen_size+screen_lines) for
82 printing, paging after that. That is, if you want auto-detection but need
81 printing, paging after that. That is, if you want auto-detection but need
83 to reserve the bottom 3 lines of the screen, use screen_lines = -3, and for
82 to reserve the bottom 3 lines of the screen, use screen_lines = -3, and for
84 auto-detection without any lines reserved simply use screen_lines = 0.
83 auto-detection without any lines reserved simply use screen_lines = 0.
85
84
86 If a string won't fit in the allowed lines, it is sent through the
85 If a string won't fit in the allowed lines, it is sent through the
87 specified pager command. If none given, look for PAGER in the environment,
86 specified pager command. If none given, look for PAGER in the environment,
88 and ultimately default to less.
87 and ultimately default to less.
89
88
90 If no system pager works, the string is sent through a 'dumb pager'
89 If no system pager works, the string is sent through a 'dumb pager'
91 written in python, very simplistic.
90 written in python, very simplistic.
92 """
91 """
93
92
94 # Some routines may auto-compute start offsets incorrectly and pass a
93 # Some routines may auto-compute start offsets incorrectly and pass a
95 # negative value. Offset to 0 for robustness.
94 # negative value. Offset to 0 for robustness.
96 start = max(0, start)
95 start = max(0, start)
97
96
98 # first, try the hook
97 # first, try the hook
99 ip = ipapi.get()
98 ip = ipapi.get()
100 if ip:
99 if ip:
101 try:
100 try:
102 ip.hooks.show_in_pager(strng)
101 ip.hooks.show_in_pager(strng)
103 return
102 return
104 except TryNext:
103 except TryNext:
105 pass
104 pass
106
105
107 # Ugly kludge, but calling curses.initscr() flat out crashes in emacs
106 # Ugly kludge, but calling curses.initscr() flat out crashes in emacs
108 TERM = os.environ.get('TERM','dumb')
107 TERM = os.environ.get('TERM','dumb')
109 if TERM in ['dumb','emacs'] and os.name != 'nt':
108 if TERM in ['dumb','emacs'] and os.name != 'nt':
110 print strng
109 print strng
111 return
110 return
112 # chop off the topmost part of the string we don't want to see
111 # chop off the topmost part of the string we don't want to see
113 str_lines = strng.splitlines()[start:]
112 str_lines = strng.splitlines()[start:]
114 str_toprint = os.linesep.join(str_lines)
113 str_toprint = os.linesep.join(str_lines)
115 num_newlines = len(str_lines)
114 num_newlines = len(str_lines)
116 len_str = len(str_toprint)
115 len_str = len(str_toprint)
117
116
118 # Dumb heuristics to guesstimate number of on-screen lines the string
117 # Dumb heuristics to guesstimate number of on-screen lines the string
119 # takes. Very basic, but good enough for docstrings in reasonable
118 # takes. Very basic, but good enough for docstrings in reasonable
120 # terminals. If someone later feels like refining it, it's not hard.
119 # terminals. If someone later feels like refining it, it's not hard.
121 numlines = max(num_newlines,int(len_str/80)+1)
120 numlines = max(num_newlines,int(len_str/80)+1)
122
121
123 screen_lines_def = get_terminal_size()[1]
122 screen_lines_def = get_terminal_size()[1]
124
123
125 # auto-determine screen size
124 # auto-determine screen size
126 if screen_lines <= 0:
125 if screen_lines <= 0:
127 if (TERM=='xterm' or TERM=='xterm-color') and sys.platform != 'sunos5':
126 if (TERM=='xterm' or TERM=='xterm-color') and sys.platform != 'sunos5':
128 local_use_curses = use_curses
127 local_use_curses = use_curses
129 else:
128 else:
130 # curses causes problems on many terminals other than xterm, and
129 # curses causes problems on many terminals other than xterm, and
131 # some termios calls lock up on Sun OS5.
130 # some termios calls lock up on Sun OS5.
132 local_use_curses = False
131 local_use_curses = False
133 if local_use_curses:
132 if local_use_curses:
134 import termios
133 import termios
135 import curses
134 import curses
136 # There is a bug in curses, where *sometimes* it fails to properly
135 # There is a bug in curses, where *sometimes* it fails to properly
137 # initialize, and then after the endwin() call is made, the
136 # initialize, and then after the endwin() call is made, the
138 # terminal is left in an unusable state. Rather than trying to
137 # terminal is left in an unusable state. Rather than trying to
139 # check everytime for this (by requesting and comparing termios
138 # check everytime for this (by requesting and comparing termios
140 # flags each time), we just save the initial terminal state and
139 # flags each time), we just save the initial terminal state and
141 # unconditionally reset it every time. It's cheaper than making
140 # unconditionally reset it every time. It's cheaper than making
142 # the checks.
141 # the checks.
143 term_flags = termios.tcgetattr(sys.stdout)
142 term_flags = termios.tcgetattr(sys.stdout)
144
143
145 # Curses modifies the stdout buffer size by default, which messes
144 # Curses modifies the stdout buffer size by default, which messes
146 # up Python's normal stdout buffering. This would manifest itself
145 # up Python's normal stdout buffering. This would manifest itself
147 # to IPython users as delayed printing on stdout after having used
146 # to IPython users as delayed printing on stdout after having used
148 # the pager.
147 # the pager.
149 #
148 #
150 # We can prevent this by manually setting the NCURSES_NO_SETBUF
149 # We can prevent this by manually setting the NCURSES_NO_SETBUF
151 # environment variable. For more details, see:
150 # environment variable. For more details, see:
152 # http://bugs.python.org/issue10144
151 # http://bugs.python.org/issue10144
153 NCURSES_NO_SETBUF = os.environ.get('NCURSES_NO_SETBUF', None)
152 NCURSES_NO_SETBUF = os.environ.get('NCURSES_NO_SETBUF', None)
154 os.environ['NCURSES_NO_SETBUF'] = ''
153 os.environ['NCURSES_NO_SETBUF'] = ''
155
154
156 # Proceed with curses initialization
155 # Proceed with curses initialization
157 scr = curses.initscr()
156 scr = curses.initscr()
158 screen_lines_real,screen_cols = scr.getmaxyx()
157 screen_lines_real,screen_cols = scr.getmaxyx()
159 curses.endwin()
158 curses.endwin()
160
159
161 # Restore environment
160 # Restore environment
162 if NCURSES_NO_SETBUF is None:
161 if NCURSES_NO_SETBUF is None:
163 del os.environ['NCURSES_NO_SETBUF']
162 del os.environ['NCURSES_NO_SETBUF']
164 else:
163 else:
165 os.environ['NCURSES_NO_SETBUF'] = NCURSES_NO_SETBUF
164 os.environ['NCURSES_NO_SETBUF'] = NCURSES_NO_SETBUF
166
165
167 # Restore terminal state in case endwin() didn't.
166 # Restore terminal state in case endwin() didn't.
168 termios.tcsetattr(sys.stdout,termios.TCSANOW,term_flags)
167 termios.tcsetattr(sys.stdout,termios.TCSANOW,term_flags)
169 # Now we have what we needed: the screen size in rows/columns
168 # Now we have what we needed: the screen size in rows/columns
170 screen_lines += screen_lines_real
169 screen_lines += screen_lines_real
171 #print '***Screen size:',screen_lines_real,'lines x',\
170 #print '***Screen size:',screen_lines_real,'lines x',\
172 #screen_cols,'columns.' # dbg
171 #screen_cols,'columns.' # dbg
173 else:
172 else:
174 screen_lines += screen_lines_def
173 screen_lines += screen_lines_def
175
174
176 #print 'numlines',numlines,'screenlines',screen_lines # dbg
175 #print 'numlines',numlines,'screenlines',screen_lines # dbg
177 if numlines <= screen_lines :
176 if numlines <= screen_lines :
178 #print '*** normal print' # dbg
177 #print '*** normal print' # dbg
179 print >>io.stdout, str_toprint
178 print >>io.stdout, str_toprint
180 else:
179 else:
181 # Try to open pager and default to internal one if that fails.
180 # Try to open pager and default to internal one if that fails.
182 # All failure modes are tagged as 'retval=1', to match the return
181 # All failure modes are tagged as 'retval=1', to match the return
183 # value of a failed system command. If any intermediate attempt
182 # value of a failed system command. If any intermediate attempt
184 # sets retval to 1, at the end we resort to our own page_dumb() pager.
183 # sets retval to 1, at the end we resort to our own page_dumb() pager.
185 pager_cmd = get_pager_cmd(pager_cmd)
184 pager_cmd = get_pager_cmd(pager_cmd)
186 pager_cmd += ' ' + get_pager_start(pager_cmd,start)
185 pager_cmd += ' ' + get_pager_start(pager_cmd,start)
187 if os.name == 'nt':
186 if os.name == 'nt':
188 if pager_cmd.startswith('type'):
187 if pager_cmd.startswith('type'):
189 # The default WinXP 'type' command is failing on complex strings.
188 # The default WinXP 'type' command is failing on complex strings.
190 retval = 1
189 retval = 1
191 else:
190 else:
192 tmpname = tempfile.mktemp('.txt')
191 tmpname = tempfile.mktemp('.txt')
193 tmpfile = file(tmpname,'wt')
192 tmpfile = file(tmpname,'wt')
194 tmpfile.write(strng)
193 tmpfile.write(strng)
195 tmpfile.close()
194 tmpfile.close()
196 cmd = "%s < %s" % (pager_cmd,tmpname)
195 cmd = "%s < %s" % (pager_cmd,tmpname)
197 if os.system(cmd):
196 if os.system(cmd):
198 retval = 1
197 retval = 1
199 else:
198 else:
200 retval = None
199 retval = None
201 os.remove(tmpname)
200 os.remove(tmpname)
202 else:
201 else:
203 try:
202 try:
204 retval = None
203 retval = None
205 # if I use popen4, things hang. No idea why.
204 # if I use popen4, things hang. No idea why.
206 #pager,shell_out = os.popen4(pager_cmd)
205 #pager,shell_out = os.popen4(pager_cmd)
207 pager = os.popen(pager_cmd,'w')
206 pager = os.popen(pager_cmd,'w')
208 pager.write(strng)
207 pager.write(strng)
209 pager.close()
208 pager.close()
210 retval = pager.close() # success returns None
209 retval = pager.close() # success returns None
211 except IOError,msg: # broken pipe when user quits
210 except IOError,msg: # broken pipe when user quits
212 if msg.args == (32,'Broken pipe'):
211 if msg.args == (32,'Broken pipe'):
213 retval = None
212 retval = None
214 else:
213 else:
215 retval = 1
214 retval = 1
216 except OSError:
215 except OSError:
217 # Other strange problems, sometimes seen in Win2k/cygwin
216 # Other strange problems, sometimes seen in Win2k/cygwin
218 retval = 1
217 retval = 1
219 if retval is not None:
218 if retval is not None:
220 page_dumb(strng,screen_lines=screen_lines)
219 page_dumb(strng,screen_lines=screen_lines)
221
220
222
221
223 def page_file(fname, start=0, pager_cmd=None):
222 def page_file(fname, start=0, pager_cmd=None):
224 """Page a file, using an optional pager command and starting line.
223 """Page a file, using an optional pager command and starting line.
225 """
224 """
226
225
227 pager_cmd = get_pager_cmd(pager_cmd)
226 pager_cmd = get_pager_cmd(pager_cmd)
228 pager_cmd += ' ' + get_pager_start(pager_cmd,start)
227 pager_cmd += ' ' + get_pager_start(pager_cmd,start)
229
228
230 try:
229 try:
231 if os.environ['TERM'] in ['emacs','dumb']:
230 if os.environ['TERM'] in ['emacs','dumb']:
232 raise EnvironmentError
231 raise EnvironmentError
233 system(pager_cmd + ' ' + fname)
232 system(pager_cmd + ' ' + fname)
234 except:
233 except:
235 try:
234 try:
236 if start > 0:
235 if start > 0:
237 start -= 1
236 start -= 1
238 page(open(fname).read(),start)
237 page(open(fname).read(),start)
239 except:
238 except:
240 print 'Unable to show file',`fname`
239 print 'Unable to show file',`fname`
241
240
242
241
243 def get_pager_cmd(pager_cmd=None):
242 def get_pager_cmd(pager_cmd=None):
244 """Return a pager command.
243 """Return a pager command.
245
244
246 Makes some attempts at finding an OS-correct one.
245 Makes some attempts at finding an OS-correct one.
247 """
246 """
248 if os.name == 'posix':
247 if os.name == 'posix':
249 default_pager_cmd = 'less -r' # -r for color control sequences
248 default_pager_cmd = 'less -r' # -r for color control sequences
250 elif os.name in ['nt','dos']:
249 elif os.name in ['nt','dos']:
251 default_pager_cmd = 'type'
250 default_pager_cmd = 'type'
252
251
253 if pager_cmd is None:
252 if pager_cmd is None:
254 try:
253 try:
255 pager_cmd = os.environ['PAGER']
254 pager_cmd = os.environ['PAGER']
256 except:
255 except:
257 pager_cmd = default_pager_cmd
256 pager_cmd = default_pager_cmd
258 return pager_cmd
257 return pager_cmd
259
258
260
259
261 def get_pager_start(pager, start):
260 def get_pager_start(pager, start):
262 """Return the string for paging files with an offset.
261 """Return the string for paging files with an offset.
263
262
264 This is the '+N' argument which less and more (under Unix) accept.
263 This is the '+N' argument which less and more (under Unix) accept.
265 """
264 """
266
265
267 if pager in ['less','more']:
266 if pager in ['less','more']:
268 if start:
267 if start:
269 start_string = '+' + str(start)
268 start_string = '+' + str(start)
270 else:
269 else:
271 start_string = ''
270 start_string = ''
272 else:
271 else:
273 start_string = ''
272 start_string = ''
274 return start_string
273 return start_string
275
274
276
275
277 # (X)emacs on win32 doesn't like to be bypassed with msvcrt.getch()
276 # (X)emacs on win32 doesn't like to be bypassed with msvcrt.getch()
278 if os.name == 'nt' and os.environ.get('TERM','dumb') != 'emacs':
277 if os.name == 'nt' and os.environ.get('TERM','dumb') != 'emacs':
279 import msvcrt
278 import msvcrt
280 def page_more():
279 def page_more():
281 """ Smart pausing between pages
280 """ Smart pausing between pages
282
281
283 @return: True if need print more lines, False if quit
282 @return: True if need print more lines, False if quit
284 """
283 """
285 io.stdout.write('---Return to continue, q to quit--- ')
284 io.stdout.write('---Return to continue, q to quit--- ')
286 ans = msvcrt.getch()
285 ans = msvcrt.getch()
287 if ans in ("q", "Q"):
286 if ans in ("q", "Q"):
288 result = False
287 result = False
289 else:
288 else:
290 result = True
289 result = True
291 io.stdout.write("\b"*37 + " "*37 + "\b"*37)
290 io.stdout.write("\b"*37 + " "*37 + "\b"*37)
292 return result
291 return result
293 else:
292 else:
294 def page_more():
293 def page_more():
295 ans = raw_input('---Return to continue, q to quit--- ')
294 ans = raw_input('---Return to continue, q to quit--- ')
296 if ans.lower().startswith('q'):
295 if ans.lower().startswith('q'):
297 return False
296 return False
298 else:
297 else:
299 return True
298 return True
300
299
301
300
302 def snip_print(str,width = 75,print_full = 0,header = ''):
301 def snip_print(str,width = 75,print_full = 0,header = ''):
303 """Print a string snipping the midsection to fit in width.
302 """Print a string snipping the midsection to fit in width.
304
303
305 print_full: mode control:
304 print_full: mode control:
306 - 0: only snip long strings
305 - 0: only snip long strings
307 - 1: send to page() directly.
306 - 1: send to page() directly.
308 - 2: snip long strings and ask for full length viewing with page()
307 - 2: snip long strings and ask for full length viewing with page()
309 Return 1 if snipping was necessary, 0 otherwise."""
308 Return 1 if snipping was necessary, 0 otherwise."""
310
309
311 if print_full == 1:
310 if print_full == 1:
312 page(header+str)
311 page(header+str)
313 return 0
312 return 0
314
313
315 print header,
314 print header,
316 if len(str) < width:
315 if len(str) < width:
317 print str
316 print str
318 snip = 0
317 snip = 0
319 else:
318 else:
320 whalf = int((width -5)/2)
319 whalf = int((width -5)/2)
321 print str[:whalf] + ' <...> ' + str[-whalf:]
320 print str[:whalf] + ' <...> ' + str[-whalf:]
322 snip = 1
321 snip = 1
323 if snip and print_full == 2:
322 if snip and print_full == 2:
324 if raw_input(header+' Snipped. View (y/n)? [N]').lower() == 'y':
323 if raw_input(header+' Snipped. View (y/n)? [N]').lower() == 'y':
325 page(str)
324 page(str)
326 return snip
325 return snip
327
326
@@ -1,97 +1,96 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 A payload based version of page.
3 A payload based version of page.
5
4
6 Authors:
5 Authors:
7
6
8 * Brian Granger
7 * Brian Granger
9 * Fernando Perez
8 * Fernando Perez
10 """
9 """
11
10
12 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
13 # Copyright (C) 2008-2010 The IPython Development Team
12 # Copyright (C) 2008-2010 The IPython Development Team
14 #
13 #
15 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
16 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
17 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
18
17
19 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
20 # Imports
19 # Imports
21 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
22
21
23 # Third-party
22 # Third-party
24 try:
23 try:
25 from docutils.core import publish_string
24 from docutils.core import publish_string
26 except ImportError:
25 except ImportError:
27 # html paging won't be available, but we don't raise any errors. It's a
26 # html paging won't be available, but we don't raise any errors. It's a
28 # purely optional feature.
27 # purely optional feature.
29 pass
28 pass
30
29
31 # Our own
30 # Our own
32 from IPython.core.interactiveshell import InteractiveShell
31 from IPython.core.interactiveshell import InteractiveShell
33
32
34 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
35 # Classes and functions
34 # Classes and functions
36 #-----------------------------------------------------------------------------
35 #-----------------------------------------------------------------------------
37
36
38 def page(strng, start=0, screen_lines=0, pager_cmd=None,
37 def page(strng, start=0, screen_lines=0, pager_cmd=None,
39 html=None, auto_html=False):
38 html=None, auto_html=False):
40 """Print a string, piping through a pager.
39 """Print a string, piping through a pager.
41
40
42 This version ignores the screen_lines and pager_cmd arguments and uses
41 This version ignores the screen_lines and pager_cmd arguments and uses
43 IPython's payload system instead.
42 IPython's payload system instead.
44
43
45 Parameters
44 Parameters
46 ----------
45 ----------
47 strng : str
46 strng : str
48 Text to page.
47 Text to page.
49
48
50 start : int
49 start : int
51 Starting line at which to place the display.
50 Starting line at which to place the display.
52
51
53 html : str, optional
52 html : str, optional
54 If given, an html string to send as well.
53 If given, an html string to send as well.
55
54
56 auto_html : bool, optional
55 auto_html : bool, optional
57 If true, the input string is assumed to be valid reStructuredText and is
56 If true, the input string is assumed to be valid reStructuredText and is
58 converted to HTML with docutils. Note that if docutils is not found,
57 converted to HTML with docutils. Note that if docutils is not found,
59 this option is silently ignored.
58 this option is silently ignored.
60
59
61 Note
60 Note
62 ----
61 ----
63
62
64 Only one of the ``html`` and ``auto_html`` options can be given, not
63 Only one of the ``html`` and ``auto_html`` options can be given, not
65 both.
64 both.
66 """
65 """
67
66
68 # Some routines may auto-compute start offsets incorrectly and pass a
67 # Some routines may auto-compute start offsets incorrectly and pass a
69 # negative value. Offset to 0 for robustness.
68 # negative value. Offset to 0 for robustness.
70 start = max(0, start)
69 start = max(0, start)
71 shell = InteractiveShell.instance()
70 shell = InteractiveShell.instance()
72
71
73 if auto_html:
72 if auto_html:
74 try:
73 try:
75 # These defaults ensure user configuration variables for docutils
74 # These defaults ensure user configuration variables for docutils
76 # are not loaded, only our config is used here.
75 # are not loaded, only our config is used here.
77 defaults = {'file_insertion_enabled': 0,
76 defaults = {'file_insertion_enabled': 0,
78 'raw_enabled': 0,
77 'raw_enabled': 0,
79 '_disable_config': 1}
78 '_disable_config': 1}
80 html = publish_string(strng, writer_name='html',
79 html = publish_string(strng, writer_name='html',
81 settings_overrides=defaults)
80 settings_overrides=defaults)
82 except:
81 except:
83 pass
82 pass
84
83
85 payload = dict(
84 payload = dict(
86 source='IPython.zmq.page.page',
85 source='IPython.zmq.page.page',
87 text=strng,
86 text=strng,
88 html=html,
87 html=html,
89 start_line_number=start
88 start_line_number=start
90 )
89 )
91 shell.payload_manager.write_payload(payload)
90 shell.payload_manager.write_payload(payload)
92
91
93
92
94 def install_payload_page():
93 def install_payload_page():
95 """Install this version of page as IPython.core.page.page."""
94 """Install this version of page as IPython.core.page.page."""
96 from IPython.core import page as corepage
95 from IPython.core import page as corepage
97 corepage.page = page
96 corepage.page = page
@@ -1,1013 +1,1012 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 Prefiltering components.
3 Prefiltering components.
5
4
6 Prefilters transform user input before it is exec'd by Python. These
5 Prefilters transform user input before it is exec'd by Python. These
7 transforms are used to implement additional syntax such as !ls and %magic.
6 transforms are used to implement additional syntax such as !ls and %magic.
8
7
9 Authors:
8 Authors:
10
9
11 * Brian Granger
10 * Brian Granger
12 * Fernando Perez
11 * Fernando Perez
13 * Dan Milstein
12 * Dan Milstein
14 * Ville Vainio
13 * Ville Vainio
15 """
14 """
16
15
17 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
18 # Copyright (C) 2008-2009 The IPython Development Team
17 # Copyright (C) 2008-2009 The IPython Development Team
19 #
18 #
20 # Distributed under the terms of the BSD License. The full license is in
19 # Distributed under the terms of the BSD License. The full license is in
21 # the file COPYING, distributed as part of this software.
20 # the file COPYING, distributed as part of this software.
22 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
23
22
24 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
25 # Imports
24 # Imports
26 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
27
26
28 import __builtin__
27 import __builtin__
29 import codeop
28 import codeop
30 import re
29 import re
31
30
32 from IPython.core.alias import AliasManager
31 from IPython.core.alias import AliasManager
33 from IPython.core.autocall import IPyAutocall
32 from IPython.core.autocall import IPyAutocall
34 from IPython.config.configurable import Configurable
33 from IPython.config.configurable import Configurable
35 from IPython.core.macro import Macro
34 from IPython.core.macro import Macro
36 from IPython.core.splitinput import split_user_input
35 from IPython.core.splitinput import split_user_input
37 from IPython.core import page
36 from IPython.core import page
38
37
39 from IPython.utils.traitlets import List, Int, Any, Unicode, CBool, Bool, Instance
38 from IPython.utils.traitlets import List, Int, Any, Unicode, CBool, Bool, Instance
40 from IPython.utils.text import make_quoted_expr
39 from IPython.utils.text import make_quoted_expr
41 from IPython.utils.autoattr import auto_attr
40 from IPython.utils.autoattr import auto_attr
42
41
43 #-----------------------------------------------------------------------------
42 #-----------------------------------------------------------------------------
44 # Global utilities, errors and constants
43 # Global utilities, errors and constants
45 #-----------------------------------------------------------------------------
44 #-----------------------------------------------------------------------------
46
45
47 # Warning, these cannot be changed unless various regular expressions
46 # Warning, these cannot be changed unless various regular expressions
48 # are updated in a number of places. Not great, but at least we told you.
47 # are updated in a number of places. Not great, but at least we told you.
49 ESC_SHELL = '!'
48 ESC_SHELL = '!'
50 ESC_SH_CAP = '!!'
49 ESC_SH_CAP = '!!'
51 ESC_HELP = '?'
50 ESC_HELP = '?'
52 ESC_MAGIC = '%'
51 ESC_MAGIC = '%'
53 ESC_QUOTE = ','
52 ESC_QUOTE = ','
54 ESC_QUOTE2 = ';'
53 ESC_QUOTE2 = ';'
55 ESC_PAREN = '/'
54 ESC_PAREN = '/'
56
55
57
56
58 class PrefilterError(Exception):
57 class PrefilterError(Exception):
59 pass
58 pass
60
59
61
60
62 # RegExp to identify potential function names
61 # RegExp to identify potential function names
63 re_fun_name = re.compile(r'[a-zA-Z_]([a-zA-Z0-9_.]*) *$')
62 re_fun_name = re.compile(r'[a-zA-Z_]([a-zA-Z0-9_.]*) *$')
64
63
65 # RegExp to exclude strings with this start from autocalling. In
64 # RegExp to exclude strings with this start from autocalling. In
66 # particular, all binary operators should be excluded, so that if foo is
65 # particular, all binary operators should be excluded, so that if foo is
67 # callable, foo OP bar doesn't become foo(OP bar), which is invalid. The
66 # callable, foo OP bar doesn't become foo(OP bar), which is invalid. The
68 # characters '!=()' don't need to be checked for, as the checkPythonChars
67 # characters '!=()' don't need to be checked for, as the checkPythonChars
69 # routine explicitely does so, to catch direct calls and rebindings of
68 # routine explicitely does so, to catch direct calls and rebindings of
70 # existing names.
69 # existing names.
71
70
72 # Warning: the '-' HAS TO BE AT THE END of the first group, otherwise
71 # Warning: the '-' HAS TO BE AT THE END of the first group, otherwise
73 # it affects the rest of the group in square brackets.
72 # it affects the rest of the group in square brackets.
74 re_exclude_auto = re.compile(r'^[,&^\|\*/\+-]'
73 re_exclude_auto = re.compile(r'^[,&^\|\*/\+-]'
75 r'|^is |^not |^in |^and |^or ')
74 r'|^is |^not |^in |^and |^or ')
76
75
77 # try to catch also methods for stuff in lists/tuples/dicts: off
76 # try to catch also methods for stuff in lists/tuples/dicts: off
78 # (experimental). For this to work, the line_split regexp would need
77 # (experimental). For this to work, the line_split regexp would need
79 # to be modified so it wouldn't break things at '['. That line is
78 # to be modified so it wouldn't break things at '['. That line is
80 # nasty enough that I shouldn't change it until I can test it _well_.
79 # nasty enough that I shouldn't change it until I can test it _well_.
81 #self.re_fun_name = re.compile (r'[a-zA-Z_]([a-zA-Z0-9_.\[\]]*) ?$')
80 #self.re_fun_name = re.compile (r'[a-zA-Z_]([a-zA-Z0-9_.\[\]]*) ?$')
82
81
83
82
84 # Handler Check Utilities
83 # Handler Check Utilities
85 def is_shadowed(identifier, ip):
84 def is_shadowed(identifier, ip):
86 """Is the given identifier defined in one of the namespaces which shadow
85 """Is the given identifier defined in one of the namespaces which shadow
87 the alias and magic namespaces? Note that an identifier is different
86 the alias and magic namespaces? Note that an identifier is different
88 than ifun, because it can not contain a '.' character."""
87 than ifun, because it can not contain a '.' character."""
89 # This is much safer than calling ofind, which can change state
88 # This is much safer than calling ofind, which can change state
90 return (identifier in ip.user_ns \
89 return (identifier in ip.user_ns \
91 or identifier in ip.internal_ns \
90 or identifier in ip.internal_ns \
92 or identifier in ip.ns_table['builtin'])
91 or identifier in ip.ns_table['builtin'])
93
92
94
93
95 #-----------------------------------------------------------------------------
94 #-----------------------------------------------------------------------------
96 # The LineInfo class used throughout
95 # The LineInfo class used throughout
97 #-----------------------------------------------------------------------------
96 #-----------------------------------------------------------------------------
98
97
99
98
100 class LineInfo(object):
99 class LineInfo(object):
101 """A single line of input and associated info.
100 """A single line of input and associated info.
102
101
103 Includes the following as properties:
102 Includes the following as properties:
104
103
105 line
104 line
106 The original, raw line
105 The original, raw line
107
106
108 continue_prompt
107 continue_prompt
109 Is this line a continuation in a sequence of multiline input?
108 Is this line a continuation in a sequence of multiline input?
110
109
111 pre
110 pre
112 The initial esc character or whitespace.
111 The initial esc character or whitespace.
113
112
114 pre_char
113 pre_char
115 The escape character(s) in pre or the empty string if there isn't one.
114 The escape character(s) in pre or the empty string if there isn't one.
116 Note that '!!' is a possible value for pre_char. Otherwise it will
115 Note that '!!' is a possible value for pre_char. Otherwise it will
117 always be a single character.
116 always be a single character.
118
117
119 pre_whitespace
118 pre_whitespace
120 The leading whitespace from pre if it exists. If there is a pre_char,
119 The leading whitespace from pre if it exists. If there is a pre_char,
121 this is just ''.
120 this is just ''.
122
121
123 ifun
122 ifun
124 The 'function part', which is basically the maximal initial sequence
123 The 'function part', which is basically the maximal initial sequence
125 of valid python identifiers and the '.' character. This is what is
124 of valid python identifiers and the '.' character. This is what is
126 checked for alias and magic transformations, used for auto-calling,
125 checked for alias and magic transformations, used for auto-calling,
127 etc.
126 etc.
128
127
129 the_rest
128 the_rest
130 Everything else on the line.
129 Everything else on the line.
131 """
130 """
132 def __init__(self, line, continue_prompt):
131 def __init__(self, line, continue_prompt):
133 self.line = line
132 self.line = line
134 self.continue_prompt = continue_prompt
133 self.continue_prompt = continue_prompt
135 self.pre, self.ifun, self.the_rest = split_user_input(line)
134 self.pre, self.ifun, self.the_rest = split_user_input(line)
136
135
137 self.pre_char = self.pre.strip()
136 self.pre_char = self.pre.strip()
138 if self.pre_char:
137 if self.pre_char:
139 self.pre_whitespace = '' # No whitespace allowd before esc chars
138 self.pre_whitespace = '' # No whitespace allowd before esc chars
140 else:
139 else:
141 self.pre_whitespace = self.pre
140 self.pre_whitespace = self.pre
142
141
143 self._oinfo = None
142 self._oinfo = None
144
143
145 def ofind(self, ip):
144 def ofind(self, ip):
146 """Do a full, attribute-walking lookup of the ifun in the various
145 """Do a full, attribute-walking lookup of the ifun in the various
147 namespaces for the given IPython InteractiveShell instance.
146 namespaces for the given IPython InteractiveShell instance.
148
147
149 Return a dict with keys: found,obj,ospace,ismagic
148 Return a dict with keys: found,obj,ospace,ismagic
150
149
151 Note: can cause state changes because of calling getattr, but should
150 Note: can cause state changes because of calling getattr, but should
152 only be run if autocall is on and if the line hasn't matched any
151 only be run if autocall is on and if the line hasn't matched any
153 other, less dangerous handlers.
152 other, less dangerous handlers.
154
153
155 Does cache the results of the call, so can be called multiple times
154 Does cache the results of the call, so can be called multiple times
156 without worrying about *further* damaging state.
155 without worrying about *further* damaging state.
157 """
156 """
158 if not self._oinfo:
157 if not self._oinfo:
159 # ip.shell._ofind is actually on the Magic class!
158 # ip.shell._ofind is actually on the Magic class!
160 self._oinfo = ip.shell._ofind(self.ifun)
159 self._oinfo = ip.shell._ofind(self.ifun)
161 return self._oinfo
160 return self._oinfo
162
161
163 def __str__(self):
162 def __str__(self):
164 return "Lineinfo [%s|%s|%s]" %(self.pre, self.ifun, self.the_rest)
163 return "Lineinfo [%s|%s|%s]" %(self.pre, self.ifun, self.the_rest)
165
164
166
165
167 #-----------------------------------------------------------------------------
166 #-----------------------------------------------------------------------------
168 # Main Prefilter manager
167 # Main Prefilter manager
169 #-----------------------------------------------------------------------------
168 #-----------------------------------------------------------------------------
170
169
171
170
172 class PrefilterManager(Configurable):
171 class PrefilterManager(Configurable):
173 """Main prefilter component.
172 """Main prefilter component.
174
173
175 The IPython prefilter is run on all user input before it is run. The
174 The IPython prefilter is run on all user input before it is run. The
176 prefilter consumes lines of input and produces transformed lines of
175 prefilter consumes lines of input and produces transformed lines of
177 input.
176 input.
178
177
179 The iplementation consists of two phases:
178 The iplementation consists of two phases:
180
179
181 1. Transformers
180 1. Transformers
182 2. Checkers and handlers
181 2. Checkers and handlers
183
182
184 Over time, we plan on deprecating the checkers and handlers and doing
183 Over time, we plan on deprecating the checkers and handlers and doing
185 everything in the transformers.
184 everything in the transformers.
186
185
187 The transformers are instances of :class:`PrefilterTransformer` and have
186 The transformers are instances of :class:`PrefilterTransformer` and have
188 a single method :meth:`transform` that takes a line and returns a
187 a single method :meth:`transform` that takes a line and returns a
189 transformed line. The transformation can be accomplished using any
188 transformed line. The transformation can be accomplished using any
190 tool, but our current ones use regular expressions for speed. We also
189 tool, but our current ones use regular expressions for speed. We also
191 ship :mod:`pyparsing` in :mod:`IPython.external` for use in transformers.
190 ship :mod:`pyparsing` in :mod:`IPython.external` for use in transformers.
192
191
193 After all the transformers have been run, the line is fed to the checkers,
192 After all the transformers have been run, the line is fed to the checkers,
194 which are instances of :class:`PrefilterChecker`. The line is passed to
193 which are instances of :class:`PrefilterChecker`. The line is passed to
195 the :meth:`check` method, which either returns `None` or a
194 the :meth:`check` method, which either returns `None` or a
196 :class:`PrefilterHandler` instance. If `None` is returned, the other
195 :class:`PrefilterHandler` instance. If `None` is returned, the other
197 checkers are tried. If an :class:`PrefilterHandler` instance is returned,
196 checkers are tried. If an :class:`PrefilterHandler` instance is returned,
198 the line is passed to the :meth:`handle` method of the returned
197 the line is passed to the :meth:`handle` method of the returned
199 handler and no further checkers are tried.
198 handler and no further checkers are tried.
200
199
201 Both transformers and checkers have a `priority` attribute, that determines
200 Both transformers and checkers have a `priority` attribute, that determines
202 the order in which they are called. Smaller priorities are tried first.
201 the order in which they are called. Smaller priorities are tried first.
203
202
204 Both transformers and checkers also have `enabled` attribute, which is
203 Both transformers and checkers also have `enabled` attribute, which is
205 a boolean that determines if the instance is used.
204 a boolean that determines if the instance is used.
206
205
207 Users or developers can change the priority or enabled attribute of
206 Users or developers can change the priority or enabled attribute of
208 transformers or checkers, but they must call the :meth:`sort_checkers`
207 transformers or checkers, but they must call the :meth:`sort_checkers`
209 or :meth:`sort_transformers` method after changing the priority.
208 or :meth:`sort_transformers` method after changing the priority.
210 """
209 """
211
210
212 multi_line_specials = CBool(True, config=True)
211 multi_line_specials = CBool(True, config=True)
213 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
212 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
214
213
215 def __init__(self, shell=None, config=None):
214 def __init__(self, shell=None, config=None):
216 super(PrefilterManager, self).__init__(shell=shell, config=config)
215 super(PrefilterManager, self).__init__(shell=shell, config=config)
217 self.shell = shell
216 self.shell = shell
218 self.init_transformers()
217 self.init_transformers()
219 self.init_handlers()
218 self.init_handlers()
220 self.init_checkers()
219 self.init_checkers()
221
220
222 #-------------------------------------------------------------------------
221 #-------------------------------------------------------------------------
223 # API for managing transformers
222 # API for managing transformers
224 #-------------------------------------------------------------------------
223 #-------------------------------------------------------------------------
225
224
226 def init_transformers(self):
225 def init_transformers(self):
227 """Create the default transformers."""
226 """Create the default transformers."""
228 self._transformers = []
227 self._transformers = []
229 for transformer_cls in _default_transformers:
228 for transformer_cls in _default_transformers:
230 transformer_cls(
229 transformer_cls(
231 shell=self.shell, prefilter_manager=self, config=self.config
230 shell=self.shell, prefilter_manager=self, config=self.config
232 )
231 )
233
232
234 def sort_transformers(self):
233 def sort_transformers(self):
235 """Sort the transformers by priority.
234 """Sort the transformers by priority.
236
235
237 This must be called after the priority of a transformer is changed.
236 This must be called after the priority of a transformer is changed.
238 The :meth:`register_transformer` method calls this automatically.
237 The :meth:`register_transformer` method calls this automatically.
239 """
238 """
240 self._transformers.sort(key=lambda x: x.priority)
239 self._transformers.sort(key=lambda x: x.priority)
241
240
242 @property
241 @property
243 def transformers(self):
242 def transformers(self):
244 """Return a list of checkers, sorted by priority."""
243 """Return a list of checkers, sorted by priority."""
245 return self._transformers
244 return self._transformers
246
245
247 def register_transformer(self, transformer):
246 def register_transformer(self, transformer):
248 """Register a transformer instance."""
247 """Register a transformer instance."""
249 if transformer not in self._transformers:
248 if transformer not in self._transformers:
250 self._transformers.append(transformer)
249 self._transformers.append(transformer)
251 self.sort_transformers()
250 self.sort_transformers()
252
251
253 def unregister_transformer(self, transformer):
252 def unregister_transformer(self, transformer):
254 """Unregister a transformer instance."""
253 """Unregister a transformer instance."""
255 if transformer in self._transformers:
254 if transformer in self._transformers:
256 self._transformers.remove(transformer)
255 self._transformers.remove(transformer)
257
256
258 #-------------------------------------------------------------------------
257 #-------------------------------------------------------------------------
259 # API for managing checkers
258 # API for managing checkers
260 #-------------------------------------------------------------------------
259 #-------------------------------------------------------------------------
261
260
262 def init_checkers(self):
261 def init_checkers(self):
263 """Create the default checkers."""
262 """Create the default checkers."""
264 self._checkers = []
263 self._checkers = []
265 for checker in _default_checkers:
264 for checker in _default_checkers:
266 checker(
265 checker(
267 shell=self.shell, prefilter_manager=self, config=self.config
266 shell=self.shell, prefilter_manager=self, config=self.config
268 )
267 )
269
268
270 def sort_checkers(self):
269 def sort_checkers(self):
271 """Sort the checkers by priority.
270 """Sort the checkers by priority.
272
271
273 This must be called after the priority of a checker is changed.
272 This must be called after the priority of a checker is changed.
274 The :meth:`register_checker` method calls this automatically.
273 The :meth:`register_checker` method calls this automatically.
275 """
274 """
276 self._checkers.sort(key=lambda x: x.priority)
275 self._checkers.sort(key=lambda x: x.priority)
277
276
278 @property
277 @property
279 def checkers(self):
278 def checkers(self):
280 """Return a list of checkers, sorted by priority."""
279 """Return a list of checkers, sorted by priority."""
281 return self._checkers
280 return self._checkers
282
281
283 def register_checker(self, checker):
282 def register_checker(self, checker):
284 """Register a checker instance."""
283 """Register a checker instance."""
285 if checker not in self._checkers:
284 if checker not in self._checkers:
286 self._checkers.append(checker)
285 self._checkers.append(checker)
287 self.sort_checkers()
286 self.sort_checkers()
288
287
289 def unregister_checker(self, checker):
288 def unregister_checker(self, checker):
290 """Unregister a checker instance."""
289 """Unregister a checker instance."""
291 if checker in self._checkers:
290 if checker in self._checkers:
292 self._checkers.remove(checker)
291 self._checkers.remove(checker)
293
292
294 #-------------------------------------------------------------------------
293 #-------------------------------------------------------------------------
295 # API for managing checkers
294 # API for managing checkers
296 #-------------------------------------------------------------------------
295 #-------------------------------------------------------------------------
297
296
298 def init_handlers(self):
297 def init_handlers(self):
299 """Create the default handlers."""
298 """Create the default handlers."""
300 self._handlers = {}
299 self._handlers = {}
301 self._esc_handlers = {}
300 self._esc_handlers = {}
302 for handler in _default_handlers:
301 for handler in _default_handlers:
303 handler(
302 handler(
304 shell=self.shell, prefilter_manager=self, config=self.config
303 shell=self.shell, prefilter_manager=self, config=self.config
305 )
304 )
306
305
307 @property
306 @property
308 def handlers(self):
307 def handlers(self):
309 """Return a dict of all the handlers."""
308 """Return a dict of all the handlers."""
310 return self._handlers
309 return self._handlers
311
310
312 def register_handler(self, name, handler, esc_strings):
311 def register_handler(self, name, handler, esc_strings):
313 """Register a handler instance by name with esc_strings."""
312 """Register a handler instance by name with esc_strings."""
314 self._handlers[name] = handler
313 self._handlers[name] = handler
315 for esc_str in esc_strings:
314 for esc_str in esc_strings:
316 self._esc_handlers[esc_str] = handler
315 self._esc_handlers[esc_str] = handler
317
316
318 def unregister_handler(self, name, handler, esc_strings):
317 def unregister_handler(self, name, handler, esc_strings):
319 """Unregister a handler instance by name with esc_strings."""
318 """Unregister a handler instance by name with esc_strings."""
320 try:
319 try:
321 del self._handlers[name]
320 del self._handlers[name]
322 except KeyError:
321 except KeyError:
323 pass
322 pass
324 for esc_str in esc_strings:
323 for esc_str in esc_strings:
325 h = self._esc_handlers.get(esc_str)
324 h = self._esc_handlers.get(esc_str)
326 if h is handler:
325 if h is handler:
327 del self._esc_handlers[esc_str]
326 del self._esc_handlers[esc_str]
328
327
329 def get_handler_by_name(self, name):
328 def get_handler_by_name(self, name):
330 """Get a handler by its name."""
329 """Get a handler by its name."""
331 return self._handlers.get(name)
330 return self._handlers.get(name)
332
331
333 def get_handler_by_esc(self, esc_str):
332 def get_handler_by_esc(self, esc_str):
334 """Get a handler by its escape string."""
333 """Get a handler by its escape string."""
335 return self._esc_handlers.get(esc_str)
334 return self._esc_handlers.get(esc_str)
336
335
337 #-------------------------------------------------------------------------
336 #-------------------------------------------------------------------------
338 # Main prefiltering API
337 # Main prefiltering API
339 #-------------------------------------------------------------------------
338 #-------------------------------------------------------------------------
340
339
341 def prefilter_line_info(self, line_info):
340 def prefilter_line_info(self, line_info):
342 """Prefilter a line that has been converted to a LineInfo object.
341 """Prefilter a line that has been converted to a LineInfo object.
343
342
344 This implements the checker/handler part of the prefilter pipe.
343 This implements the checker/handler part of the prefilter pipe.
345 """
344 """
346 # print "prefilter_line_info: ", line_info
345 # print "prefilter_line_info: ", line_info
347 handler = self.find_handler(line_info)
346 handler = self.find_handler(line_info)
348 return handler.handle(line_info)
347 return handler.handle(line_info)
349
348
350 def find_handler(self, line_info):
349 def find_handler(self, line_info):
351 """Find a handler for the line_info by trying checkers."""
350 """Find a handler for the line_info by trying checkers."""
352 for checker in self.checkers:
351 for checker in self.checkers:
353 if checker.enabled:
352 if checker.enabled:
354 handler = checker.check(line_info)
353 handler = checker.check(line_info)
355 if handler:
354 if handler:
356 return handler
355 return handler
357 return self.get_handler_by_name('normal')
356 return self.get_handler_by_name('normal')
358
357
359 def transform_line(self, line, continue_prompt):
358 def transform_line(self, line, continue_prompt):
360 """Calls the enabled transformers in order of increasing priority."""
359 """Calls the enabled transformers in order of increasing priority."""
361 for transformer in self.transformers:
360 for transformer in self.transformers:
362 if transformer.enabled:
361 if transformer.enabled:
363 line = transformer.transform(line, continue_prompt)
362 line = transformer.transform(line, continue_prompt)
364 return line
363 return line
365
364
366 def prefilter_line(self, line, continue_prompt=False):
365 def prefilter_line(self, line, continue_prompt=False):
367 """Prefilter a single input line as text.
366 """Prefilter a single input line as text.
368
367
369 This method prefilters a single line of text by calling the
368 This method prefilters a single line of text by calling the
370 transformers and then the checkers/handlers.
369 transformers and then the checkers/handlers.
371 """
370 """
372
371
373 # print "prefilter_line: ", line, continue_prompt
372 # print "prefilter_line: ", line, continue_prompt
374 # All handlers *must* return a value, even if it's blank ('').
373 # All handlers *must* return a value, even if it's blank ('').
375
374
376 # save the line away in case we crash, so the post-mortem handler can
375 # save the line away in case we crash, so the post-mortem handler can
377 # record it
376 # record it
378 self.shell._last_input_line = line
377 self.shell._last_input_line = line
379
378
380 if not line:
379 if not line:
381 # Return immediately on purely empty lines, so that if the user
380 # Return immediately on purely empty lines, so that if the user
382 # previously typed some whitespace that started a continuation
381 # previously typed some whitespace that started a continuation
383 # prompt, he can break out of that loop with just an empty line.
382 # prompt, he can break out of that loop with just an empty line.
384 # This is how the default python prompt works.
383 # This is how the default python prompt works.
385 return ''
384 return ''
386
385
387 # At this point, we invoke our transformers.
386 # At this point, we invoke our transformers.
388 if not continue_prompt or (continue_prompt and self.multi_line_specials):
387 if not continue_prompt or (continue_prompt and self.multi_line_specials):
389 line = self.transform_line(line, continue_prompt)
388 line = self.transform_line(line, continue_prompt)
390
389
391 # Now we compute line_info for the checkers and handlers
390 # Now we compute line_info for the checkers and handlers
392 line_info = LineInfo(line, continue_prompt)
391 line_info = LineInfo(line, continue_prompt)
393
392
394 # the input history needs to track even empty lines
393 # the input history needs to track even empty lines
395 stripped = line.strip()
394 stripped = line.strip()
396
395
397 normal_handler = self.get_handler_by_name('normal')
396 normal_handler = self.get_handler_by_name('normal')
398 if not stripped:
397 if not stripped:
399 if not continue_prompt:
398 if not continue_prompt:
400 self.shell.displayhook.prompt_count -= 1
399 self.shell.displayhook.prompt_count -= 1
401
400
402 return normal_handler.handle(line_info)
401 return normal_handler.handle(line_info)
403
402
404 # special handlers are only allowed for single line statements
403 # special handlers are only allowed for single line statements
405 if continue_prompt and not self.multi_line_specials:
404 if continue_prompt and not self.multi_line_specials:
406 return normal_handler.handle(line_info)
405 return normal_handler.handle(line_info)
407
406
408 prefiltered = self.prefilter_line_info(line_info)
407 prefiltered = self.prefilter_line_info(line_info)
409 # print "prefiltered line: %r" % prefiltered
408 # print "prefiltered line: %r" % prefiltered
410 return prefiltered
409 return prefiltered
411
410
412 def prefilter_lines(self, lines, continue_prompt=False):
411 def prefilter_lines(self, lines, continue_prompt=False):
413 """Prefilter multiple input lines of text.
412 """Prefilter multiple input lines of text.
414
413
415 This is the main entry point for prefiltering multiple lines of
414 This is the main entry point for prefiltering multiple lines of
416 input. This simply calls :meth:`prefilter_line` for each line of
415 input. This simply calls :meth:`prefilter_line` for each line of
417 input.
416 input.
418
417
419 This covers cases where there are multiple lines in the user entry,
418 This covers cases where there are multiple lines in the user entry,
420 which is the case when the user goes back to a multiline history
419 which is the case when the user goes back to a multiline history
421 entry and presses enter.
420 entry and presses enter.
422 """
421 """
423 llines = lines.rstrip('\n').split('\n')
422 llines = lines.rstrip('\n').split('\n')
424 # We can get multiple lines in one shot, where multiline input 'blends'
423 # We can get multiple lines in one shot, where multiline input 'blends'
425 # into one line, in cases like recalling from the readline history
424 # into one line, in cases like recalling from the readline history
426 # buffer. We need to make sure that in such cases, we correctly
425 # buffer. We need to make sure that in such cases, we correctly
427 # communicate downstream which line is first and which are continuation
426 # communicate downstream which line is first and which are continuation
428 # ones.
427 # ones.
429 if len(llines) > 1:
428 if len(llines) > 1:
430 out = '\n'.join([self.prefilter_line(line, lnum>0)
429 out = '\n'.join([self.prefilter_line(line, lnum>0)
431 for lnum, line in enumerate(llines) ])
430 for lnum, line in enumerate(llines) ])
432 else:
431 else:
433 out = self.prefilter_line(llines[0], continue_prompt)
432 out = self.prefilter_line(llines[0], continue_prompt)
434
433
435 return out
434 return out
436
435
437 #-----------------------------------------------------------------------------
436 #-----------------------------------------------------------------------------
438 # Prefilter transformers
437 # Prefilter transformers
439 #-----------------------------------------------------------------------------
438 #-----------------------------------------------------------------------------
440
439
441
440
442 class PrefilterTransformer(Configurable):
441 class PrefilterTransformer(Configurable):
443 """Transform a line of user input."""
442 """Transform a line of user input."""
444
443
445 priority = Int(100, config=True)
444 priority = Int(100, config=True)
446 # Transformers don't currently use shell or prefilter_manager, but as we
445 # Transformers don't currently use shell or prefilter_manager, but as we
447 # move away from checkers and handlers, they will need them.
446 # move away from checkers and handlers, they will need them.
448 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
447 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
449 prefilter_manager = Instance('IPython.core.prefilter.PrefilterManager')
448 prefilter_manager = Instance('IPython.core.prefilter.PrefilterManager')
450 enabled = Bool(True, config=True)
449 enabled = Bool(True, config=True)
451
450
452 def __init__(self, shell=None, prefilter_manager=None, config=None):
451 def __init__(self, shell=None, prefilter_manager=None, config=None):
453 super(PrefilterTransformer, self).__init__(
452 super(PrefilterTransformer, self).__init__(
454 shell=shell, prefilter_manager=prefilter_manager, config=config
453 shell=shell, prefilter_manager=prefilter_manager, config=config
455 )
454 )
456 self.prefilter_manager.register_transformer(self)
455 self.prefilter_manager.register_transformer(self)
457
456
458 def transform(self, line, continue_prompt):
457 def transform(self, line, continue_prompt):
459 """Transform a line, returning the new one."""
458 """Transform a line, returning the new one."""
460 return None
459 return None
461
460
462 def __repr__(self):
461 def __repr__(self):
463 return "<%s(priority=%r, enabled=%r)>" % (
462 return "<%s(priority=%r, enabled=%r)>" % (
464 self.__class__.__name__, self.priority, self.enabled)
463 self.__class__.__name__, self.priority, self.enabled)
465
464
466
465
467 _assign_system_re = re.compile(r'(?P<lhs>(\s*)([\w\.]+)((\s*,\s*[\w\.]+)*))'
466 _assign_system_re = re.compile(r'(?P<lhs>(\s*)([\w\.]+)((\s*,\s*[\w\.]+)*))'
468 r'\s*=\s*!(?P<cmd>.*)')
467 r'\s*=\s*!(?P<cmd>.*)')
469
468
470
469
471 class AssignSystemTransformer(PrefilterTransformer):
470 class AssignSystemTransformer(PrefilterTransformer):
472 """Handle the `files = !ls` syntax."""
471 """Handle the `files = !ls` syntax."""
473
472
474 priority = Int(100, config=True)
473 priority = Int(100, config=True)
475
474
476 def transform(self, line, continue_prompt):
475 def transform(self, line, continue_prompt):
477 m = _assign_system_re.match(line)
476 m = _assign_system_re.match(line)
478 if m is not None:
477 if m is not None:
479 cmd = m.group('cmd')
478 cmd = m.group('cmd')
480 lhs = m.group('lhs')
479 lhs = m.group('lhs')
481 expr = make_quoted_expr("sc =%s" % cmd)
480 expr = make_quoted_expr("sc =%s" % cmd)
482 new_line = '%s = get_ipython().magic(%s)' % (lhs, expr)
481 new_line = '%s = get_ipython().magic(%s)' % (lhs, expr)
483 return new_line
482 return new_line
484 return line
483 return line
485
484
486
485
487 _assign_magic_re = re.compile(r'(?P<lhs>(\s*)([\w\.]+)((\s*,\s*[\w\.]+)*))'
486 _assign_magic_re = re.compile(r'(?P<lhs>(\s*)([\w\.]+)((\s*,\s*[\w\.]+)*))'
488 r'\s*=\s*%(?P<cmd>.*)')
487 r'\s*=\s*%(?P<cmd>.*)')
489
488
490 class AssignMagicTransformer(PrefilterTransformer):
489 class AssignMagicTransformer(PrefilterTransformer):
491 """Handle the `a = %who` syntax."""
490 """Handle the `a = %who` syntax."""
492
491
493 priority = Int(200, config=True)
492 priority = Int(200, config=True)
494
493
495 def transform(self, line, continue_prompt):
494 def transform(self, line, continue_prompt):
496 m = _assign_magic_re.match(line)
495 m = _assign_magic_re.match(line)
497 if m is not None:
496 if m is not None:
498 cmd = m.group('cmd')
497 cmd = m.group('cmd')
499 lhs = m.group('lhs')
498 lhs = m.group('lhs')
500 expr = make_quoted_expr(cmd)
499 expr = make_quoted_expr(cmd)
501 new_line = '%s = get_ipython().magic(%s)' % (lhs, expr)
500 new_line = '%s = get_ipython().magic(%s)' % (lhs, expr)
502 return new_line
501 return new_line
503 return line
502 return line
504
503
505
504
506 _classic_prompt_re = re.compile(r'(^[ \t]*>>> |^[ \t]*\.\.\. )')
505 _classic_prompt_re = re.compile(r'(^[ \t]*>>> |^[ \t]*\.\.\. )')
507
506
508 class PyPromptTransformer(PrefilterTransformer):
507 class PyPromptTransformer(PrefilterTransformer):
509 """Handle inputs that start with '>>> ' syntax."""
508 """Handle inputs that start with '>>> ' syntax."""
510
509
511 priority = Int(50, config=True)
510 priority = Int(50, config=True)
512
511
513 def transform(self, line, continue_prompt):
512 def transform(self, line, continue_prompt):
514
513
515 if not line or line.isspace() or line.strip() == '...':
514 if not line or line.isspace() or line.strip() == '...':
516 # This allows us to recognize multiple input prompts separated by
515 # This allows us to recognize multiple input prompts separated by
517 # blank lines and pasted in a single chunk, very common when
516 # blank lines and pasted in a single chunk, very common when
518 # pasting doctests or long tutorial passages.
517 # pasting doctests or long tutorial passages.
519 return ''
518 return ''
520 m = _classic_prompt_re.match(line)
519 m = _classic_prompt_re.match(line)
521 if m:
520 if m:
522 return line[len(m.group(0)):]
521 return line[len(m.group(0)):]
523 else:
522 else:
524 return line
523 return line
525
524
526
525
527 _ipy_prompt_re = re.compile(r'(^[ \t]*In \[\d+\]: |^[ \t]*\ \ \ \.\.\.+: )')
526 _ipy_prompt_re = re.compile(r'(^[ \t]*In \[\d+\]: |^[ \t]*\ \ \ \.\.\.+: )')
528
527
529 class IPyPromptTransformer(PrefilterTransformer):
528 class IPyPromptTransformer(PrefilterTransformer):
530 """Handle inputs that start classic IPython prompt syntax."""
529 """Handle inputs that start classic IPython prompt syntax."""
531
530
532 priority = Int(50, config=True)
531 priority = Int(50, config=True)
533
532
534 def transform(self, line, continue_prompt):
533 def transform(self, line, continue_prompt):
535
534
536 if not line or line.isspace() or line.strip() == '...':
535 if not line or line.isspace() or line.strip() == '...':
537 # This allows us to recognize multiple input prompts separated by
536 # This allows us to recognize multiple input prompts separated by
538 # blank lines and pasted in a single chunk, very common when
537 # blank lines and pasted in a single chunk, very common when
539 # pasting doctests or long tutorial passages.
538 # pasting doctests or long tutorial passages.
540 return ''
539 return ''
541 m = _ipy_prompt_re.match(line)
540 m = _ipy_prompt_re.match(line)
542 if m:
541 if m:
543 return line[len(m.group(0)):]
542 return line[len(m.group(0)):]
544 else:
543 else:
545 return line
544 return line
546
545
547 #-----------------------------------------------------------------------------
546 #-----------------------------------------------------------------------------
548 # Prefilter checkers
547 # Prefilter checkers
549 #-----------------------------------------------------------------------------
548 #-----------------------------------------------------------------------------
550
549
551
550
552 class PrefilterChecker(Configurable):
551 class PrefilterChecker(Configurable):
553 """Inspect an input line and return a handler for that line."""
552 """Inspect an input line and return a handler for that line."""
554
553
555 priority = Int(100, config=True)
554 priority = Int(100, config=True)
556 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
555 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
557 prefilter_manager = Instance('IPython.core.prefilter.PrefilterManager')
556 prefilter_manager = Instance('IPython.core.prefilter.PrefilterManager')
558 enabled = Bool(True, config=True)
557 enabled = Bool(True, config=True)
559
558
560 def __init__(self, shell=None, prefilter_manager=None, config=None):
559 def __init__(self, shell=None, prefilter_manager=None, config=None):
561 super(PrefilterChecker, self).__init__(
560 super(PrefilterChecker, self).__init__(
562 shell=shell, prefilter_manager=prefilter_manager, config=config
561 shell=shell, prefilter_manager=prefilter_manager, config=config
563 )
562 )
564 self.prefilter_manager.register_checker(self)
563 self.prefilter_manager.register_checker(self)
565
564
566 def check(self, line_info):
565 def check(self, line_info):
567 """Inspect line_info and return a handler instance or None."""
566 """Inspect line_info and return a handler instance or None."""
568 return None
567 return None
569
568
570 def __repr__(self):
569 def __repr__(self):
571 return "<%s(priority=%r, enabled=%r)>" % (
570 return "<%s(priority=%r, enabled=%r)>" % (
572 self.__class__.__name__, self.priority, self.enabled)
571 self.__class__.__name__, self.priority, self.enabled)
573
572
574
573
575 class EmacsChecker(PrefilterChecker):
574 class EmacsChecker(PrefilterChecker):
576
575
577 priority = Int(100, config=True)
576 priority = Int(100, config=True)
578 enabled = Bool(False, config=True)
577 enabled = Bool(False, config=True)
579
578
580 def check(self, line_info):
579 def check(self, line_info):
581 "Emacs ipython-mode tags certain input lines."
580 "Emacs ipython-mode tags certain input lines."
582 if line_info.line.endswith('# PYTHON-MODE'):
581 if line_info.line.endswith('# PYTHON-MODE'):
583 return self.prefilter_manager.get_handler_by_name('emacs')
582 return self.prefilter_manager.get_handler_by_name('emacs')
584 else:
583 else:
585 return None
584 return None
586
585
587
586
588 class ShellEscapeChecker(PrefilterChecker):
587 class ShellEscapeChecker(PrefilterChecker):
589
588
590 priority = Int(200, config=True)
589 priority = Int(200, config=True)
591
590
592 def check(self, line_info):
591 def check(self, line_info):
593 if line_info.line.lstrip().startswith(ESC_SHELL):
592 if line_info.line.lstrip().startswith(ESC_SHELL):
594 return self.prefilter_manager.get_handler_by_name('shell')
593 return self.prefilter_manager.get_handler_by_name('shell')
595
594
596
595
597 class MacroChecker(PrefilterChecker):
596 class MacroChecker(PrefilterChecker):
598
597
599 priority = Int(250, config=True)
598 priority = Int(250, config=True)
600
599
601 def check(self, line_info):
600 def check(self, line_info):
602 obj = self.shell.user_ns.get(line_info.ifun)
601 obj = self.shell.user_ns.get(line_info.ifun)
603 if isinstance(obj, Macro):
602 if isinstance(obj, Macro):
604 return self.prefilter_manager.get_handler_by_name('macro')
603 return self.prefilter_manager.get_handler_by_name('macro')
605 else:
604 else:
606 return None
605 return None
607
606
608
607
609 class IPyAutocallChecker(PrefilterChecker):
608 class IPyAutocallChecker(PrefilterChecker):
610
609
611 priority = Int(300, config=True)
610 priority = Int(300, config=True)
612
611
613 def check(self, line_info):
612 def check(self, line_info):
614 "Instances of IPyAutocall in user_ns get autocalled immediately"
613 "Instances of IPyAutocall in user_ns get autocalled immediately"
615 obj = self.shell.user_ns.get(line_info.ifun, None)
614 obj = self.shell.user_ns.get(line_info.ifun, None)
616 if isinstance(obj, IPyAutocall):
615 if isinstance(obj, IPyAutocall):
617 obj.set_ip(self.shell)
616 obj.set_ip(self.shell)
618 return self.prefilter_manager.get_handler_by_name('auto')
617 return self.prefilter_manager.get_handler_by_name('auto')
619 else:
618 else:
620 return None
619 return None
621
620
622
621
623 class MultiLineMagicChecker(PrefilterChecker):
622 class MultiLineMagicChecker(PrefilterChecker):
624
623
625 priority = Int(400, config=True)
624 priority = Int(400, config=True)
626
625
627 def check(self, line_info):
626 def check(self, line_info):
628 "Allow ! and !! in multi-line statements if multi_line_specials is on"
627 "Allow ! and !! in multi-line statements if multi_line_specials is on"
629 # Note that this one of the only places we check the first character of
628 # Note that this one of the only places we check the first character of
630 # ifun and *not* the pre_char. Also note that the below test matches
629 # ifun and *not* the pre_char. Also note that the below test matches
631 # both ! and !!.
630 # both ! and !!.
632 if line_info.continue_prompt \
631 if line_info.continue_prompt \
633 and self.prefilter_manager.multi_line_specials:
632 and self.prefilter_manager.multi_line_specials:
634 if line_info.ifun.startswith(ESC_MAGIC):
633 if line_info.ifun.startswith(ESC_MAGIC):
635 return self.prefilter_manager.get_handler_by_name('magic')
634 return self.prefilter_manager.get_handler_by_name('magic')
636 else:
635 else:
637 return None
636 return None
638
637
639
638
640 class EscCharsChecker(PrefilterChecker):
639 class EscCharsChecker(PrefilterChecker):
641
640
642 priority = Int(500, config=True)
641 priority = Int(500, config=True)
643
642
644 def check(self, line_info):
643 def check(self, line_info):
645 """Check for escape character and return either a handler to handle it,
644 """Check for escape character and return either a handler to handle it,
646 or None if there is no escape char."""
645 or None if there is no escape char."""
647 if line_info.line[-1] == ESC_HELP \
646 if line_info.line[-1] == ESC_HELP \
648 and line_info.pre_char != ESC_SHELL \
647 and line_info.pre_char != ESC_SHELL \
649 and line_info.pre_char != ESC_SH_CAP:
648 and line_info.pre_char != ESC_SH_CAP:
650 # the ? can be at the end, but *not* for either kind of shell escape,
649 # the ? can be at the end, but *not* for either kind of shell escape,
651 # because a ? can be a vaild final char in a shell cmd
650 # because a ? can be a vaild final char in a shell cmd
652 return self.prefilter_manager.get_handler_by_name('help')
651 return self.prefilter_manager.get_handler_by_name('help')
653 else:
652 else:
654 # This returns None like it should if no handler exists
653 # This returns None like it should if no handler exists
655 return self.prefilter_manager.get_handler_by_esc(line_info.pre_char)
654 return self.prefilter_manager.get_handler_by_esc(line_info.pre_char)
656
655
657
656
658 class AssignmentChecker(PrefilterChecker):
657 class AssignmentChecker(PrefilterChecker):
659
658
660 priority = Int(600, config=True)
659 priority = Int(600, config=True)
661
660
662 def check(self, line_info):
661 def check(self, line_info):
663 """Check to see if user is assigning to a var for the first time, in
662 """Check to see if user is assigning to a var for the first time, in
664 which case we want to avoid any sort of automagic / autocall games.
663 which case we want to avoid any sort of automagic / autocall games.
665
664
666 This allows users to assign to either alias or magic names true python
665 This allows users to assign to either alias or magic names true python
667 variables (the magic/alias systems always take second seat to true
666 variables (the magic/alias systems always take second seat to true
668 python code). E.g. ls='hi', or ls,that=1,2"""
667 python code). E.g. ls='hi', or ls,that=1,2"""
669 if line_info.the_rest:
668 if line_info.the_rest:
670 if line_info.the_rest[0] in '=,':
669 if line_info.the_rest[0] in '=,':
671 return self.prefilter_manager.get_handler_by_name('normal')
670 return self.prefilter_manager.get_handler_by_name('normal')
672 else:
671 else:
673 return None
672 return None
674
673
675
674
676 class AutoMagicChecker(PrefilterChecker):
675 class AutoMagicChecker(PrefilterChecker):
677
676
678 priority = Int(700, config=True)
677 priority = Int(700, config=True)
679
678
680 def check(self, line_info):
679 def check(self, line_info):
681 """If the ifun is magic, and automagic is on, run it. Note: normal,
680 """If the ifun is magic, and automagic is on, run it. Note: normal,
682 non-auto magic would already have been triggered via '%' in
681 non-auto magic would already have been triggered via '%' in
683 check_esc_chars. This just checks for automagic. Also, before
682 check_esc_chars. This just checks for automagic. Also, before
684 triggering the magic handler, make sure that there is nothing in the
683 triggering the magic handler, make sure that there is nothing in the
685 user namespace which could shadow it."""
684 user namespace which could shadow it."""
686 if not self.shell.automagic or not hasattr(self.shell,'magic_'+line_info.ifun):
685 if not self.shell.automagic or not hasattr(self.shell,'magic_'+line_info.ifun):
687 return None
686 return None
688
687
689 # We have a likely magic method. Make sure we should actually call it.
688 # We have a likely magic method. Make sure we should actually call it.
690 if line_info.continue_prompt and not self.prefilter_manager.multi_line_specials:
689 if line_info.continue_prompt and not self.prefilter_manager.multi_line_specials:
691 return None
690 return None
692
691
693 head = line_info.ifun.split('.',1)[0]
692 head = line_info.ifun.split('.',1)[0]
694 if is_shadowed(head, self.shell):
693 if is_shadowed(head, self.shell):
695 return None
694 return None
696
695
697 return self.prefilter_manager.get_handler_by_name('magic')
696 return self.prefilter_manager.get_handler_by_name('magic')
698
697
699
698
700 class AliasChecker(PrefilterChecker):
699 class AliasChecker(PrefilterChecker):
701
700
702 priority = Int(800, config=True)
701 priority = Int(800, config=True)
703
702
704 def check(self, line_info):
703 def check(self, line_info):
705 "Check if the initital identifier on the line is an alias."
704 "Check if the initital identifier on the line is an alias."
706 # Note: aliases can not contain '.'
705 # Note: aliases can not contain '.'
707 head = line_info.ifun.split('.',1)[0]
706 head = line_info.ifun.split('.',1)[0]
708 if line_info.ifun not in self.shell.alias_manager \
707 if line_info.ifun not in self.shell.alias_manager \
709 or head not in self.shell.alias_manager \
708 or head not in self.shell.alias_manager \
710 or is_shadowed(head, self.shell):
709 or is_shadowed(head, self.shell):
711 return None
710 return None
712
711
713 return self.prefilter_manager.get_handler_by_name('alias')
712 return self.prefilter_manager.get_handler_by_name('alias')
714
713
715
714
716 class PythonOpsChecker(PrefilterChecker):
715 class PythonOpsChecker(PrefilterChecker):
717
716
718 priority = Int(900, config=True)
717 priority = Int(900, config=True)
719
718
720 def check(self, line_info):
719 def check(self, line_info):
721 """If the 'rest' of the line begins with a function call or pretty much
720 """If the 'rest' of the line begins with a function call or pretty much
722 any python operator, we should simply execute the line (regardless of
721 any python operator, we should simply execute the line (regardless of
723 whether or not there's a possible autocall expansion). This avoids
722 whether or not there's a possible autocall expansion). This avoids
724 spurious (and very confusing) geattr() accesses."""
723 spurious (and very confusing) geattr() accesses."""
725 if line_info.the_rest and line_info.the_rest[0] in '!=()<>,+*/%^&|':
724 if line_info.the_rest and line_info.the_rest[0] in '!=()<>,+*/%^&|':
726 return self.prefilter_manager.get_handler_by_name('normal')
725 return self.prefilter_manager.get_handler_by_name('normal')
727 else:
726 else:
728 return None
727 return None
729
728
730
729
731 class AutocallChecker(PrefilterChecker):
730 class AutocallChecker(PrefilterChecker):
732
731
733 priority = Int(1000, config=True)
732 priority = Int(1000, config=True)
734
733
735 def check(self, line_info):
734 def check(self, line_info):
736 "Check if the initial word/function is callable and autocall is on."
735 "Check if the initial word/function is callable and autocall is on."
737 if not self.shell.autocall:
736 if not self.shell.autocall:
738 return None
737 return None
739
738
740 oinfo = line_info.ofind(self.shell) # This can mutate state via getattr
739 oinfo = line_info.ofind(self.shell) # This can mutate state via getattr
741 if not oinfo['found']:
740 if not oinfo['found']:
742 return None
741 return None
743
742
744 if callable(oinfo['obj']) \
743 if callable(oinfo['obj']) \
745 and (not re_exclude_auto.match(line_info.the_rest)) \
744 and (not re_exclude_auto.match(line_info.the_rest)) \
746 and re_fun_name.match(line_info.ifun):
745 and re_fun_name.match(line_info.ifun):
747 return self.prefilter_manager.get_handler_by_name('auto')
746 return self.prefilter_manager.get_handler_by_name('auto')
748 else:
747 else:
749 return None
748 return None
750
749
751
750
752 #-----------------------------------------------------------------------------
751 #-----------------------------------------------------------------------------
753 # Prefilter handlers
752 # Prefilter handlers
754 #-----------------------------------------------------------------------------
753 #-----------------------------------------------------------------------------
755
754
756
755
757 class PrefilterHandler(Configurable):
756 class PrefilterHandler(Configurable):
758
757
759 handler_name = Unicode('normal')
758 handler_name = Unicode('normal')
760 esc_strings = List([])
759 esc_strings = List([])
761 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
760 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
762 prefilter_manager = Instance('IPython.core.prefilter.PrefilterManager')
761 prefilter_manager = Instance('IPython.core.prefilter.PrefilterManager')
763
762
764 def __init__(self, shell=None, prefilter_manager=None, config=None):
763 def __init__(self, shell=None, prefilter_manager=None, config=None):
765 super(PrefilterHandler, self).__init__(
764 super(PrefilterHandler, self).__init__(
766 shell=shell, prefilter_manager=prefilter_manager, config=config
765 shell=shell, prefilter_manager=prefilter_manager, config=config
767 )
766 )
768 self.prefilter_manager.register_handler(
767 self.prefilter_manager.register_handler(
769 self.handler_name,
768 self.handler_name,
770 self,
769 self,
771 self.esc_strings
770 self.esc_strings
772 )
771 )
773
772
774 def handle(self, line_info):
773 def handle(self, line_info):
775 # print "normal: ", line_info
774 # print "normal: ", line_info
776 """Handle normal input lines. Use as a template for handlers."""
775 """Handle normal input lines. Use as a template for handlers."""
777
776
778 # With autoindent on, we need some way to exit the input loop, and I
777 # With autoindent on, we need some way to exit the input loop, and I
779 # don't want to force the user to have to backspace all the way to
778 # don't want to force the user to have to backspace all the way to
780 # clear the line. The rule will be in this case, that either two
779 # clear the line. The rule will be in this case, that either two
781 # lines of pure whitespace in a row, or a line of pure whitespace but
780 # lines of pure whitespace in a row, or a line of pure whitespace but
782 # of a size different to the indent level, will exit the input loop.
781 # of a size different to the indent level, will exit the input loop.
783 line = line_info.line
782 line = line_info.line
784 continue_prompt = line_info.continue_prompt
783 continue_prompt = line_info.continue_prompt
785
784
786 if (continue_prompt and
785 if (continue_prompt and
787 self.shell.autoindent and
786 self.shell.autoindent and
788 line.isspace() and
787 line.isspace() and
789 0 < abs(len(line) - self.shell.indent_current_nsp) <= 2):
788 0 < abs(len(line) - self.shell.indent_current_nsp) <= 2):
790 line = ''
789 line = ''
791
790
792 return line
791 return line
793
792
794 def __str__(self):
793 def __str__(self):
795 return "<%s(name=%s)>" % (self.__class__.__name__, self.handler_name)
794 return "<%s(name=%s)>" % (self.__class__.__name__, self.handler_name)
796
795
797
796
798 class AliasHandler(PrefilterHandler):
797 class AliasHandler(PrefilterHandler):
799
798
800 handler_name = Unicode('alias')
799 handler_name = Unicode('alias')
801
800
802 def handle(self, line_info):
801 def handle(self, line_info):
803 """Handle alias input lines. """
802 """Handle alias input lines. """
804 transformed = self.shell.alias_manager.expand_aliases(line_info.ifun,line_info.the_rest)
803 transformed = self.shell.alias_manager.expand_aliases(line_info.ifun,line_info.the_rest)
805 # pre is needed, because it carries the leading whitespace. Otherwise
804 # pre is needed, because it carries the leading whitespace. Otherwise
806 # aliases won't work in indented sections.
805 # aliases won't work in indented sections.
807 line_out = '%sget_ipython().system(%s)' % (line_info.pre_whitespace,
806 line_out = '%sget_ipython().system(%s)' % (line_info.pre_whitespace,
808 make_quoted_expr(transformed))
807 make_quoted_expr(transformed))
809
808
810 return line_out
809 return line_out
811
810
812
811
813 class ShellEscapeHandler(PrefilterHandler):
812 class ShellEscapeHandler(PrefilterHandler):
814
813
815 handler_name = Unicode('shell')
814 handler_name = Unicode('shell')
816 esc_strings = List([ESC_SHELL, ESC_SH_CAP])
815 esc_strings = List([ESC_SHELL, ESC_SH_CAP])
817
816
818 def handle(self, line_info):
817 def handle(self, line_info):
819 """Execute the line in a shell, empty return value"""
818 """Execute the line in a shell, empty return value"""
820 magic_handler = self.prefilter_manager.get_handler_by_name('magic')
819 magic_handler = self.prefilter_manager.get_handler_by_name('magic')
821
820
822 line = line_info.line
821 line = line_info.line
823 if line.lstrip().startswith(ESC_SH_CAP):
822 if line.lstrip().startswith(ESC_SH_CAP):
824 # rewrite LineInfo's line, ifun and the_rest to properly hold the
823 # rewrite LineInfo's line, ifun and the_rest to properly hold the
825 # call to %sx and the actual command to be executed, so
824 # call to %sx and the actual command to be executed, so
826 # handle_magic can work correctly. Note that this works even if
825 # handle_magic can work correctly. Note that this works even if
827 # the line is indented, so it handles multi_line_specials
826 # the line is indented, so it handles multi_line_specials
828 # properly.
827 # properly.
829 new_rest = line.lstrip()[2:]
828 new_rest = line.lstrip()[2:]
830 line_info.line = '%ssx %s' % (ESC_MAGIC, new_rest)
829 line_info.line = '%ssx %s' % (ESC_MAGIC, new_rest)
831 line_info.ifun = 'sx'
830 line_info.ifun = 'sx'
832 line_info.the_rest = new_rest
831 line_info.the_rest = new_rest
833 return magic_handler.handle(line_info)
832 return magic_handler.handle(line_info)
834 else:
833 else:
835 cmd = line.lstrip().lstrip(ESC_SHELL)
834 cmd = line.lstrip().lstrip(ESC_SHELL)
836 line_out = '%sget_ipython().system(%s)' % (line_info.pre_whitespace,
835 line_out = '%sget_ipython().system(%s)' % (line_info.pre_whitespace,
837 make_quoted_expr(cmd))
836 make_quoted_expr(cmd))
838 return line_out
837 return line_out
839
838
840
839
841 class MacroHandler(PrefilterHandler):
840 class MacroHandler(PrefilterHandler):
842 handler_name = Unicode("macro")
841 handler_name = Unicode("macro")
843
842
844 def handle(self, line_info):
843 def handle(self, line_info):
845 obj = self.shell.user_ns.get(line_info.ifun)
844 obj = self.shell.user_ns.get(line_info.ifun)
846 pre_space = line_info.pre_whitespace
845 pre_space = line_info.pre_whitespace
847 line_sep = "\n" + pre_space
846 line_sep = "\n" + pre_space
848 return pre_space + line_sep.join(obj.value.splitlines())
847 return pre_space + line_sep.join(obj.value.splitlines())
849
848
850
849
851 class MagicHandler(PrefilterHandler):
850 class MagicHandler(PrefilterHandler):
852
851
853 handler_name = Unicode('magic')
852 handler_name = Unicode('magic')
854 esc_strings = List([ESC_MAGIC])
853 esc_strings = List([ESC_MAGIC])
855
854
856 def handle(self, line_info):
855 def handle(self, line_info):
857 """Execute magic functions."""
856 """Execute magic functions."""
858 ifun = line_info.ifun
857 ifun = line_info.ifun
859 the_rest = line_info.the_rest
858 the_rest = line_info.the_rest
860 cmd = '%sget_ipython().magic(%s)' % (line_info.pre_whitespace,
859 cmd = '%sget_ipython().magic(%s)' % (line_info.pre_whitespace,
861 make_quoted_expr(ifun + " " + the_rest))
860 make_quoted_expr(ifun + " " + the_rest))
862 return cmd
861 return cmd
863
862
864
863
865 class AutoHandler(PrefilterHandler):
864 class AutoHandler(PrefilterHandler):
866
865
867 handler_name = Unicode('auto')
866 handler_name = Unicode('auto')
868 esc_strings = List([ESC_PAREN, ESC_QUOTE, ESC_QUOTE2])
867 esc_strings = List([ESC_PAREN, ESC_QUOTE, ESC_QUOTE2])
869
868
870 def handle(self, line_info):
869 def handle(self, line_info):
871 """Handle lines which can be auto-executed, quoting if requested."""
870 """Handle lines which can be auto-executed, quoting if requested."""
872 line = line_info.line
871 line = line_info.line
873 ifun = line_info.ifun
872 ifun = line_info.ifun
874 the_rest = line_info.the_rest
873 the_rest = line_info.the_rest
875 pre = line_info.pre
874 pre = line_info.pre
876 continue_prompt = line_info.continue_prompt
875 continue_prompt = line_info.continue_prompt
877 obj = line_info.ofind(self)['obj']
876 obj = line_info.ofind(self)['obj']
878 #print 'pre <%s> ifun <%s> rest <%s>' % (pre,ifun,the_rest) # dbg
877 #print 'pre <%s> ifun <%s> rest <%s>' % (pre,ifun,the_rest) # dbg
879
878
880 # This should only be active for single-line input!
879 # This should only be active for single-line input!
881 if continue_prompt:
880 if continue_prompt:
882 return line
881 return line
883
882
884 force_auto = isinstance(obj, IPyAutocall)
883 force_auto = isinstance(obj, IPyAutocall)
885 auto_rewrite = getattr(obj, 'rewrite', True)
884 auto_rewrite = getattr(obj, 'rewrite', True)
886
885
887 if pre == ESC_QUOTE:
886 if pre == ESC_QUOTE:
888 # Auto-quote splitting on whitespace
887 # Auto-quote splitting on whitespace
889 newcmd = '%s("%s")' % (ifun,'", "'.join(the_rest.split()) )
888 newcmd = '%s("%s")' % (ifun,'", "'.join(the_rest.split()) )
890 elif pre == ESC_QUOTE2:
889 elif pre == ESC_QUOTE2:
891 # Auto-quote whole string
890 # Auto-quote whole string
892 newcmd = '%s("%s")' % (ifun,the_rest)
891 newcmd = '%s("%s")' % (ifun,the_rest)
893 elif pre == ESC_PAREN:
892 elif pre == ESC_PAREN:
894 newcmd = '%s(%s)' % (ifun,",".join(the_rest.split()))
893 newcmd = '%s(%s)' % (ifun,",".join(the_rest.split()))
895 else:
894 else:
896 # Auto-paren.
895 # Auto-paren.
897 # We only apply it to argument-less calls if the autocall
896 # We only apply it to argument-less calls if the autocall
898 # parameter is set to 2. We only need to check that autocall is <
897 # parameter is set to 2. We only need to check that autocall is <
899 # 2, since this function isn't called unless it's at least 1.
898 # 2, since this function isn't called unless it's at least 1.
900 if not the_rest and (self.shell.autocall < 2) and not force_auto:
899 if not the_rest and (self.shell.autocall < 2) and not force_auto:
901 newcmd = '%s %s' % (ifun,the_rest)
900 newcmd = '%s %s' % (ifun,the_rest)
902 auto_rewrite = False
901 auto_rewrite = False
903 else:
902 else:
904 if not force_auto and the_rest.startswith('['):
903 if not force_auto and the_rest.startswith('['):
905 if hasattr(obj,'__getitem__'):
904 if hasattr(obj,'__getitem__'):
906 # Don't autocall in this case: item access for an object
905 # Don't autocall in this case: item access for an object
907 # which is BOTH callable and implements __getitem__.
906 # which is BOTH callable and implements __getitem__.
908 newcmd = '%s %s' % (ifun,the_rest)
907 newcmd = '%s %s' % (ifun,the_rest)
909 auto_rewrite = False
908 auto_rewrite = False
910 else:
909 else:
911 # if the object doesn't support [] access, go ahead and
910 # if the object doesn't support [] access, go ahead and
912 # autocall
911 # autocall
913 newcmd = '%s(%s)' % (ifun.rstrip(),the_rest)
912 newcmd = '%s(%s)' % (ifun.rstrip(),the_rest)
914 elif the_rest.endswith(';'):
913 elif the_rest.endswith(';'):
915 newcmd = '%s(%s);' % (ifun.rstrip(),the_rest[:-1])
914 newcmd = '%s(%s);' % (ifun.rstrip(),the_rest[:-1])
916 else:
915 else:
917 newcmd = '%s(%s)' % (ifun.rstrip(), the_rest)
916 newcmd = '%s(%s)' % (ifun.rstrip(), the_rest)
918
917
919 if auto_rewrite:
918 if auto_rewrite:
920 self.shell.auto_rewrite_input(newcmd)
919 self.shell.auto_rewrite_input(newcmd)
921
920
922 return newcmd
921 return newcmd
923
922
924
923
925 class HelpHandler(PrefilterHandler):
924 class HelpHandler(PrefilterHandler):
926
925
927 handler_name = Unicode('help')
926 handler_name = Unicode('help')
928 esc_strings = List([ESC_HELP])
927 esc_strings = List([ESC_HELP])
929
928
930 def handle(self, line_info):
929 def handle(self, line_info):
931 """Try to get some help for the object.
930 """Try to get some help for the object.
932
931
933 obj? or ?obj -> basic information.
932 obj? or ?obj -> basic information.
934 obj?? or ??obj -> more details.
933 obj?? or ??obj -> more details.
935 """
934 """
936 normal_handler = self.prefilter_manager.get_handler_by_name('normal')
935 normal_handler = self.prefilter_manager.get_handler_by_name('normal')
937 line = line_info.line
936 line = line_info.line
938 # We need to make sure that we don't process lines which would be
937 # We need to make sure that we don't process lines which would be
939 # otherwise valid python, such as "x=1 # what?"
938 # otherwise valid python, such as "x=1 # what?"
940 try:
939 try:
941 codeop.compile_command(line)
940 codeop.compile_command(line)
942 except SyntaxError:
941 except SyntaxError:
943 # We should only handle as help stuff which is NOT valid syntax
942 # We should only handle as help stuff which is NOT valid syntax
944 if line[0]==ESC_HELP:
943 if line[0]==ESC_HELP:
945 line = line[1:]
944 line = line[1:]
946 elif line[-1]==ESC_HELP:
945 elif line[-1]==ESC_HELP:
947 line = line[:-1]
946 line = line[:-1]
948 if line:
947 if line:
949 #print 'line:<%r>' % line # dbg
948 #print 'line:<%r>' % line # dbg
950 self.shell.magic_pinfo(line)
949 self.shell.magic_pinfo(line)
951 else:
950 else:
952 self.shell.show_usage()
951 self.shell.show_usage()
953 return '' # Empty string is needed here!
952 return '' # Empty string is needed here!
954 except:
953 except:
955 raise
954 raise
956 # Pass any other exceptions through to the normal handler
955 # Pass any other exceptions through to the normal handler
957 return normal_handler.handle(line_info)
956 return normal_handler.handle(line_info)
958 else:
957 else:
959 # If the code compiles ok, we should handle it normally
958 # If the code compiles ok, we should handle it normally
960 return normal_handler.handle(line_info)
959 return normal_handler.handle(line_info)
961
960
962
961
963 class EmacsHandler(PrefilterHandler):
962 class EmacsHandler(PrefilterHandler):
964
963
965 handler_name = Unicode('emacs')
964 handler_name = Unicode('emacs')
966 esc_strings = List([])
965 esc_strings = List([])
967
966
968 def handle(self, line_info):
967 def handle(self, line_info):
969 """Handle input lines marked by python-mode."""
968 """Handle input lines marked by python-mode."""
970
969
971 # Currently, nothing is done. Later more functionality can be added
970 # Currently, nothing is done. Later more functionality can be added
972 # here if needed.
971 # here if needed.
973
972
974 # The input cache shouldn't be updated
973 # The input cache shouldn't be updated
975 return line_info.line
974 return line_info.line
976
975
977
976
978 #-----------------------------------------------------------------------------
977 #-----------------------------------------------------------------------------
979 # Defaults
978 # Defaults
980 #-----------------------------------------------------------------------------
979 #-----------------------------------------------------------------------------
981
980
982
981
983 _default_transformers = [
982 _default_transformers = [
984 AssignSystemTransformer,
983 AssignSystemTransformer,
985 AssignMagicTransformer,
984 AssignMagicTransformer,
986 PyPromptTransformer,
985 PyPromptTransformer,
987 IPyPromptTransformer,
986 IPyPromptTransformer,
988 ]
987 ]
989
988
990 _default_checkers = [
989 _default_checkers = [
991 EmacsChecker,
990 EmacsChecker,
992 ShellEscapeChecker,
991 ShellEscapeChecker,
993 MacroChecker,
992 MacroChecker,
994 IPyAutocallChecker,
993 IPyAutocallChecker,
995 MultiLineMagicChecker,
994 MultiLineMagicChecker,
996 EscCharsChecker,
995 EscCharsChecker,
997 AssignmentChecker,
996 AssignmentChecker,
998 AutoMagicChecker,
997 AutoMagicChecker,
999 AliasChecker,
998 AliasChecker,
1000 PythonOpsChecker,
999 PythonOpsChecker,
1001 AutocallChecker
1000 AutocallChecker
1002 ]
1001 ]
1003
1002
1004 _default_handlers = [
1003 _default_handlers = [
1005 PrefilterHandler,
1004 PrefilterHandler,
1006 AliasHandler,
1005 AliasHandler,
1007 ShellEscapeHandler,
1006 ShellEscapeHandler,
1008 MacroHandler,
1007 MacroHandler,
1009 MagicHandler,
1008 MagicHandler,
1010 AutoHandler,
1009 AutoHandler,
1011 HelpHandler,
1010 HelpHandler,
1012 EmacsHandler
1011 EmacsHandler
1013 ]
1012 ]
@@ -1,253 +1,252 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 A mixin for :class:`~IPython.core.application.Application` classes that
3 A mixin for :class:`~IPython.core.application.Application` classes that
5 launch InteractiveShell instances, load extensions, etc.
4 launch InteractiveShell instances, load extensions, etc.
6
5
7 Authors
6 Authors
8 -------
7 -------
9
8
10 * Min Ragan-Kelley
9 * Min Ragan-Kelley
11 """
10 """
12
11
13 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
14 # Copyright (C) 2008-2011 The IPython Development Team
13 # Copyright (C) 2008-2011 The IPython Development Team
15 #
14 #
16 # Distributed under the terms of the BSD License. The full license is in
15 # Distributed under the terms of the BSD License. The full license is in
17 # the file COPYING, distributed as part of this software.
16 # the file COPYING, distributed as part of this software.
18 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
19
18
20 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
21 # Imports
20 # Imports
22 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
23
22
24 from __future__ import absolute_import
23 from __future__ import absolute_import
25
24
26 import os
25 import os
27 import sys
26 import sys
28
27
29 from IPython.config.application import boolean_flag
28 from IPython.config.application import boolean_flag
30 from IPython.config.configurable import Configurable
29 from IPython.config.configurable import Configurable
31 from IPython.config.loader import Config
30 from IPython.config.loader import Config
32 from IPython.utils.path import filefind
31 from IPython.utils.path import filefind
33 from IPython.utils.traitlets import Unicode, Instance, List
32 from IPython.utils.traitlets import Unicode, Instance, List
34
33
35 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
36 # Aliases and Flags
35 # Aliases and Flags
37 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
38
37
39 shell_flags = {}
38 shell_flags = {}
40
39
41 addflag = lambda *args: shell_flags.update(boolean_flag(*args))
40 addflag = lambda *args: shell_flags.update(boolean_flag(*args))
42 addflag('autoindent', 'InteractiveShell.autoindent',
41 addflag('autoindent', 'InteractiveShell.autoindent',
43 'Turn on autoindenting.', 'Turn off autoindenting.'
42 'Turn on autoindenting.', 'Turn off autoindenting.'
44 )
43 )
45 addflag('automagic', 'InteractiveShell.automagic',
44 addflag('automagic', 'InteractiveShell.automagic',
46 """Turn on the auto calling of magic commands. Type %%magic at the
45 """Turn on the auto calling of magic commands. Type %%magic at the
47 IPython prompt for more information.""",
46 IPython prompt for more information.""",
48 'Turn off the auto calling of magic commands.'
47 'Turn off the auto calling of magic commands.'
49 )
48 )
50 addflag('pdb', 'InteractiveShell.pdb',
49 addflag('pdb', 'InteractiveShell.pdb',
51 "Enable auto calling the pdb debugger after every exception.",
50 "Enable auto calling the pdb debugger after every exception.",
52 "Disable auto calling the pdb debugger after every exception."
51 "Disable auto calling the pdb debugger after every exception."
53 )
52 )
54 addflag('pprint', 'PlainTextFormatter.pprint',
53 addflag('pprint', 'PlainTextFormatter.pprint',
55 "Enable auto pretty printing of results.",
54 "Enable auto pretty printing of results.",
56 "Disable auto auto pretty printing of results."
55 "Disable auto auto pretty printing of results."
57 )
56 )
58 addflag('color-info', 'InteractiveShell.color_info',
57 addflag('color-info', 'InteractiveShell.color_info',
59 """IPython can display information about objects via a set of func-
58 """IPython can display information about objects via a set of func-
60 tions, and optionally can use colors for this, syntax highlighting
59 tions, and optionally can use colors for this, syntax highlighting
61 source code and various other elements. However, because this
60 source code and various other elements. However, because this
62 information is passed through a pager (like 'less') and many pagers get
61 information is passed through a pager (like 'less') and many pagers get
63 confused with color codes, this option is off by default. You can test
62 confused with color codes, this option is off by default. You can test
64 it and turn it on permanently in your ipython_config.py file if it
63 it and turn it on permanently in your ipython_config.py file if it
65 works for you. Test it and turn it on permanently if it works with
64 works for you. Test it and turn it on permanently if it works with
66 your system. The magic function %%color_info allows you to toggle this
65 your system. The magic function %%color_info allows you to toggle this
67 interactively for testing.""",
66 interactively for testing.""",
68 "Disable using colors for info related things."
67 "Disable using colors for info related things."
69 )
68 )
70 addflag('deep-reload', 'InteractiveShell.deep_reload',
69 addflag('deep-reload', 'InteractiveShell.deep_reload',
71 """Enable deep (recursive) reloading by default. IPython can use the
70 """Enable deep (recursive) reloading by default. IPython can use the
72 deep_reload module which reloads changes in modules recursively (it
71 deep_reload module which reloads changes in modules recursively (it
73 replaces the reload() function, so you don't need to change anything to
72 replaces the reload() function, so you don't need to change anything to
74 use it). deep_reload() forces a full reload of modules whose code may
73 use it). deep_reload() forces a full reload of modules whose code may
75 have changed, which the default reload() function does not. When
74 have changed, which the default reload() function does not. When
76 deep_reload is off, IPython will use the normal reload(), but
75 deep_reload is off, IPython will use the normal reload(), but
77 deep_reload will still be available as dreload(). This feature is off
76 deep_reload will still be available as dreload(). This feature is off
78 by default [which means that you have both normal reload() and
77 by default [which means that you have both normal reload() and
79 dreload()].""",
78 dreload()].""",
80 "Disable deep (recursive) reloading by default."
79 "Disable deep (recursive) reloading by default."
81 )
80 )
82 nosep_config = Config()
81 nosep_config = Config()
83 nosep_config.InteractiveShell.separate_in = ''
82 nosep_config.InteractiveShell.separate_in = ''
84 nosep_config.InteractiveShell.separate_out = ''
83 nosep_config.InteractiveShell.separate_out = ''
85 nosep_config.InteractiveShell.separate_out2 = ''
84 nosep_config.InteractiveShell.separate_out2 = ''
86
85
87 shell_flags['nosep']=(nosep_config, "Eliminate all spacing between prompts.")
86 shell_flags['nosep']=(nosep_config, "Eliminate all spacing between prompts.")
88
87
89
88
90 # it's possible we don't want short aliases for *all* of these:
89 # it's possible we don't want short aliases for *all* of these:
91 shell_aliases = dict(
90 shell_aliases = dict(
92 autocall='InteractiveShell.autocall',
91 autocall='InteractiveShell.autocall',
93 colors='InteractiveShell.colors',
92 colors='InteractiveShell.colors',
94 logfile='InteractiveShell.logfile',
93 logfile='InteractiveShell.logfile',
95 logappend='InteractiveShell.logappend',
94 logappend='InteractiveShell.logappend',
96 c='InteractiveShellApp.code_to_run',
95 c='InteractiveShellApp.code_to_run',
97 ext='InteractiveShellApp.extra_extension',
96 ext='InteractiveShellApp.extra_extension',
98 )
97 )
99 shell_aliases['cache-size'] = 'InteractiveShell.cache_size'
98 shell_aliases['cache-size'] = 'InteractiveShell.cache_size'
100
99
101 #-----------------------------------------------------------------------------
100 #-----------------------------------------------------------------------------
102 # Main classes and functions
101 # Main classes and functions
103 #-----------------------------------------------------------------------------
102 #-----------------------------------------------------------------------------
104
103
105 class InteractiveShellApp(Configurable):
104 class InteractiveShellApp(Configurable):
106 """A Mixin for applications that start InteractiveShell instances.
105 """A Mixin for applications that start InteractiveShell instances.
107
106
108 Provides configurables for loading extensions and executing files
107 Provides configurables for loading extensions and executing files
109 as part of configuring a Shell environment.
108 as part of configuring a Shell environment.
110
109
111 Provides init_extensions() and init_code() methods, to be called
110 Provides init_extensions() and init_code() methods, to be called
112 after init_shell(), which must be implemented by subclasses.
111 after init_shell(), which must be implemented by subclasses.
113 """
112 """
114 extensions = List(Unicode, config=True,
113 extensions = List(Unicode, config=True,
115 help="A list of dotted module names of IPython extensions to load."
114 help="A list of dotted module names of IPython extensions to load."
116 )
115 )
117 extra_extension = Unicode('', config=True,
116 extra_extension = Unicode('', config=True,
118 help="dotted module name of an IPython extension to load."
117 help="dotted module name of an IPython extension to load."
119 )
118 )
120 def _extra_extension_changed(self, name, old, new):
119 def _extra_extension_changed(self, name, old, new):
121 if new:
120 if new:
122 # add to self.extensions
121 # add to self.extensions
123 self.extensions.append(new)
122 self.extensions.append(new)
124
123
125 exec_files = List(Unicode, config=True,
124 exec_files = List(Unicode, config=True,
126 help="""List of files to run at IPython startup."""
125 help="""List of files to run at IPython startup."""
127 )
126 )
128 file_to_run = Unicode('', config=True,
127 file_to_run = Unicode('', config=True,
129 help="""A file to be run""")
128 help="""A file to be run""")
130
129
131 exec_lines = List(Unicode, config=True,
130 exec_lines = List(Unicode, config=True,
132 help="""lines of code to run at IPython startup."""
131 help="""lines of code to run at IPython startup."""
133 )
132 )
134 code_to_run = Unicode('', config=True,
133 code_to_run = Unicode('', config=True,
135 help="Execute the given command string."
134 help="Execute the given command string."
136 )
135 )
137 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
136 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
138
137
139 def init_shell(self):
138 def init_shell(self):
140 raise NotImplementedError("Override in subclasses")
139 raise NotImplementedError("Override in subclasses")
141
140
142 def init_extensions(self):
141 def init_extensions(self):
143 """Load all IPython extensions in IPythonApp.extensions.
142 """Load all IPython extensions in IPythonApp.extensions.
144
143
145 This uses the :meth:`ExtensionManager.load_extensions` to load all
144 This uses the :meth:`ExtensionManager.load_extensions` to load all
146 the extensions listed in ``self.extensions``.
145 the extensions listed in ``self.extensions``.
147 """
146 """
148 if not self.extensions:
147 if not self.extensions:
149 return
148 return
150 try:
149 try:
151 self.log.debug("Loading IPython extensions...")
150 self.log.debug("Loading IPython extensions...")
152 extensions = self.extensions
151 extensions = self.extensions
153 for ext in extensions:
152 for ext in extensions:
154 try:
153 try:
155 self.log.info("Loading IPython extension: %s" % ext)
154 self.log.info("Loading IPython extension: %s" % ext)
156 self.shell.extension_manager.load_extension(ext)
155 self.shell.extension_manager.load_extension(ext)
157 except:
156 except:
158 self.log.warn("Error in loading extension: %s" % ext)
157 self.log.warn("Error in loading extension: %s" % ext)
159 self.shell.showtraceback()
158 self.shell.showtraceback()
160 except:
159 except:
161 self.log.warn("Unknown error in loading extensions:")
160 self.log.warn("Unknown error in loading extensions:")
162 self.shell.showtraceback()
161 self.shell.showtraceback()
163
162
164 def init_code(self):
163 def init_code(self):
165 """run the pre-flight code, specified via exec_lines"""
164 """run the pre-flight code, specified via exec_lines"""
166 self._run_exec_lines()
165 self._run_exec_lines()
167 self._run_exec_files()
166 self._run_exec_files()
168 self._run_cmd_line_code()
167 self._run_cmd_line_code()
169
168
170 def _run_exec_lines(self):
169 def _run_exec_lines(self):
171 """Run lines of code in IPythonApp.exec_lines in the user's namespace."""
170 """Run lines of code in IPythonApp.exec_lines in the user's namespace."""
172 if not self.exec_lines:
171 if not self.exec_lines:
173 return
172 return
174 try:
173 try:
175 self.log.debug("Running code from IPythonApp.exec_lines...")
174 self.log.debug("Running code from IPythonApp.exec_lines...")
176 for line in self.exec_lines:
175 for line in self.exec_lines:
177 try:
176 try:
178 self.log.info("Running code in user namespace: %s" %
177 self.log.info("Running code in user namespace: %s" %
179 line)
178 line)
180 self.shell.run_cell(line, store_history=False)
179 self.shell.run_cell(line, store_history=False)
181 except:
180 except:
182 self.log.warn("Error in executing line in user "
181 self.log.warn("Error in executing line in user "
183 "namespace: %s" % line)
182 "namespace: %s" % line)
184 self.shell.showtraceback()
183 self.shell.showtraceback()
185 except:
184 except:
186 self.log.warn("Unknown error in handling IPythonApp.exec_lines:")
185 self.log.warn("Unknown error in handling IPythonApp.exec_lines:")
187 self.shell.showtraceback()
186 self.shell.showtraceback()
188
187
189 def _exec_file(self, fname):
188 def _exec_file(self, fname):
190 try:
189 try:
191 full_filename = filefind(fname, [u'.', self.ipython_dir])
190 full_filename = filefind(fname, [u'.', self.ipython_dir])
192 except IOError as e:
191 except IOError as e:
193 self.log.warn("File not found: %r"%fname)
192 self.log.warn("File not found: %r"%fname)
194 return
193 return
195 # Make sure that the running script gets a proper sys.argv as if it
194 # Make sure that the running script gets a proper sys.argv as if it
196 # were run from a system shell.
195 # were run from a system shell.
197 save_argv = sys.argv
196 save_argv = sys.argv
198 sys.argv = [full_filename] + self.extra_args[1:]
197 sys.argv = [full_filename] + self.extra_args[1:]
199 try:
198 try:
200 if os.path.isfile(full_filename):
199 if os.path.isfile(full_filename):
201 if full_filename.endswith('.ipy'):
200 if full_filename.endswith('.ipy'):
202 self.log.info("Running file in user namespace: %s" %
201 self.log.info("Running file in user namespace: %s" %
203 full_filename)
202 full_filename)
204 self.shell.safe_execfile_ipy(full_filename)
203 self.shell.safe_execfile_ipy(full_filename)
205 else:
204 else:
206 # default to python, even without extension
205 # default to python, even without extension
207 self.log.info("Running file in user namespace: %s" %
206 self.log.info("Running file in user namespace: %s" %
208 full_filename)
207 full_filename)
209 # Ensure that __file__ is always defined to match Python behavior
208 # Ensure that __file__ is always defined to match Python behavior
210 self.shell.user_ns['__file__'] = fname
209 self.shell.user_ns['__file__'] = fname
211 try:
210 try:
212 self.shell.safe_execfile(full_filename, self.shell.user_ns)
211 self.shell.safe_execfile(full_filename, self.shell.user_ns)
213 finally:
212 finally:
214 del self.shell.user_ns['__file__']
213 del self.shell.user_ns['__file__']
215 finally:
214 finally:
216 sys.argv = save_argv
215 sys.argv = save_argv
217
216
218 def _run_exec_files(self):
217 def _run_exec_files(self):
219 """Run files from IPythonApp.exec_files"""
218 """Run files from IPythonApp.exec_files"""
220 if not self.exec_files:
219 if not self.exec_files:
221 return
220 return
222
221
223 self.log.debug("Running files in IPythonApp.exec_files...")
222 self.log.debug("Running files in IPythonApp.exec_files...")
224 try:
223 try:
225 for fname in self.exec_files:
224 for fname in self.exec_files:
226 self._exec_file(fname)
225 self._exec_file(fname)
227 except:
226 except:
228 self.log.warn("Unknown error in handling IPythonApp.exec_files:")
227 self.log.warn("Unknown error in handling IPythonApp.exec_files:")
229 self.shell.showtraceback()
228 self.shell.showtraceback()
230
229
231 def _run_cmd_line_code(self):
230 def _run_cmd_line_code(self):
232 """Run code or file specified at the command-line"""
231 """Run code or file specified at the command-line"""
233 if self.code_to_run:
232 if self.code_to_run:
234 line = self.code_to_run
233 line = self.code_to_run
235 try:
234 try:
236 self.log.info("Running code given at command line (c=): %s" %
235 self.log.info("Running code given at command line (c=): %s" %
237 line)
236 line)
238 self.shell.run_cell(line, store_history=False)
237 self.shell.run_cell(line, store_history=False)
239 except:
238 except:
240 self.log.warn("Error in executing line in user namespace: %s" %
239 self.log.warn("Error in executing line in user namespace: %s" %
241 line)
240 line)
242 self.shell.showtraceback()
241 self.shell.showtraceback()
243
242
244 # Like Python itself, ignore the second if the first of these is present
243 # Like Python itself, ignore the second if the first of these is present
245 elif self.file_to_run:
244 elif self.file_to_run:
246 fname = self.file_to_run
245 fname = self.file_to_run
247 try:
246 try:
248 self._exec_file(fname)
247 self._exec_file(fname)
249 except:
248 except:
250 self.log.warn("Error in executing file in user namespace: %s" %
249 self.log.warn("Error in executing file in user namespace: %s" %
251 fname)
250 fname)
252 self.shell.showtraceback()
251 self.shell.showtraceback()
253
252
@@ -1,91 +1,90 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 Simple utility for splitting user input.
3 Simple utility for splitting user input.
5
4
6 Authors:
5 Authors:
7
6
8 * Brian Granger
7 * Brian Granger
9 * Fernando Perez
8 * Fernando Perez
10 """
9 """
11
10
12 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
13 # Copyright (C) 2008-2009 The IPython Development Team
12 # Copyright (C) 2008-2009 The IPython Development Team
14 #
13 #
15 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
16 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
17 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
18
17
19 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
20 # Imports
19 # Imports
21 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
22
21
23 import re
22 import re
24 import sys
23 import sys
25
24
26 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
27 # Main function
26 # Main function
28 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
29
28
30
29
31 # RegExp for splitting line contents into pre-char//first word-method//rest.
30 # RegExp for splitting line contents into pre-char//first word-method//rest.
32 # For clarity, each group in on one line.
31 # For clarity, each group in on one line.
33
32
34 # WARNING: update the regexp if the escapes in interactiveshell are changed, as they
33 # WARNING: update the regexp if the escapes in interactiveshell are changed, as they
35 # are hardwired in.
34 # are hardwired in.
36
35
37 # Although it's not solely driven by the regex, note that:
36 # Although it's not solely driven by the regex, note that:
38 # ,;/% only trigger if they are the first character on the line
37 # ,;/% only trigger if they are the first character on the line
39 # ! and !! trigger if they are first char(s) *or* follow an indent
38 # ! and !! trigger if they are first char(s) *or* follow an indent
40 # ? triggers as first or last char.
39 # ? triggers as first or last char.
41
40
42 # The three parts of the regex are:
41 # The three parts of the regex are:
43 # 1) pre: pre_char *or* initial whitespace
42 # 1) pre: pre_char *or* initial whitespace
44 # 2) ifun: first word/method (mix of \w and '.')
43 # 2) ifun: first word/method (mix of \w and '.')
45 # 3) the_rest: rest of line (separated from ifun by space if non-empty)
44 # 3) the_rest: rest of line (separated from ifun by space if non-empty)
46 line_split = re.compile(r'^([,;/%?]|!!?|\s*)'
45 line_split = re.compile(r'^([,;/%?]|!!?|\s*)'
47 r'\s*([\w\.]+)'
46 r'\s*([\w\.]+)'
48 r'(\s+.*$|$)')
47 r'(\s+.*$|$)')
49
48
50 # r'[\w\.]+'
49 # r'[\w\.]+'
51 # r'\s*=\s*%.*'
50 # r'\s*=\s*%.*'
52
51
53 def split_user_input(line, pattern=None):
52 def split_user_input(line, pattern=None):
54 """Split user input into pre-char/whitespace, function part and rest.
53 """Split user input into pre-char/whitespace, function part and rest.
55
54
56 This is currently handles lines with '=' in them in a very inconsistent
55 This is currently handles lines with '=' in them in a very inconsistent
57 manner.
56 manner.
58 """
57 """
59 # We need to ensure that the rest of this routine deals only with unicode
58 # We need to ensure that the rest of this routine deals only with unicode
60 if type(line)==str:
59 if type(line)==str:
61 codec = sys.stdin.encoding
60 codec = sys.stdin.encoding
62 if codec is None:
61 if codec is None:
63 codec = 'utf-8'
62 codec = 'utf-8'
64 line = line.decode(codec)
63 line = line.decode(codec)
65
64
66 if pattern is None:
65 if pattern is None:
67 pattern = line_split
66 pattern = line_split
68 match = pattern.match(line)
67 match = pattern.match(line)
69 if not match:
68 if not match:
70 # print "match failed for line '%s'" % line
69 # print "match failed for line '%s'" % line
71 try:
70 try:
72 ifun, the_rest = line.split(None,1)
71 ifun, the_rest = line.split(None,1)
73 except ValueError:
72 except ValueError:
74 # print "split failed for line '%s'" % line
73 # print "split failed for line '%s'" % line
75 ifun, the_rest = line, u''
74 ifun, the_rest = line, u''
76 pre = re.match('^(\s*)(.*)',line).groups()[0]
75 pre = re.match('^(\s*)(.*)',line).groups()[0]
77 else:
76 else:
78 pre,ifun,the_rest = match.groups()
77 pre,ifun,the_rest = match.groups()
79
78
80 # ifun has to be a valid python identifier, so it better encode into
79 # ifun has to be a valid python identifier, so it better encode into
81 # ascii. We do still make it a unicode string so that we consistently
80 # ascii. We do still make it a unicode string so that we consistently
82 # return unicode, but it will be one that is guaranteed to be pure ascii
81 # return unicode, but it will be one that is guaranteed to be pure ascii
83 try:
82 try:
84 ifun = unicode(ifun.encode('ascii'))
83 ifun = unicode(ifun.encode('ascii'))
85 except UnicodeEncodeError:
84 except UnicodeEncodeError:
86 the_rest = ifun + u' ' + the_rest
85 the_rest = ifun + u' ' + the_rest
87 ifun = u''
86 ifun = u''
88
87
89 #print 'line:<%s>' % line # dbg
88 #print 'line:<%s>' % line # dbg
90 #print 'pre <%s> ifun <%s> rest <%s>' % (pre,ifun.strip(),the_rest) # dbg
89 #print 'pre <%s> ifun <%s> rest <%s>' % (pre,ifun.strip(),the_rest) # dbg
91 return pre, ifun.strip(), the_rest.lstrip()
90 return pre, ifun.strip(), the_rest.lstrip()
@@ -1,59 +1,58 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3
2
4 def test_import_completer():
3 def test_import_completer():
5 from IPython.core import completer
4 from IPython.core import completer
6
5
7 def test_import_crashhandler():
6 def test_import_crashhandler():
8 from IPython.core import crashhandler
7 from IPython.core import crashhandler
9
8
10 def test_import_debugger():
9 def test_import_debugger():
11 from IPython.core import debugger
10 from IPython.core import debugger
12
11
13 def test_import_fakemodule():
12 def test_import_fakemodule():
14 from IPython.core import fakemodule
13 from IPython.core import fakemodule
15
14
16 def test_import_excolors():
15 def test_import_excolors():
17 from IPython.core import excolors
16 from IPython.core import excolors
18
17
19 def test_import_history():
18 def test_import_history():
20 from IPython.core import history
19 from IPython.core import history
21
20
22 def test_import_hooks():
21 def test_import_hooks():
23 from IPython.core import hooks
22 from IPython.core import hooks
24
23
25 def test_import_ipapi():
24 def test_import_ipapi():
26 from IPython.core import ipapi
25 from IPython.core import ipapi
27
26
28 def test_import_interactiveshell():
27 def test_import_interactiveshell():
29 from IPython.core import interactiveshell
28 from IPython.core import interactiveshell
30
29
31 def test_import_logger():
30 def test_import_logger():
32 from IPython.core import logger
31 from IPython.core import logger
33
32
34 def test_import_macro():
33 def test_import_macro():
35 from IPython.core import macro
34 from IPython.core import macro
36
35
37 def test_import_magic():
36 def test_import_magic():
38 from IPython.core import magic
37 from IPython.core import magic
39
38
40 def test_import_oinspect():
39 def test_import_oinspect():
41 from IPython.core import oinspect
40 from IPython.core import oinspect
42
41
43 def test_import_prefilter():
42 def test_import_prefilter():
44 from IPython.core import prefilter
43 from IPython.core import prefilter
45
44
46 def test_import_prompts():
45 def test_import_prompts():
47 from IPython.core import prompts
46 from IPython.core import prompts
48
47
49 def test_import_release():
48 def test_import_release():
50 from IPython.core import release
49 from IPython.core import release
51
50
52 def test_import_shadowns():
51 def test_import_shadowns():
53 from IPython.core import shadowns
52 from IPython.core import shadowns
54
53
55 def test_import_ultratb():
54 def test_import_ultratb():
56 from IPython.core import ultratb
55 from IPython.core import ultratb
57
56
58 def test_import_usage():
57 def test_import_usage():
59 from IPython.core import usage
58 from IPython.core import usage
@@ -1,300 +1,299 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3
2
4 """Magic command interface for interactive parallel work."""
3 """Magic command interface for interactive parallel work."""
5
4
6 #-----------------------------------------------------------------------------
5 #-----------------------------------------------------------------------------
7 # Copyright (C) 2008-2009 The IPython Development Team
6 # Copyright (C) 2008-2009 The IPython Development Team
8 #
7 #
9 # Distributed under the terms of the BSD License. The full license is in
8 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
9 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
12
11
13 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
14 # Imports
13 # Imports
15 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
16
15
17 import ast
16 import ast
18 import re
17 import re
19
18
20 from IPython.core.plugin import Plugin
19 from IPython.core.plugin import Plugin
21 from IPython.utils.traitlets import Bool, Any, Instance
20 from IPython.utils.traitlets import Bool, Any, Instance
22 from IPython.testing.skipdoctest import skip_doctest
21 from IPython.testing.skipdoctest import skip_doctest
23
22
24 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
25 # Definitions of magic functions for use with IPython
24 # Definitions of magic functions for use with IPython
26 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
27
26
28
27
29 NO_ACTIVE_VIEW = """
28 NO_ACTIVE_VIEW = """
30 Use activate() on a DirectView object to activate it for magics.
29 Use activate() on a DirectView object to activate it for magics.
31 """
30 """
32
31
33
32
34 class ParalleMagic(Plugin):
33 class ParalleMagic(Plugin):
35 """A component to manage the %result, %px and %autopx magics."""
34 """A component to manage the %result, %px and %autopx magics."""
36
35
37 active_view = Instance('IPython.parallel.client.view.DirectView')
36 active_view = Instance('IPython.parallel.client.view.DirectView')
38 verbose = Bool(False, config=True)
37 verbose = Bool(False, config=True)
39 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
38 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
40
39
41 def __init__(self, shell=None, config=None):
40 def __init__(self, shell=None, config=None):
42 super(ParalleMagic, self).__init__(shell=shell, config=config)
41 super(ParalleMagic, self).__init__(shell=shell, config=config)
43 self._define_magics()
42 self._define_magics()
44 # A flag showing if autopx is activated or not
43 # A flag showing if autopx is activated or not
45 self.autopx = False
44 self.autopx = False
46
45
47 def _define_magics(self):
46 def _define_magics(self):
48 """Define the magic functions."""
47 """Define the magic functions."""
49 self.shell.define_magic('result', self.magic_result)
48 self.shell.define_magic('result', self.magic_result)
50 self.shell.define_magic('px', self.magic_px)
49 self.shell.define_magic('px', self.magic_px)
51 self.shell.define_magic('autopx', self.magic_autopx)
50 self.shell.define_magic('autopx', self.magic_autopx)
52
51
53 @skip_doctest
52 @skip_doctest
54 def magic_result(self, ipself, parameter_s=''):
53 def magic_result(self, ipself, parameter_s=''):
55 """Print the result of command i on all engines..
54 """Print the result of command i on all engines..
56
55
57 To use this a :class:`DirectView` instance must be created
56 To use this a :class:`DirectView` instance must be created
58 and then activated by calling its :meth:`activate` method.
57 and then activated by calling its :meth:`activate` method.
59
58
60 Then you can do the following::
59 Then you can do the following::
61
60
62 In [23]: %result
61 In [23]: %result
63 Out[23]:
62 Out[23]:
64 <Results List>
63 <Results List>
65 [0] In [6]: a = 10
64 [0] In [6]: a = 10
66 [1] In [6]: a = 10
65 [1] In [6]: a = 10
67
66
68 In [22]: %result 6
67 In [22]: %result 6
69 Out[22]:
68 Out[22]:
70 <Results List>
69 <Results List>
71 [0] In [6]: a = 10
70 [0] In [6]: a = 10
72 [1] In [6]: a = 10
71 [1] In [6]: a = 10
73 """
72 """
74 if self.active_view is None:
73 if self.active_view is None:
75 print NO_ACTIVE_VIEW
74 print NO_ACTIVE_VIEW
76 return
75 return
77
76
78 try:
77 try:
79 index = int(parameter_s)
78 index = int(parameter_s)
80 except:
79 except:
81 index = None
80 index = None
82 result = self.active_view.get_result(index)
81 result = self.active_view.get_result(index)
83 return result
82 return result
84
83
85 @skip_doctest
84 @skip_doctest
86 def magic_px(self, ipself, parameter_s=''):
85 def magic_px(self, ipself, parameter_s=''):
87 """Executes the given python command in parallel.
86 """Executes the given python command in parallel.
88
87
89 To use this a :class:`DirectView` instance must be created
88 To use this a :class:`DirectView` instance must be created
90 and then activated by calling its :meth:`activate` method.
89 and then activated by calling its :meth:`activate` method.
91
90
92 Then you can do the following::
91 Then you can do the following::
93
92
94 In [24]: %px a = 5
93 In [24]: %px a = 5
95 Parallel execution on engine(s): all
94 Parallel execution on engine(s): all
96 Out[24]:
95 Out[24]:
97 <Results List>
96 <Results List>
98 [0] In [7]: a = 5
97 [0] In [7]: a = 5
99 [1] In [7]: a = 5
98 [1] In [7]: a = 5
100 """
99 """
101
100
102 if self.active_view is None:
101 if self.active_view is None:
103 print NO_ACTIVE_VIEW
102 print NO_ACTIVE_VIEW
104 return
103 return
105 print "Parallel execution on engine(s): %s" % self.active_view.targets
104 print "Parallel execution on engine(s): %s" % self.active_view.targets
106 result = self.active_view.execute(parameter_s, block=False)
105 result = self.active_view.execute(parameter_s, block=False)
107 if self.active_view.block:
106 if self.active_view.block:
108 result.get()
107 result.get()
109 self._maybe_display_output(result)
108 self._maybe_display_output(result)
110
109
111 @skip_doctest
110 @skip_doctest
112 def magic_autopx(self, ipself, parameter_s=''):
111 def magic_autopx(self, ipself, parameter_s=''):
113 """Toggles auto parallel mode.
112 """Toggles auto parallel mode.
114
113
115 To use this a :class:`DirectView` instance must be created
114 To use this a :class:`DirectView` instance must be created
116 and then activated by calling its :meth:`activate` method. Once this
115 and then activated by calling its :meth:`activate` method. Once this
117 is called, all commands typed at the command line are send to
116 is called, all commands typed at the command line are send to
118 the engines to be executed in parallel. To control which engine
117 the engines to be executed in parallel. To control which engine
119 are used, set the ``targets`` attributed of the multiengine client
118 are used, set the ``targets`` attributed of the multiengine client
120 before entering ``%autopx`` mode.
119 before entering ``%autopx`` mode.
121
120
122 Then you can do the following::
121 Then you can do the following::
123
122
124 In [25]: %autopx
123 In [25]: %autopx
125 %autopx to enabled
124 %autopx to enabled
126
125
127 In [26]: a = 10
126 In [26]: a = 10
128 Parallel execution on engine(s): [0,1,2,3]
127 Parallel execution on engine(s): [0,1,2,3]
129 In [27]: print a
128 In [27]: print a
130 Parallel execution on engine(s): [0,1,2,3]
129 Parallel execution on engine(s): [0,1,2,3]
131 [stdout:0] 10
130 [stdout:0] 10
132 [stdout:1] 10
131 [stdout:1] 10
133 [stdout:2] 10
132 [stdout:2] 10
134 [stdout:3] 10
133 [stdout:3] 10
135
134
136
135
137 In [27]: %autopx
136 In [27]: %autopx
138 %autopx disabled
137 %autopx disabled
139 """
138 """
140 if self.autopx:
139 if self.autopx:
141 self._disable_autopx()
140 self._disable_autopx()
142 else:
141 else:
143 self._enable_autopx()
142 self._enable_autopx()
144
143
145 def _enable_autopx(self):
144 def _enable_autopx(self):
146 """Enable %autopx mode by saving the original run_cell and installing
145 """Enable %autopx mode by saving the original run_cell and installing
147 pxrun_cell.
146 pxrun_cell.
148 """
147 """
149 if self.active_view is None:
148 if self.active_view is None:
150 print NO_ACTIVE_VIEW
149 print NO_ACTIVE_VIEW
151 return
150 return
152
151
153 # override run_cell and run_code
152 # override run_cell and run_code
154 self._original_run_cell = self.shell.run_cell
153 self._original_run_cell = self.shell.run_cell
155 self.shell.run_cell = self.pxrun_cell
154 self.shell.run_cell = self.pxrun_cell
156 self._original_run_code = self.shell.run_code
155 self._original_run_code = self.shell.run_code
157 self.shell.run_code = self.pxrun_code
156 self.shell.run_code = self.pxrun_code
158
157
159 self.autopx = True
158 self.autopx = True
160 print "%autopx enabled"
159 print "%autopx enabled"
161
160
162 def _disable_autopx(self):
161 def _disable_autopx(self):
163 """Disable %autopx by restoring the original InteractiveShell.run_cell.
162 """Disable %autopx by restoring the original InteractiveShell.run_cell.
164 """
163 """
165 if self.autopx:
164 if self.autopx:
166 self.shell.run_cell = self._original_run_cell
165 self.shell.run_cell = self._original_run_cell
167 self.shell.run_code = self._original_run_code
166 self.shell.run_code = self._original_run_code
168 self.autopx = False
167 self.autopx = False
169 print "%autopx disabled"
168 print "%autopx disabled"
170
169
171 def _maybe_display_output(self, result):
170 def _maybe_display_output(self, result):
172 """Maybe display the output of a parallel result.
171 """Maybe display the output of a parallel result.
173
172
174 If self.active_view.block is True, wait for the result
173 If self.active_view.block is True, wait for the result
175 and display the result. Otherwise, this is a noop.
174 and display the result. Otherwise, this is a noop.
176 """
175 """
177 if isinstance(result.stdout, basestring):
176 if isinstance(result.stdout, basestring):
178 # single result
177 # single result
179 stdouts = [result.stdout.rstrip()]
178 stdouts = [result.stdout.rstrip()]
180 else:
179 else:
181 stdouts = [s.rstrip() for s in result.stdout]
180 stdouts = [s.rstrip() for s in result.stdout]
182
181
183 targets = self.active_view.targets
182 targets = self.active_view.targets
184 if isinstance(targets, int):
183 if isinstance(targets, int):
185 targets = [targets]
184 targets = [targets]
186 elif targets == 'all':
185 elif targets == 'all':
187 targets = self.active_view.client.ids
186 targets = self.active_view.client.ids
188
187
189 if any(stdouts):
188 if any(stdouts):
190 for eid,stdout in zip(targets, stdouts):
189 for eid,stdout in zip(targets, stdouts):
191 print '[stdout:%i]'%eid, stdout
190 print '[stdout:%i]'%eid, stdout
192
191
193
192
194 def pxrun_cell(self, raw_cell, store_history=True):
193 def pxrun_cell(self, raw_cell, store_history=True):
195 """drop-in replacement for InteractiveShell.run_cell.
194 """drop-in replacement for InteractiveShell.run_cell.
196
195
197 This executes code remotely, instead of in the local namespace.
196 This executes code remotely, instead of in the local namespace.
198
197
199 See InteractiveShell.run_cell for details.
198 See InteractiveShell.run_cell for details.
200 """
199 """
201
200
202 if (not raw_cell) or raw_cell.isspace():
201 if (not raw_cell) or raw_cell.isspace():
203 return
202 return
204
203
205 ipself = self.shell
204 ipself = self.shell
206
205
207 with ipself.builtin_trap:
206 with ipself.builtin_trap:
208 cell = ipself.prefilter_manager.prefilter_lines(raw_cell)
207 cell = ipself.prefilter_manager.prefilter_lines(raw_cell)
209
208
210 # Store raw and processed history
209 # Store raw and processed history
211 if store_history:
210 if store_history:
212 ipself.history_manager.store_inputs(ipself.execution_count,
211 ipself.history_manager.store_inputs(ipself.execution_count,
213 cell, raw_cell)
212 cell, raw_cell)
214
213
215 # ipself.logger.log(cell, raw_cell)
214 # ipself.logger.log(cell, raw_cell)
216
215
217 cell_name = ipself.compile.cache(cell, ipself.execution_count)
216 cell_name = ipself.compile.cache(cell, ipself.execution_count)
218
217
219 try:
218 try:
220 code_ast = ast.parse(cell, filename=cell_name)
219 code_ast = ast.parse(cell, filename=cell_name)
221 except (OverflowError, SyntaxError, ValueError, TypeError, MemoryError):
220 except (OverflowError, SyntaxError, ValueError, TypeError, MemoryError):
222 # Case 1
221 # Case 1
223 ipself.showsyntaxerror()
222 ipself.showsyntaxerror()
224 ipself.execution_count += 1
223 ipself.execution_count += 1
225 return None
224 return None
226 except NameError:
225 except NameError:
227 # ignore name errors, because we don't know the remote keys
226 # ignore name errors, because we don't know the remote keys
228 pass
227 pass
229
228
230 if store_history:
229 if store_history:
231 # Write output to the database. Does nothing unless
230 # Write output to the database. Does nothing unless
232 # history output logging is enabled.
231 # history output logging is enabled.
233 ipself.history_manager.store_output(ipself.execution_count)
232 ipself.history_manager.store_output(ipself.execution_count)
234 # Each cell is a *single* input, regardless of how many lines it has
233 # Each cell is a *single* input, regardless of how many lines it has
235 ipself.execution_count += 1
234 ipself.execution_count += 1
236
235
237 if re.search(r'get_ipython\(\)\.magic\(u?"%?autopx', cell):
236 if re.search(r'get_ipython\(\)\.magic\(u?"%?autopx', cell):
238 self._disable_autopx()
237 self._disable_autopx()
239 return False
238 return False
240 else:
239 else:
241 try:
240 try:
242 result = self.active_view.execute(cell, block=False)
241 result = self.active_view.execute(cell, block=False)
243 except:
242 except:
244 ipself.showtraceback()
243 ipself.showtraceback()
245 return True
244 return True
246 else:
245 else:
247 if self.active_view.block:
246 if self.active_view.block:
248 try:
247 try:
249 result.get()
248 result.get()
250 except:
249 except:
251 self.shell.showtraceback()
250 self.shell.showtraceback()
252 return True
251 return True
253 else:
252 else:
254 self._maybe_display_output(result)
253 self._maybe_display_output(result)
255 return False
254 return False
256
255
257 def pxrun_code(self, code_obj):
256 def pxrun_code(self, code_obj):
258 """drop-in replacement for InteractiveShell.run_code.
257 """drop-in replacement for InteractiveShell.run_code.
259
258
260 This executes code remotely, instead of in the local namespace.
259 This executes code remotely, instead of in the local namespace.
261
260
262 See InteractiveShell.run_code for details.
261 See InteractiveShell.run_code for details.
263 """
262 """
264 ipself = self.shell
263 ipself = self.shell
265 # check code object for the autopx magic
264 # check code object for the autopx magic
266 if 'get_ipython' in code_obj.co_names and 'magic' in code_obj.co_names and \
265 if 'get_ipython' in code_obj.co_names and 'magic' in code_obj.co_names and \
267 any( [ isinstance(c, basestring) and 'autopx' in c for c in code_obj.co_consts ]):
266 any( [ isinstance(c, basestring) and 'autopx' in c for c in code_obj.co_consts ]):
268 self._disable_autopx()
267 self._disable_autopx()
269 return False
268 return False
270 else:
269 else:
271 try:
270 try:
272 result = self.active_view.execute(code_obj, block=False)
271 result = self.active_view.execute(code_obj, block=False)
273 except:
272 except:
274 ipself.showtraceback()
273 ipself.showtraceback()
275 return True
274 return True
276 else:
275 else:
277 if self.active_view.block:
276 if self.active_view.block:
278 try:
277 try:
279 result.get()
278 result.get()
280 except:
279 except:
281 self.shell.showtraceback()
280 self.shell.showtraceback()
282 return True
281 return True
283 else:
282 else:
284 self._maybe_display_output(result)
283 self._maybe_display_output(result)
285 return False
284 return False
286
285
287
286
288
287
289
288
290 _loaded = False
289 _loaded = False
291
290
292
291
293 def load_ipython_extension(ip):
292 def load_ipython_extension(ip):
294 """Load the extension in IPython."""
293 """Load the extension in IPython."""
295 global _loaded
294 global _loaded
296 if not _loaded:
295 if not _loaded:
297 plugin = ParalleMagic(shell=ip, config=ip.config)
296 plugin = ParalleMagic(shell=ip, config=ip.config)
298 ip.plugin_manager.register_plugin('parallelmagic', plugin)
297 ip.plugin_manager.register_plugin('parallelmagic', plugin)
299 _loaded = True
298 _loaded = True
300
299
@@ -1,170 +1,169 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3
2
4 # GUID.py
3 # GUID.py
5 # Version 2.6
4 # Version 2.6
6 #
5 #
7 # Copyright (c) 2006 Conan C. Albrecht
6 # Copyright (c) 2006 Conan C. Albrecht
8 #
7 #
9 # Permission is hereby granted, free of charge, to any person obtaining a copy
8 # Permission is hereby granted, free of charge, to any person obtaining a copy
10 # of this software and associated documentation files (the "Software"), to deal
9 # of this software and associated documentation files (the "Software"), to deal
11 # in the Software without restriction, including without limitation the rights
10 # in the Software without restriction, including without limitation the rights
12 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 # copies of the Software, and to permit persons to whom the Software is furnished
12 # copies of the Software, and to permit persons to whom the Software is furnished
14 # to do so, subject to the following conditions:
13 # to do so, subject to the following conditions:
15 #
14 #
16 # The above copyright notice and this permission notice shall be included in all
15 # The above copyright notice and this permission notice shall be included in all
17 # copies or substantial portions of the Software.
16 # copies or substantial portions of the Software.
18 #
17 #
19 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
18 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
20 # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
19 # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
21 # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
20 # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
22 # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
21 # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
23 # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
22 # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
24 # DEALINGS IN THE SOFTWARE.
23 # DEALINGS IN THE SOFTWARE.
25
24
26
25
27
26
28 ##################################################################################################
27 ##################################################################################################
29 ### A globally-unique identifier made up of time and ip and 8 digits for a counter:
28 ### A globally-unique identifier made up of time and ip and 8 digits for a counter:
30 ### each GUID is 40 characters wide
29 ### each GUID is 40 characters wide
31 ###
30 ###
32 ### A globally unique identifier that combines ip, time, and a counter. Since the
31 ### A globally unique identifier that combines ip, time, and a counter. Since the
33 ### time is listed first, you can sort records by guid. You can also extract the time
32 ### time is listed first, you can sort records by guid. You can also extract the time
34 ### and ip if needed.
33 ### and ip if needed.
35 ###
34 ###
36 ### Since the counter has eight hex characters, you can create up to
35 ### Since the counter has eight hex characters, you can create up to
37 ### 0xffffffff (4294967295) GUIDs every millisecond. If your processor
36 ### 0xffffffff (4294967295) GUIDs every millisecond. If your processor
38 ### is somehow fast enough to create more than that in a millisecond (looking
37 ### is somehow fast enough to create more than that in a millisecond (looking
39 ### toward the future, of course), the function will wait until the next
38 ### toward the future, of course), the function will wait until the next
40 ### millisecond to return.
39 ### millisecond to return.
41 ###
40 ###
42 ### GUIDs make wonderful database keys. They require no access to the
41 ### GUIDs make wonderful database keys. They require no access to the
43 ### database (to get the max index number), they are extremely unique, and they sort
42 ### database (to get the max index number), they are extremely unique, and they sort
44 ### automatically by time. GUIDs prevent key clashes when merging
43 ### automatically by time. GUIDs prevent key clashes when merging
45 ### two databases together, combining data, or generating keys in distributed
44 ### two databases together, combining data, or generating keys in distributed
46 ### systems.
45 ### systems.
47 ###
46 ###
48 ### There is an Internet Draft for UUIDs, but this module does not implement it.
47 ### There is an Internet Draft for UUIDs, but this module does not implement it.
49 ### If the draft catches on, perhaps I'll conform the module to it.
48 ### If the draft catches on, perhaps I'll conform the module to it.
50 ###
49 ###
51
50
52
51
53 # Changelog
52 # Changelog
54 # Sometime, 1997 Created the Java version of GUID
53 # Sometime, 1997 Created the Java version of GUID
55 # Went through many versions in Java
54 # Went through many versions in Java
56 # Sometime, 2002 Created the Python version of GUID, mirroring the Java version
55 # Sometime, 2002 Created the Python version of GUID, mirroring the Java version
57 # November 24, 2003 Changed Python version to be more pythonic, took out object and made just a module
56 # November 24, 2003 Changed Python version to be more pythonic, took out object and made just a module
58 # December 2, 2003 Fixed duplicating GUIDs. Sometimes they duplicate if multiples are created
57 # December 2, 2003 Fixed duplicating GUIDs. Sometimes they duplicate if multiples are created
59 # in the same millisecond (it checks the last 100 GUIDs now and has a larger random part)
58 # in the same millisecond (it checks the last 100 GUIDs now and has a larger random part)
60 # December 9, 2003 Fixed MAX_RANDOM, which was going over sys.maxint
59 # December 9, 2003 Fixed MAX_RANDOM, which was going over sys.maxint
61 # June 12, 2004 Allowed a custom IP address to be sent in rather than always using the
60 # June 12, 2004 Allowed a custom IP address to be sent in rather than always using the
62 # local IP address.
61 # local IP address.
63 # November 4, 2005 Changed the random part to a counter variable. Now GUIDs are totally
62 # November 4, 2005 Changed the random part to a counter variable. Now GUIDs are totally
64 # unique and more efficient, as long as they are created by only
63 # unique and more efficient, as long as they are created by only
65 # on runtime on a given machine. The counter part is after the time
64 # on runtime on a given machine. The counter part is after the time
66 # part so it sorts correctly.
65 # part so it sorts correctly.
67 # November 8, 2005 The counter variable now starts at a random long now and cycles
66 # November 8, 2005 The counter variable now starts at a random long now and cycles
68 # around. This is in case two guids are created on the same
67 # around. This is in case two guids are created on the same
69 # machine at the same millisecond (by different processes). Even though
68 # machine at the same millisecond (by different processes). Even though
70 # it is possible the GUID can be created, this makes it highly unlikely
69 # it is possible the GUID can be created, this makes it highly unlikely
71 # since the counter will likely be different.
70 # since the counter will likely be different.
72 # November 11, 2005 Fixed a bug in the new IP getting algorithm. Also, use IPv6 range
71 # November 11, 2005 Fixed a bug in the new IP getting algorithm. Also, use IPv6 range
73 # for IP when we make it up (when it's no accessible)
72 # for IP when we make it up (when it's no accessible)
74 # November 21, 2005 Added better IP-finding code. It finds IP address better now.
73 # November 21, 2005 Added better IP-finding code. It finds IP address better now.
75 # January 5, 2006 Fixed a small bug caused in old versions of python (random module use)
74 # January 5, 2006 Fixed a small bug caused in old versions of python (random module use)
76
75
77 import math
76 import math
78 import socket
77 import socket
79 import random
78 import random
80 import sys
79 import sys
81 import time
80 import time
82 import threading
81 import threading
83
82
84
83
85
84
86 #############################
85 #############################
87 ### global module variables
86 ### global module variables
88
87
89 #Makes a hex IP from a decimal dot-separated ip (eg: 127.0.0.1)
88 #Makes a hex IP from a decimal dot-separated ip (eg: 127.0.0.1)
90 make_hexip = lambda ip: ''.join(["%04x" % long(i) for i in ip.split('.')]) # leave space for ip v6 (65K in each sub)
89 make_hexip = lambda ip: ''.join(["%04x" % long(i) for i in ip.split('.')]) # leave space for ip v6 (65K in each sub)
91
90
92 MAX_COUNTER = 0xfffffffe
91 MAX_COUNTER = 0xfffffffe
93 counter = 0L
92 counter = 0L
94 firstcounter = MAX_COUNTER
93 firstcounter = MAX_COUNTER
95 lasttime = 0
94 lasttime = 0
96 ip = ''
95 ip = ''
97 lock = threading.RLock()
96 lock = threading.RLock()
98 try: # only need to get the IP addresss once
97 try: # only need to get the IP addresss once
99 ip = socket.getaddrinfo(socket.gethostname(),0)[-1][-1][0]
98 ip = socket.getaddrinfo(socket.gethostname(),0)[-1][-1][0]
100 hexip = make_hexip(ip)
99 hexip = make_hexip(ip)
101 except: # if we don't have an ip, default to someting in the 10.x.x.x private range
100 except: # if we don't have an ip, default to someting in the 10.x.x.x private range
102 ip = '10'
101 ip = '10'
103 rand = random.Random()
102 rand = random.Random()
104 for i in range(3):
103 for i in range(3):
105 ip += '.' + str(rand.randrange(1, 0xffff)) # might as well use IPv6 range if we're making it up
104 ip += '.' + str(rand.randrange(1, 0xffff)) # might as well use IPv6 range if we're making it up
106 hexip = make_hexip(ip)
105 hexip = make_hexip(ip)
107
106
108
107
109 #################################
108 #################################
110 ### Public module functions
109 ### Public module functions
111
110
112 def generate(ip=None):
111 def generate(ip=None):
113 '''Generates a new guid. A guid is unique in space and time because it combines
112 '''Generates a new guid. A guid is unique in space and time because it combines
114 the machine IP with the current time in milliseconds. Be careful about sending in
113 the machine IP with the current time in milliseconds. Be careful about sending in
115 a specified IP address because the ip makes it unique in space. You could send in
114 a specified IP address because the ip makes it unique in space. You could send in
116 the same IP address that is created on another machine.
115 the same IP address that is created on another machine.
117 '''
116 '''
118 global counter, firstcounter, lasttime
117 global counter, firstcounter, lasttime
119 lock.acquire() # can't generate two guids at the same time
118 lock.acquire() # can't generate two guids at the same time
120 try:
119 try:
121 parts = []
120 parts = []
122
121
123 # do we need to wait for the next millisecond (are we out of counters?)
122 # do we need to wait for the next millisecond (are we out of counters?)
124 now = long(time.time() * 1000)
123 now = long(time.time() * 1000)
125 while lasttime == now and counter == firstcounter:
124 while lasttime == now and counter == firstcounter:
126 time.sleep(.01)
125 time.sleep(.01)
127 now = long(time.time() * 1000)
126 now = long(time.time() * 1000)
128
127
129 # time part
128 # time part
130 parts.append("%016x" % now)
129 parts.append("%016x" % now)
131
130
132 # counter part
131 # counter part
133 if lasttime != now: # time to start counter over since we have a different millisecond
132 if lasttime != now: # time to start counter over since we have a different millisecond
134 firstcounter = long(random.uniform(1, MAX_COUNTER)) # start at random position
133 firstcounter = long(random.uniform(1, MAX_COUNTER)) # start at random position
135 counter = firstcounter
134 counter = firstcounter
136 counter += 1
135 counter += 1
137 if counter > MAX_COUNTER:
136 if counter > MAX_COUNTER:
138 counter = 0
137 counter = 0
139 lasttime = now
138 lasttime = now
140 parts.append("%08x" % (counter))
139 parts.append("%08x" % (counter))
141
140
142 # ip part
141 # ip part
143 parts.append(hexip)
142 parts.append(hexip)
144
143
145 # put them all together
144 # put them all together
146 return ''.join(parts)
145 return ''.join(parts)
147 finally:
146 finally:
148 lock.release()
147 lock.release()
149
148
150
149
151 def extract_time(guid):
150 def extract_time(guid):
152 '''Extracts the time portion out of the guid and returns the
151 '''Extracts the time portion out of the guid and returns the
153 number of seconds since the epoch as a float'''
152 number of seconds since the epoch as a float'''
154 return float(long(guid[0:16], 16)) / 1000.0
153 return float(long(guid[0:16], 16)) / 1000.0
155
154
156
155
157 def extract_counter(guid):
156 def extract_counter(guid):
158 '''Extracts the counter from the guid (returns the bits in decimal)'''
157 '''Extracts the counter from the guid (returns the bits in decimal)'''
159 return int(guid[16:24], 16)
158 return int(guid[16:24], 16)
160
159
161
160
162 def extract_ip(guid):
161 def extract_ip(guid):
163 '''Extracts the ip portion out of the guid and returns it
162 '''Extracts the ip portion out of the guid and returns it
164 as a string like 10.10.10.10'''
163 as a string like 10.10.10.10'''
165 # there's probably a more elegant way to do this
164 # there's probably a more elegant way to do this
166 thisip = []
165 thisip = []
167 for index in range(24, 40, 4):
166 for index in range(24, 40, 4):
168 thisip.append(str(int(guid[index: index + 4], 16)))
167 thisip.append(str(int(guid[index: index + 4], 16)))
169 return '.'.join(thisip)
168 return '.'.join(thisip)
170
169
@@ -1,90 +1,88 b''
1 #!/usr/bin/env python
2
3 #
1 #
4 # This file is adapted from a paramiko demo, and thus licensed under LGPL 2.1.
2 # This file is adapted from a paramiko demo, and thus licensed under LGPL 2.1.
5 # Original Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
3 # Original Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
6 # Edits Copyright (C) 2010 The IPython Team
4 # Edits Copyright (C) 2010 The IPython Team
7 #
5 #
8 # Paramiko is free software; you can redistribute it and/or modify it under the
6 # Paramiko is free software; you can redistribute it and/or modify it under the
9 # terms of the GNU Lesser General Public License as published by the Free
7 # terms of the GNU Lesser General Public License as published by the Free
10 # Software Foundation; either version 2.1 of the License, or (at your option)
8 # Software Foundation; either version 2.1 of the License, or (at your option)
11 # any later version.
9 # any later version.
12 #
10 #
13 # Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
11 # Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
14 # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
12 # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
15 # A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
13 # A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
16 # details.
14 # details.
17 #
15 #
18 # You should have received a copy of the GNU Lesser General Public License
16 # You should have received a copy of the GNU Lesser General Public License
19 # along with Paramiko; if not, write to the Free Software Foundation, Inc.,
17 # along with Paramiko; if not, write to the Free Software Foundation, Inc.,
20 # 51 Franklin Street, Fifth Floor, Boston, MA 02111-1301 USA.
18 # 51 Franklin Street, Fifth Floor, Boston, MA 02111-1301 USA.
21
19
22 """
20 """
23 Sample script showing how to do local port forwarding over paramiko.
21 Sample script showing how to do local port forwarding over paramiko.
24
22
25 This script connects to the requested SSH server and sets up local port
23 This script connects to the requested SSH server and sets up local port
26 forwarding (the openssh -L option) from a local port through a tunneled
24 forwarding (the openssh -L option) from a local port through a tunneled
27 connection to a destination reachable from the SSH server machine.
25 connection to a destination reachable from the SSH server machine.
28 """
26 """
29
27
30 from __future__ import print_function
28 from __future__ import print_function
31
29
32 import logging
30 import logging
33 import select
31 import select
34 import SocketServer
32 import SocketServer
35
33
36 logger = logging.getLogger('ssh')
34 logger = logging.getLogger('ssh')
37
35
38 class ForwardServer (SocketServer.ThreadingTCPServer):
36 class ForwardServer (SocketServer.ThreadingTCPServer):
39 daemon_threads = True
37 daemon_threads = True
40 allow_reuse_address = True
38 allow_reuse_address = True
41
39
42
40
43 class Handler (SocketServer.BaseRequestHandler):
41 class Handler (SocketServer.BaseRequestHandler):
44
42
45 def handle(self):
43 def handle(self):
46 try:
44 try:
47 chan = self.ssh_transport.open_channel('direct-tcpip',
45 chan = self.ssh_transport.open_channel('direct-tcpip',
48 (self.chain_host, self.chain_port),
46 (self.chain_host, self.chain_port),
49 self.request.getpeername())
47 self.request.getpeername())
50 except Exception, e:
48 except Exception, e:
51 logger.debug('Incoming request to %s:%d failed: %s' % (self.chain_host,
49 logger.debug('Incoming request to %s:%d failed: %s' % (self.chain_host,
52 self.chain_port,
50 self.chain_port,
53 repr(e)))
51 repr(e)))
54 return
52 return
55 if chan is None:
53 if chan is None:
56 logger.debug('Incoming request to %s:%d was rejected by the SSH server.' %
54 logger.debug('Incoming request to %s:%d was rejected by the SSH server.' %
57 (self.chain_host, self.chain_port))
55 (self.chain_host, self.chain_port))
58 return
56 return
59
57
60 logger.debug('Connected! Tunnel open %r -> %r -> %r' % (self.request.getpeername(),
58 logger.debug('Connected! Tunnel open %r -> %r -> %r' % (self.request.getpeername(),
61 chan.getpeername(), (self.chain_host, self.chain_port)))
59 chan.getpeername(), (self.chain_host, self.chain_port)))
62 while True:
60 while True:
63 r, w, x = select.select([self.request, chan], [], [])
61 r, w, x = select.select([self.request, chan], [], [])
64 if self.request in r:
62 if self.request in r:
65 data = self.request.recv(1024)
63 data = self.request.recv(1024)
66 if len(data) == 0:
64 if len(data) == 0:
67 break
65 break
68 chan.send(data)
66 chan.send(data)
69 if chan in r:
67 if chan in r:
70 data = chan.recv(1024)
68 data = chan.recv(1024)
71 if len(data) == 0:
69 if len(data) == 0:
72 break
70 break
73 self.request.send(data)
71 self.request.send(data)
74 chan.close()
72 chan.close()
75 self.request.close()
73 self.request.close()
76 logger.debug('Tunnel closed ')
74 logger.debug('Tunnel closed ')
77
75
78
76
79 def forward_tunnel(local_port, remote_host, remote_port, transport):
77 def forward_tunnel(local_port, remote_host, remote_port, transport):
80 # this is a little convoluted, but lets me configure things for the Handler
78 # this is a little convoluted, but lets me configure things for the Handler
81 # object. (SocketServer doesn't give Handlers any way to access the outer
79 # object. (SocketServer doesn't give Handlers any way to access the outer
82 # server normally.)
80 # server normally.)
83 class SubHander (Handler):
81 class SubHander (Handler):
84 chain_host = remote_host
82 chain_host = remote_host
85 chain_port = remote_port
83 chain_port = remote_port
86 ssh_transport = transport
84 ssh_transport = transport
87 ForwardServer(('127.0.0.1', local_port), SubHander).serve_forever()
85 ForwardServer(('127.0.0.1', local_port), SubHander).serve_forever()
88
86
89
87
90 __all__ = ['forward_tunnel']
88 __all__ = ['forward_tunnel']
@@ -1,29 +1,28 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 Extra capabilities for IPython
3 Extra capabilities for IPython
5 """
4 """
6
5
7 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
8 # Copyright (C) 2008-2009 The IPython Development Team
7 # Copyright (C) 2008-2009 The IPython Development Team
9 #
8 #
10 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
13
12
14 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
15 # Imports
14 # Imports
16 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
17
16
18 from IPython.lib.inputhook import (
17 from IPython.lib.inputhook import (
19 enable_wx, disable_wx,
18 enable_wx, disable_wx,
20 enable_gtk, disable_gtk,
19 enable_gtk, disable_gtk,
21 enable_qt4, disable_qt4,
20 enable_qt4, disable_qt4,
22 enable_tk, disable_tk,
21 enable_tk, disable_tk,
23 set_inputhook, clear_inputhook,
22 set_inputhook, clear_inputhook,
24 current_gui
23 current_gui
25 )
24 )
26
25
27 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
28 # Code
27 # Code
29 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
@@ -1,147 +1,146 b''
1 #!/usr/bin/env python
2 # coding: utf-8
1 # coding: utf-8
3 """
2 """
4 Support for creating GUI apps and starting event loops.
3 Support for creating GUI apps and starting event loops.
5
4
6 IPython's GUI integration allows interative plotting and GUI usage in IPython
5 IPython's GUI integration allows interative plotting and GUI usage in IPython
7 session. IPython has two different types of GUI integration:
6 session. IPython has two different types of GUI integration:
8
7
9 1. The terminal based IPython supports GUI event loops through Python's
8 1. The terminal based IPython supports GUI event loops through Python's
10 PyOS_InputHook. PyOS_InputHook is a hook that Python calls periodically
9 PyOS_InputHook. PyOS_InputHook is a hook that Python calls periodically
11 whenever raw_input is waiting for a user to type code. We implement GUI
10 whenever raw_input is waiting for a user to type code. We implement GUI
12 support in the terminal by setting PyOS_InputHook to a function that
11 support in the terminal by setting PyOS_InputHook to a function that
13 iterates the event loop for a short while. It is important to note that
12 iterates the event loop for a short while. It is important to note that
14 in this situation, the real GUI event loop is NOT run in the normal
13 in this situation, the real GUI event loop is NOT run in the normal
15 manner, so you can't use the normal means to detect that it is running.
14 manner, so you can't use the normal means to detect that it is running.
16 2. In the two process IPython kernel/frontend, the GUI event loop is run in
15 2. In the two process IPython kernel/frontend, the GUI event loop is run in
17 the kernel. In this case, the event loop is run in the normal manner by
16 the kernel. In this case, the event loop is run in the normal manner by
18 calling the function or method of the GUI toolkit that starts the event
17 calling the function or method of the GUI toolkit that starts the event
19 loop.
18 loop.
20
19
21 In addition to starting the GUI event loops in one of these two ways, IPython
20 In addition to starting the GUI event loops in one of these two ways, IPython
22 will *always* create an appropriate GUI application object when GUi
21 will *always* create an appropriate GUI application object when GUi
23 integration is enabled.
22 integration is enabled.
24
23
25 If you want your GUI apps to run in IPython you need to do two things:
24 If you want your GUI apps to run in IPython you need to do two things:
26
25
27 1. Test to see if there is already an existing main application object. If
26 1. Test to see if there is already an existing main application object. If
28 there is, you should use it. If there is not an existing application object
27 there is, you should use it. If there is not an existing application object
29 you should create one.
28 you should create one.
30 2. Test to see if the GUI event loop is running. If it is, you should not
29 2. Test to see if the GUI event loop is running. If it is, you should not
31 start it. If the event loop is not running you may start it.
30 start it. If the event loop is not running you may start it.
32
31
33 This module contains functions for each toolkit that perform these things
32 This module contains functions for each toolkit that perform these things
34 in a consistent manner. Because of how PyOS_InputHook runs the event loop
33 in a consistent manner. Because of how PyOS_InputHook runs the event loop
35 you cannot detect if the event loop is running using the traditional calls
34 you cannot detect if the event loop is running using the traditional calls
36 (such as ``wx.GetApp.IsMainLoopRunning()`` in wxPython). If PyOS_InputHook is
35 (such as ``wx.GetApp.IsMainLoopRunning()`` in wxPython). If PyOS_InputHook is
37 set These methods will return a false negative. That is, they will say the
36 set These methods will return a false negative. That is, they will say the
38 event loop is not running, when is actually is. To work around this limitation
37 event loop is not running, when is actually is. To work around this limitation
39 we proposed the following informal protocol:
38 we proposed the following informal protocol:
40
39
41 * Whenever someone starts the event loop, they *must* set the ``_in_event_loop``
40 * Whenever someone starts the event loop, they *must* set the ``_in_event_loop``
42 attribute of the main application object to ``True``. This should be done
41 attribute of the main application object to ``True``. This should be done
43 regardless of how the event loop is actually run.
42 regardless of how the event loop is actually run.
44 * Whenever someone stops the event loop, they *must* set the ``_in_event_loop``
43 * Whenever someone stops the event loop, they *must* set the ``_in_event_loop``
45 attribute of the main application object to ``False``.
44 attribute of the main application object to ``False``.
46 * If you want to see if the event loop is running, you *must* use ``hasattr``
45 * If you want to see if the event loop is running, you *must* use ``hasattr``
47 to see if ``_in_event_loop`` attribute has been set. If it is set, you
46 to see if ``_in_event_loop`` attribute has been set. If it is set, you
48 *must* use its value. If it has not been set, you can query the toolkit
47 *must* use its value. If it has not been set, you can query the toolkit
49 in the normal manner.
48 in the normal manner.
50 * If you want GUI support and no one else has created an application or
49 * If you want GUI support and no one else has created an application or
51 started the event loop you *must* do this. We don't want projects to
50 started the event loop you *must* do this. We don't want projects to
52 attempt to defer these things to someone else if they themselves need it.
51 attempt to defer these things to someone else if they themselves need it.
53
52
54 The functions below implement this logic for each GUI toolkit. If you need
53 The functions below implement this logic for each GUI toolkit. If you need
55 to create custom application subclasses, you will likely have to modify this
54 to create custom application subclasses, you will likely have to modify this
56 code for your own purposes. This code can be copied into your own project
55 code for your own purposes. This code can be copied into your own project
57 so you don't have to depend on IPython.
56 so you don't have to depend on IPython.
58
57
59 """
58 """
60
59
61 #-----------------------------------------------------------------------------
60 #-----------------------------------------------------------------------------
62 # Copyright (C) 2008-2010 The IPython Development Team
61 # Copyright (C) 2008-2010 The IPython Development Team
63 #
62 #
64 # Distributed under the terms of the BSD License. The full license is in
63 # Distributed under the terms of the BSD License. The full license is in
65 # the file COPYING, distributed as part of this software.
64 # the file COPYING, distributed as part of this software.
66 #-----------------------------------------------------------------------------
65 #-----------------------------------------------------------------------------
67
66
68 #-----------------------------------------------------------------------------
67 #-----------------------------------------------------------------------------
69 # Imports
68 # Imports
70 #-----------------------------------------------------------------------------
69 #-----------------------------------------------------------------------------
71
70
72 #-----------------------------------------------------------------------------
71 #-----------------------------------------------------------------------------
73 # wx
72 # wx
74 #-----------------------------------------------------------------------------
73 #-----------------------------------------------------------------------------
75
74
76 def get_app_wx(*args, **kwargs):
75 def get_app_wx(*args, **kwargs):
77 """Create a new wx app or return an exiting one."""
76 """Create a new wx app or return an exiting one."""
78 import wx
77 import wx
79 app = wx.GetApp()
78 app = wx.GetApp()
80 if app is None:
79 if app is None:
81 if not kwargs.has_key('redirect'):
80 if not kwargs.has_key('redirect'):
82 kwargs['redirect'] = False
81 kwargs['redirect'] = False
83 app = wx.PySimpleApp(*args, **kwargs)
82 app = wx.PySimpleApp(*args, **kwargs)
84 return app
83 return app
85
84
86 def is_event_loop_running_wx(app=None):
85 def is_event_loop_running_wx(app=None):
87 """Is the wx event loop running."""
86 """Is the wx event loop running."""
88 if app is None:
87 if app is None:
89 app = get_app_wx()
88 app = get_app_wx()
90 if hasattr(app, '_in_event_loop'):
89 if hasattr(app, '_in_event_loop'):
91 return app._in_event_loop
90 return app._in_event_loop
92 else:
91 else:
93 return app.IsMainLoopRunning()
92 return app.IsMainLoopRunning()
94
93
95 def start_event_loop_wx(app=None):
94 def start_event_loop_wx(app=None):
96 """Start the wx event loop in a consistent manner."""
95 """Start the wx event loop in a consistent manner."""
97 if app is None:
96 if app is None:
98 app = get_app_wx()
97 app = get_app_wx()
99 if not is_event_loop_running_wx(app):
98 if not is_event_loop_running_wx(app):
100 app._in_event_loop = True
99 app._in_event_loop = True
101 app.MainLoop()
100 app.MainLoop()
102 app._in_event_loop = False
101 app._in_event_loop = False
103 else:
102 else:
104 app._in_event_loop = True
103 app._in_event_loop = True
105
104
106 #-----------------------------------------------------------------------------
105 #-----------------------------------------------------------------------------
107 # qt4
106 # qt4
108 #-----------------------------------------------------------------------------
107 #-----------------------------------------------------------------------------
109
108
110 def get_app_qt4(*args, **kwargs):
109 def get_app_qt4(*args, **kwargs):
111 """Create a new qt4 app or return an existing one."""
110 """Create a new qt4 app or return an existing one."""
112 from IPython.external.qt_for_kernel import QtGui
111 from IPython.external.qt_for_kernel import QtGui
113 app = QtGui.QApplication.instance()
112 app = QtGui.QApplication.instance()
114 if app is None:
113 if app is None:
115 if not args:
114 if not args:
116 args = ([''],)
115 args = ([''],)
117 app = QtGui.QApplication(*args, **kwargs)
116 app = QtGui.QApplication(*args, **kwargs)
118 return app
117 return app
119
118
120 def is_event_loop_running_qt4(app=None):
119 def is_event_loop_running_qt4(app=None):
121 """Is the qt4 event loop running."""
120 """Is the qt4 event loop running."""
122 if app is None:
121 if app is None:
123 app = get_app_qt4([''])
122 app = get_app_qt4([''])
124 if hasattr(app, '_in_event_loop'):
123 if hasattr(app, '_in_event_loop'):
125 return app._in_event_loop
124 return app._in_event_loop
126 else:
125 else:
127 # Does qt4 provide a other way to detect this?
126 # Does qt4 provide a other way to detect this?
128 return False
127 return False
129
128
130 def start_event_loop_qt4(app=None):
129 def start_event_loop_qt4(app=None):
131 """Start the qt4 event loop in a consistent manner."""
130 """Start the qt4 event loop in a consistent manner."""
132 if app is None:
131 if app is None:
133 app = get_app_qt4([''])
132 app = get_app_qt4([''])
134 if not is_event_loop_running_qt4(app):
133 if not is_event_loop_running_qt4(app):
135 app._in_event_loop = True
134 app._in_event_loop = True
136 app.exec_()
135 app.exec_()
137 app._in_event_loop = False
136 app._in_event_loop = False
138 else:
137 else:
139 app._in_event_loop = True
138 app._in_event_loop = True
140
139
141 #-----------------------------------------------------------------------------
140 #-----------------------------------------------------------------------------
142 # Tk
141 # Tk
143 #-----------------------------------------------------------------------------
142 #-----------------------------------------------------------------------------
144
143
145 #-----------------------------------------------------------------------------
144 #-----------------------------------------------------------------------------
146 # gtk
145 # gtk
147 #-----------------------------------------------------------------------------
146 #-----------------------------------------------------------------------------
@@ -1,345 +1,344 b''
1 #!/usr/bin/env python
2 # coding: utf-8
1 # coding: utf-8
3 """
2 """
4 Inputhook management for GUI event loop integration.
3 Inputhook management for GUI event loop integration.
5 """
4 """
6
5
7 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
8 # Copyright (C) 2008-2009 The IPython Development Team
7 # Copyright (C) 2008-2009 The IPython Development Team
9 #
8 #
10 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
13
12
14 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
15 # Imports
14 # Imports
16 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
17
16
18 import ctypes
17 import ctypes
19 import sys
18 import sys
20 import warnings
19 import warnings
21
20
22 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
23 # Constants
22 # Constants
24 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
25
24
26 # Constants for identifying the GUI toolkits.
25 # Constants for identifying the GUI toolkits.
27 GUI_WX = 'wx'
26 GUI_WX = 'wx'
28 GUI_QT = 'qt'
27 GUI_QT = 'qt'
29 GUI_QT4 = 'qt4'
28 GUI_QT4 = 'qt4'
30 GUI_GTK = 'gtk'
29 GUI_GTK = 'gtk'
31 GUI_TK = 'tk'
30 GUI_TK = 'tk'
32 GUI_OSX = 'osx'
31 GUI_OSX = 'osx'
33
32
34 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
35 # Utility classes
34 # Utility classes
36 #-----------------------------------------------------------------------------
35 #-----------------------------------------------------------------------------
37
36
38
37
39 #-----------------------------------------------------------------------------
38 #-----------------------------------------------------------------------------
40 # Main InputHookManager class
39 # Main InputHookManager class
41 #-----------------------------------------------------------------------------
40 #-----------------------------------------------------------------------------
42
41
43
42
44 class InputHookManager(object):
43 class InputHookManager(object):
45 """Manage PyOS_InputHook for different GUI toolkits.
44 """Manage PyOS_InputHook for different GUI toolkits.
46
45
47 This class installs various hooks under ``PyOSInputHook`` to handle
46 This class installs various hooks under ``PyOSInputHook`` to handle
48 GUI event loop integration.
47 GUI event loop integration.
49 """
48 """
50
49
51 def __init__(self):
50 def __init__(self):
52 self.PYFUNC = ctypes.PYFUNCTYPE(ctypes.c_int)
51 self.PYFUNC = ctypes.PYFUNCTYPE(ctypes.c_int)
53 self._apps = {}
52 self._apps = {}
54 self._reset()
53 self._reset()
55
54
56 def _reset(self):
55 def _reset(self):
57 self._callback_pyfunctype = None
56 self._callback_pyfunctype = None
58 self._callback = None
57 self._callback = None
59 self._installed = False
58 self._installed = False
60 self._current_gui = None
59 self._current_gui = None
61
60
62 def get_pyos_inputhook(self):
61 def get_pyos_inputhook(self):
63 """Return the current PyOS_InputHook as a ctypes.c_void_p."""
62 """Return the current PyOS_InputHook as a ctypes.c_void_p."""
64 return ctypes.c_void_p.in_dll(ctypes.pythonapi,"PyOS_InputHook")
63 return ctypes.c_void_p.in_dll(ctypes.pythonapi,"PyOS_InputHook")
65
64
66 def get_pyos_inputhook_as_func(self):
65 def get_pyos_inputhook_as_func(self):
67 """Return the current PyOS_InputHook as a ctypes.PYFUNCYPE."""
66 """Return the current PyOS_InputHook as a ctypes.PYFUNCYPE."""
68 return self.PYFUNC.in_dll(ctypes.pythonapi,"PyOS_InputHook")
67 return self.PYFUNC.in_dll(ctypes.pythonapi,"PyOS_InputHook")
69
68
70 def set_inputhook(self, callback):
69 def set_inputhook(self, callback):
71 """Set PyOS_InputHook to callback and return the previous one."""
70 """Set PyOS_InputHook to callback and return the previous one."""
72 self._callback = callback
71 self._callback = callback
73 self._callback_pyfunctype = self.PYFUNC(callback)
72 self._callback_pyfunctype = self.PYFUNC(callback)
74 pyos_inputhook_ptr = self.get_pyos_inputhook()
73 pyos_inputhook_ptr = self.get_pyos_inputhook()
75 original = self.get_pyos_inputhook_as_func()
74 original = self.get_pyos_inputhook_as_func()
76 pyos_inputhook_ptr.value = \
75 pyos_inputhook_ptr.value = \
77 ctypes.cast(self._callback_pyfunctype, ctypes.c_void_p).value
76 ctypes.cast(self._callback_pyfunctype, ctypes.c_void_p).value
78 self._installed = True
77 self._installed = True
79 return original
78 return original
80
79
81 def clear_inputhook(self, app=None):
80 def clear_inputhook(self, app=None):
82 """Set PyOS_InputHook to NULL and return the previous one.
81 """Set PyOS_InputHook to NULL and return the previous one.
83
82
84 Parameters
83 Parameters
85 ----------
84 ----------
86 app : optional, ignored
85 app : optional, ignored
87 This parameter is allowed only so that clear_inputhook() can be
86 This parameter is allowed only so that clear_inputhook() can be
88 called with a similar interface as all the ``enable_*`` methods. But
87 called with a similar interface as all the ``enable_*`` methods. But
89 the actual value of the parameter is ignored. This uniform interface
88 the actual value of the parameter is ignored. This uniform interface
90 makes it easier to have user-level entry points in the main IPython
89 makes it easier to have user-level entry points in the main IPython
91 app like :meth:`enable_gui`."""
90 app like :meth:`enable_gui`."""
92 pyos_inputhook_ptr = self.get_pyos_inputhook()
91 pyos_inputhook_ptr = self.get_pyos_inputhook()
93 original = self.get_pyos_inputhook_as_func()
92 original = self.get_pyos_inputhook_as_func()
94 pyos_inputhook_ptr.value = ctypes.c_void_p(None).value
93 pyos_inputhook_ptr.value = ctypes.c_void_p(None).value
95 self._reset()
94 self._reset()
96 return original
95 return original
97
96
98 def clear_app_refs(self, gui=None):
97 def clear_app_refs(self, gui=None):
99 """Clear IPython's internal reference to an application instance.
98 """Clear IPython's internal reference to an application instance.
100
99
101 Whenever we create an app for a user on qt4 or wx, we hold a
100 Whenever we create an app for a user on qt4 or wx, we hold a
102 reference to the app. This is needed because in some cases bad things
101 reference to the app. This is needed because in some cases bad things
103 can happen if a user doesn't hold a reference themselves. This
102 can happen if a user doesn't hold a reference themselves. This
104 method is provided to clear the references we are holding.
103 method is provided to clear the references we are holding.
105
104
106 Parameters
105 Parameters
107 ----------
106 ----------
108 gui : None or str
107 gui : None or str
109 If None, clear all app references. If ('wx', 'qt4') clear
108 If None, clear all app references. If ('wx', 'qt4') clear
110 the app for that toolkit. References are not held for gtk or tk
109 the app for that toolkit. References are not held for gtk or tk
111 as those toolkits don't have the notion of an app.
110 as those toolkits don't have the notion of an app.
112 """
111 """
113 if gui is None:
112 if gui is None:
114 self._apps = {}
113 self._apps = {}
115 elif self._apps.has_key(gui):
114 elif self._apps.has_key(gui):
116 del self._apps[gui]
115 del self._apps[gui]
117
116
118 def enable_wx(self, app=None):
117 def enable_wx(self, app=None):
119 """Enable event loop integration with wxPython.
118 """Enable event loop integration with wxPython.
120
119
121 Parameters
120 Parameters
122 ----------
121 ----------
123 app : WX Application, optional.
122 app : WX Application, optional.
124 Running application to use. If not given, we probe WX for an
123 Running application to use. If not given, we probe WX for an
125 existing application object, and create a new one if none is found.
124 existing application object, and create a new one if none is found.
126
125
127 Notes
126 Notes
128 -----
127 -----
129 This methods sets the ``PyOS_InputHook`` for wxPython, which allows
128 This methods sets the ``PyOS_InputHook`` for wxPython, which allows
130 the wxPython to integrate with terminal based applications like
129 the wxPython to integrate with terminal based applications like
131 IPython.
130 IPython.
132
131
133 If ``app`` is not given we probe for an existing one, and return it if
132 If ``app`` is not given we probe for an existing one, and return it if
134 found. If no existing app is found, we create an :class:`wx.App` as
133 found. If no existing app is found, we create an :class:`wx.App` as
135 follows::
134 follows::
136
135
137 import wx
136 import wx
138 app = wx.App(redirect=False, clearSigInt=False)
137 app = wx.App(redirect=False, clearSigInt=False)
139 """
138 """
140 from IPython.lib.inputhookwx import inputhook_wx
139 from IPython.lib.inputhookwx import inputhook_wx
141 self.set_inputhook(inputhook_wx)
140 self.set_inputhook(inputhook_wx)
142 self._current_gui = GUI_WX
141 self._current_gui = GUI_WX
143 import wx
142 import wx
144 if app is None:
143 if app is None:
145 app = wx.GetApp()
144 app = wx.GetApp()
146 if app is None:
145 if app is None:
147 app = wx.App(redirect=False, clearSigInt=False)
146 app = wx.App(redirect=False, clearSigInt=False)
148 app._in_event_loop = True
147 app._in_event_loop = True
149 self._apps[GUI_WX] = app
148 self._apps[GUI_WX] = app
150 return app
149 return app
151
150
152 def disable_wx(self):
151 def disable_wx(self):
153 """Disable event loop integration with wxPython.
152 """Disable event loop integration with wxPython.
154
153
155 This merely sets PyOS_InputHook to NULL.
154 This merely sets PyOS_InputHook to NULL.
156 """
155 """
157 if self._apps.has_key(GUI_WX):
156 if self._apps.has_key(GUI_WX):
158 self._apps[GUI_WX]._in_event_loop = False
157 self._apps[GUI_WX]._in_event_loop = False
159 self.clear_inputhook()
158 self.clear_inputhook()
160
159
161 def enable_qt4(self, app=None):
160 def enable_qt4(self, app=None):
162 """Enable event loop integration with PyQt4.
161 """Enable event loop integration with PyQt4.
163
162
164 Parameters
163 Parameters
165 ----------
164 ----------
166 app : Qt Application, optional.
165 app : Qt Application, optional.
167 Running application to use. If not given, we probe Qt for an
166 Running application to use. If not given, we probe Qt for an
168 existing application object, and create a new one if none is found.
167 existing application object, and create a new one if none is found.
169
168
170 Notes
169 Notes
171 -----
170 -----
172 This methods sets the PyOS_InputHook for PyQt4, which allows
171 This methods sets the PyOS_InputHook for PyQt4, which allows
173 the PyQt4 to integrate with terminal based applications like
172 the PyQt4 to integrate with terminal based applications like
174 IPython.
173 IPython.
175
174
176 If ``app`` is not given we probe for an existing one, and return it if
175 If ``app`` is not given we probe for an existing one, and return it if
177 found. If no existing app is found, we create an :class:`QApplication`
176 found. If no existing app is found, we create an :class:`QApplication`
178 as follows::
177 as follows::
179
178
180 from PyQt4 import QtCore
179 from PyQt4 import QtCore
181 app = QtGui.QApplication(sys.argv)
180 app = QtGui.QApplication(sys.argv)
182 """
181 """
183 from IPython.external.qt_for_kernel import QtCore, QtGui
182 from IPython.external.qt_for_kernel import QtCore, QtGui
184
183
185 if 'pyreadline' in sys.modules:
184 if 'pyreadline' in sys.modules:
186 # see IPython GitHub Issue #281 for more info on this issue
185 # see IPython GitHub Issue #281 for more info on this issue
187 # Similar intermittent behavior has been reported on OSX,
186 # Similar intermittent behavior has been reported on OSX,
188 # but not consistently reproducible
187 # but not consistently reproducible
189 warnings.warn("""PyReadline's inputhook can conflict with Qt, causing delays
188 warnings.warn("""PyReadline's inputhook can conflict with Qt, causing delays
190 in interactive input. If you do see this issue, we recommend using another GUI
189 in interactive input. If you do see this issue, we recommend using another GUI
191 toolkit if you can, or disable readline with the configuration option
190 toolkit if you can, or disable readline with the configuration option
192 'TerminalInteractiveShell.readline_use=False', specified in a config file or
191 'TerminalInteractiveShell.readline_use=False', specified in a config file or
193 at the command-line""",
192 at the command-line""",
194 RuntimeWarning)
193 RuntimeWarning)
195
194
196 # PyQt4 has had this since 4.3.1. In version 4.2, PyOS_InputHook
195 # PyQt4 has had this since 4.3.1. In version 4.2, PyOS_InputHook
197 # was set when QtCore was imported, but if it ever got removed,
196 # was set when QtCore was imported, but if it ever got removed,
198 # you couldn't reset it. For earlier versions we can
197 # you couldn't reset it. For earlier versions we can
199 # probably implement a ctypes version.
198 # probably implement a ctypes version.
200 try:
199 try:
201 QtCore.pyqtRestoreInputHook()
200 QtCore.pyqtRestoreInputHook()
202 except AttributeError:
201 except AttributeError:
203 pass
202 pass
204
203
205 self._current_gui = GUI_QT4
204 self._current_gui = GUI_QT4
206 if app is None:
205 if app is None:
207 app = QtCore.QCoreApplication.instance()
206 app = QtCore.QCoreApplication.instance()
208 if app is None:
207 if app is None:
209 app = QtGui.QApplication([" "])
208 app = QtGui.QApplication([" "])
210 app._in_event_loop = True
209 app._in_event_loop = True
211 self._apps[GUI_QT4] = app
210 self._apps[GUI_QT4] = app
212 return app
211 return app
213
212
214 def disable_qt4(self):
213 def disable_qt4(self):
215 """Disable event loop integration with PyQt4.
214 """Disable event loop integration with PyQt4.
216
215
217 This merely sets PyOS_InputHook to NULL.
216 This merely sets PyOS_InputHook to NULL.
218 """
217 """
219 if self._apps.has_key(GUI_QT4):
218 if self._apps.has_key(GUI_QT4):
220 self._apps[GUI_QT4]._in_event_loop = False
219 self._apps[GUI_QT4]._in_event_loop = False
221 self.clear_inputhook()
220 self.clear_inputhook()
222
221
223 def enable_gtk(self, app=None):
222 def enable_gtk(self, app=None):
224 """Enable event loop integration with PyGTK.
223 """Enable event loop integration with PyGTK.
225
224
226 Parameters
225 Parameters
227 ----------
226 ----------
228 app : ignored
227 app : ignored
229 Ignored, it's only a placeholder to keep the call signature of all
228 Ignored, it's only a placeholder to keep the call signature of all
230 gui activation methods consistent, which simplifies the logic of
229 gui activation methods consistent, which simplifies the logic of
231 supporting magics.
230 supporting magics.
232
231
233 Notes
232 Notes
234 -----
233 -----
235 This methods sets the PyOS_InputHook for PyGTK, which allows
234 This methods sets the PyOS_InputHook for PyGTK, which allows
236 the PyGTK to integrate with terminal based applications like
235 the PyGTK to integrate with terminal based applications like
237 IPython.
236 IPython.
238 """
237 """
239 import gtk
238 import gtk
240 try:
239 try:
241 gtk.set_interactive(True)
240 gtk.set_interactive(True)
242 self._current_gui = GUI_GTK
241 self._current_gui = GUI_GTK
243 except AttributeError:
242 except AttributeError:
244 # For older versions of gtk, use our own ctypes version
243 # For older versions of gtk, use our own ctypes version
245 from IPython.lib.inputhookgtk import inputhook_gtk
244 from IPython.lib.inputhookgtk import inputhook_gtk
246 self.set_inputhook(inputhook_gtk)
245 self.set_inputhook(inputhook_gtk)
247 self._current_gui = GUI_GTK
246 self._current_gui = GUI_GTK
248
247
249 def disable_gtk(self):
248 def disable_gtk(self):
250 """Disable event loop integration with PyGTK.
249 """Disable event loop integration with PyGTK.
251
250
252 This merely sets PyOS_InputHook to NULL.
251 This merely sets PyOS_InputHook to NULL.
253 """
252 """
254 self.clear_inputhook()
253 self.clear_inputhook()
255
254
256 def enable_tk(self, app=None):
255 def enable_tk(self, app=None):
257 """Enable event loop integration with Tk.
256 """Enable event loop integration with Tk.
258
257
259 Parameters
258 Parameters
260 ----------
259 ----------
261 app : toplevel :class:`Tkinter.Tk` widget, optional.
260 app : toplevel :class:`Tkinter.Tk` widget, optional.
262 Running toplevel widget to use. If not given, we probe Tk for an
261 Running toplevel widget to use. If not given, we probe Tk for an
263 existing one, and create a new one if none is found.
262 existing one, and create a new one if none is found.
264
263
265 Notes
264 Notes
266 -----
265 -----
267 If you have already created a :class:`Tkinter.Tk` object, the only
266 If you have already created a :class:`Tkinter.Tk` object, the only
268 thing done by this method is to register with the
267 thing done by this method is to register with the
269 :class:`InputHookManager`, since creating that object automatically
268 :class:`InputHookManager`, since creating that object automatically
270 sets ``PyOS_InputHook``.
269 sets ``PyOS_InputHook``.
271 """
270 """
272 self._current_gui = GUI_TK
271 self._current_gui = GUI_TK
273 if app is None:
272 if app is None:
274 import Tkinter
273 import Tkinter
275 app = Tkinter.Tk()
274 app = Tkinter.Tk()
276 app.withdraw()
275 app.withdraw()
277 self._apps[GUI_TK] = app
276 self._apps[GUI_TK] = app
278 return app
277 return app
279
278
280 def disable_tk(self):
279 def disable_tk(self):
281 """Disable event loop integration with Tkinter.
280 """Disable event loop integration with Tkinter.
282
281
283 This merely sets PyOS_InputHook to NULL.
282 This merely sets PyOS_InputHook to NULL.
284 """
283 """
285 self.clear_inputhook()
284 self.clear_inputhook()
286
285
287 def current_gui(self):
286 def current_gui(self):
288 """Return a string indicating the currently active GUI or None."""
287 """Return a string indicating the currently active GUI or None."""
289 return self._current_gui
288 return self._current_gui
290
289
291 inputhook_manager = InputHookManager()
290 inputhook_manager = InputHookManager()
292
291
293 enable_wx = inputhook_manager.enable_wx
292 enable_wx = inputhook_manager.enable_wx
294 disable_wx = inputhook_manager.disable_wx
293 disable_wx = inputhook_manager.disable_wx
295 enable_qt4 = inputhook_manager.enable_qt4
294 enable_qt4 = inputhook_manager.enable_qt4
296 disable_qt4 = inputhook_manager.disable_qt4
295 disable_qt4 = inputhook_manager.disable_qt4
297 enable_gtk = inputhook_manager.enable_gtk
296 enable_gtk = inputhook_manager.enable_gtk
298 disable_gtk = inputhook_manager.disable_gtk
297 disable_gtk = inputhook_manager.disable_gtk
299 enable_tk = inputhook_manager.enable_tk
298 enable_tk = inputhook_manager.enable_tk
300 disable_tk = inputhook_manager.disable_tk
299 disable_tk = inputhook_manager.disable_tk
301 clear_inputhook = inputhook_manager.clear_inputhook
300 clear_inputhook = inputhook_manager.clear_inputhook
302 set_inputhook = inputhook_manager.set_inputhook
301 set_inputhook = inputhook_manager.set_inputhook
303 current_gui = inputhook_manager.current_gui
302 current_gui = inputhook_manager.current_gui
304 clear_app_refs = inputhook_manager.clear_app_refs
303 clear_app_refs = inputhook_manager.clear_app_refs
305
304
306
305
307 # Convenience function to switch amongst them
306 # Convenience function to switch amongst them
308 def enable_gui(gui=None, app=None):
307 def enable_gui(gui=None, app=None):
309 """Switch amongst GUI input hooks by name.
308 """Switch amongst GUI input hooks by name.
310
309
311 This is just a utility wrapper around the methods of the InputHookManager
310 This is just a utility wrapper around the methods of the InputHookManager
312 object.
311 object.
313
312
314 Parameters
313 Parameters
315 ----------
314 ----------
316 gui : optional, string or None
315 gui : optional, string or None
317 If None, clears input hook, otherwise it must be one of the recognized
316 If None, clears input hook, otherwise it must be one of the recognized
318 GUI names (see ``GUI_*`` constants in module).
317 GUI names (see ``GUI_*`` constants in module).
319
318
320 app : optional, existing application object.
319 app : optional, existing application object.
321 For toolkits that have the concept of a global app, you can supply an
320 For toolkits that have the concept of a global app, you can supply an
322 existing one. If not given, the toolkit will be probed for one, and if
321 existing one. If not given, the toolkit will be probed for one, and if
323 none is found, a new one will be created. Note that GTK does not have
322 none is found, a new one will be created. Note that GTK does not have
324 this concept, and passing an app if `gui`=="GTK" will raise an error.
323 this concept, and passing an app if `gui`=="GTK" will raise an error.
325
324
326 Returns
325 Returns
327 -------
326 -------
328 The output of the underlying gui switch routine, typically the actual
327 The output of the underlying gui switch routine, typically the actual
329 PyOS_InputHook wrapper object or the GUI toolkit app created, if there was
328 PyOS_InputHook wrapper object or the GUI toolkit app created, if there was
330 one.
329 one.
331 """
330 """
332 guis = {None: clear_inputhook,
331 guis = {None: clear_inputhook,
333 GUI_OSX: lambda app=False: None,
332 GUI_OSX: lambda app=False: None,
334 GUI_TK: enable_tk,
333 GUI_TK: enable_tk,
335 GUI_GTK: enable_gtk,
334 GUI_GTK: enable_gtk,
336 GUI_WX: enable_wx,
335 GUI_WX: enable_wx,
337 GUI_QT: enable_qt4, # qt3 not supported
336 GUI_QT: enable_qt4, # qt3 not supported
338 GUI_QT4: enable_qt4 }
337 GUI_QT4: enable_qt4 }
339 try:
338 try:
340 gui_hook = guis[gui]
339 gui_hook = guis[gui]
341 except KeyError:
340 except KeyError:
342 e = "Invalid GUI request %r, valid ones are:%s" % (gui, guis.keys())
341 e = "Invalid GUI request %r, valid ones are:%s" % (gui, guis.keys())
343 raise ValueError(e)
342 raise ValueError(e)
344 return gui_hook(app)
343 return gui_hook(app)
345
344
@@ -1,36 +1,35 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 Enable pygtk to be used interacive by setting PyOS_InputHook.
3 Enable pygtk to be used interacive by setting PyOS_InputHook.
5
4
6 Authors: Brian Granger
5 Authors: Brian Granger
7 """
6 """
8
7
9 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
10 # Copyright (C) 2008-2009 The IPython Development Team
9 # Copyright (C) 2008-2009 The IPython Development Team
11 #
10 #
12 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
13 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
14 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
15
14
16 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
17 # Imports
16 # Imports
18 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
19
18
20 import sys
19 import sys
21 import gtk, gobject
20 import gtk, gobject
22
21
23 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
24 # Code
23 # Code
25 #-----------------------------------------------------------------------------
24 #-----------------------------------------------------------------------------
26
25
27
26
28 def _main_quit(*args, **kwargs):
27 def _main_quit(*args, **kwargs):
29 gtk.main_quit()
28 gtk.main_quit()
30 return False
29 return False
31
30
32 def inputhook_gtk():
31 def inputhook_gtk():
33 gobject.io_add_watch(sys.stdin, gobject.IO_IN, _main_quit)
32 gobject.io_add_watch(sys.stdin, gobject.IO_IN, _main_quit)
34 gtk.main()
33 gtk.main()
35 return 0
34 return 0
36
35
@@ -1,179 +1,178 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3
2
4 """
3 """
5 Enable wxPython to be used interacive by setting PyOS_InputHook.
4 Enable wxPython to be used interacive by setting PyOS_InputHook.
6
5
7 Authors: Robin Dunn, Brian Granger, Ondrej Certik
6 Authors: Robin Dunn, Brian Granger, Ondrej Certik
8 """
7 """
9
8
10 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
11 # Copyright (C) 2008-2009 The IPython Development Team
10 # Copyright (C) 2008-2009 The IPython Development Team
12 #
11 #
13 # Distributed under the terms of the BSD License. The full license is in
12 # Distributed under the terms of the BSD License. The full license is in
14 # the file COPYING, distributed as part of this software.
13 # the file COPYING, distributed as part of this software.
15 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
16
15
17 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
18 # Imports
17 # Imports
19 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
20
19
21 import os
20 import os
22 import signal
21 import signal
23 import sys
22 import sys
24 import time
23 import time
25 from timeit import default_timer as clock
24 from timeit import default_timer as clock
26 import wx
25 import wx
27
26
28 if os.name == 'posix':
27 if os.name == 'posix':
29 import select
28 import select
30 elif sys.platform == 'win32':
29 elif sys.platform == 'win32':
31 import msvcrt
30 import msvcrt
32
31
33 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
34 # Code
33 # Code
35 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
36
35
37 def stdin_ready():
36 def stdin_ready():
38 if os.name == 'posix':
37 if os.name == 'posix':
39 infds, outfds, erfds = select.select([sys.stdin],[],[],0)
38 infds, outfds, erfds = select.select([sys.stdin],[],[],0)
40 if infds:
39 if infds:
41 return True
40 return True
42 else:
41 else:
43 return False
42 return False
44 elif sys.platform == 'win32':
43 elif sys.platform == 'win32':
45 return msvcrt.kbhit()
44 return msvcrt.kbhit()
46
45
47
46
48 def inputhook_wx1():
47 def inputhook_wx1():
49 """Run the wx event loop by processing pending events only.
48 """Run the wx event loop by processing pending events only.
50
49
51 This approach seems to work, but its performance is not great as it
50 This approach seems to work, but its performance is not great as it
52 relies on having PyOS_InputHook called regularly.
51 relies on having PyOS_InputHook called regularly.
53 """
52 """
54 try:
53 try:
55 app = wx.GetApp()
54 app = wx.GetApp()
56 if app is not None:
55 if app is not None:
57 assert wx.Thread_IsMain()
56 assert wx.Thread_IsMain()
58
57
59 # Make a temporary event loop and process system events until
58 # Make a temporary event loop and process system events until
60 # there are no more waiting, then allow idle events (which
59 # there are no more waiting, then allow idle events (which
61 # will also deal with pending or posted wx events.)
60 # will also deal with pending or posted wx events.)
62 evtloop = wx.EventLoop()
61 evtloop = wx.EventLoop()
63 ea = wx.EventLoopActivator(evtloop)
62 ea = wx.EventLoopActivator(evtloop)
64 while evtloop.Pending():
63 while evtloop.Pending():
65 evtloop.Dispatch()
64 evtloop.Dispatch()
66 app.ProcessIdle()
65 app.ProcessIdle()
67 del ea
66 del ea
68 except KeyboardInterrupt:
67 except KeyboardInterrupt:
69 pass
68 pass
70 return 0
69 return 0
71
70
72 class EventLoopTimer(wx.Timer):
71 class EventLoopTimer(wx.Timer):
73
72
74 def __init__(self, func):
73 def __init__(self, func):
75 self.func = func
74 self.func = func
76 wx.Timer.__init__(self)
75 wx.Timer.__init__(self)
77
76
78 def Notify(self):
77 def Notify(self):
79 self.func()
78 self.func()
80
79
81 class EventLoopRunner(object):
80 class EventLoopRunner(object):
82
81
83 def Run(self, time):
82 def Run(self, time):
84 self.evtloop = wx.EventLoop()
83 self.evtloop = wx.EventLoop()
85 self.timer = EventLoopTimer(self.check_stdin)
84 self.timer = EventLoopTimer(self.check_stdin)
86 self.timer.Start(time)
85 self.timer.Start(time)
87 self.evtloop.Run()
86 self.evtloop.Run()
88
87
89 def check_stdin(self):
88 def check_stdin(self):
90 if stdin_ready():
89 if stdin_ready():
91 self.timer.Stop()
90 self.timer.Stop()
92 self.evtloop.Exit()
91 self.evtloop.Exit()
93
92
94 def inputhook_wx2():
93 def inputhook_wx2():
95 """Run the wx event loop, polling for stdin.
94 """Run the wx event loop, polling for stdin.
96
95
97 This version runs the wx eventloop for an undetermined amount of time,
96 This version runs the wx eventloop for an undetermined amount of time,
98 during which it periodically checks to see if anything is ready on
97 during which it periodically checks to see if anything is ready on
99 stdin. If anything is ready on stdin, the event loop exits.
98 stdin. If anything is ready on stdin, the event loop exits.
100
99
101 The argument to elr.Run controls how often the event loop looks at stdin.
100 The argument to elr.Run controls how often the event loop looks at stdin.
102 This determines the responsiveness at the keyboard. A setting of 1000
101 This determines the responsiveness at the keyboard. A setting of 1000
103 enables a user to type at most 1 char per second. I have found that a
102 enables a user to type at most 1 char per second. I have found that a
104 setting of 10 gives good keyboard response. We can shorten it further,
103 setting of 10 gives good keyboard response. We can shorten it further,
105 but eventually performance would suffer from calling select/kbhit too
104 but eventually performance would suffer from calling select/kbhit too
106 often.
105 often.
107 """
106 """
108 try:
107 try:
109 app = wx.GetApp()
108 app = wx.GetApp()
110 if app is not None:
109 if app is not None:
111 assert wx.Thread_IsMain()
110 assert wx.Thread_IsMain()
112 elr = EventLoopRunner()
111 elr = EventLoopRunner()
113 # As this time is made shorter, keyboard response improves, but idle
112 # As this time is made shorter, keyboard response improves, but idle
114 # CPU load goes up. 10 ms seems like a good compromise.
113 # CPU load goes up. 10 ms seems like a good compromise.
115 elr.Run(time=10) # CHANGE time here to control polling interval
114 elr.Run(time=10) # CHANGE time here to control polling interval
116 except KeyboardInterrupt:
115 except KeyboardInterrupt:
117 pass
116 pass
118 return 0
117 return 0
119
118
120 def inputhook_wx3():
119 def inputhook_wx3():
121 """Run the wx event loop by processing pending events only.
120 """Run the wx event loop by processing pending events only.
122
121
123 This is like inputhook_wx1, but it keeps processing pending events
122 This is like inputhook_wx1, but it keeps processing pending events
124 until stdin is ready. After processing all pending events, a call to
123 until stdin is ready. After processing all pending events, a call to
125 time.sleep is inserted. This is needed, otherwise, CPU usage is at 100%.
124 time.sleep is inserted. This is needed, otherwise, CPU usage is at 100%.
126 This sleep time should be tuned though for best performance.
125 This sleep time should be tuned though for best performance.
127 """
126 """
128 # We need to protect against a user pressing Control-C when IPython is
127 # We need to protect against a user pressing Control-C when IPython is
129 # idle and this is running. We trap KeyboardInterrupt and pass.
128 # idle and this is running. We trap KeyboardInterrupt and pass.
130 try:
129 try:
131 app = wx.GetApp()
130 app = wx.GetApp()
132 if app is not None:
131 if app is not None:
133 assert wx.Thread_IsMain()
132 assert wx.Thread_IsMain()
134
133
135 # The import of wx on Linux sets the handler for signal.SIGINT
134 # The import of wx on Linux sets the handler for signal.SIGINT
136 # to 0. This is a bug in wx or gtk. We fix by just setting it
135 # to 0. This is a bug in wx or gtk. We fix by just setting it
137 # back to the Python default.
136 # back to the Python default.
138 if not callable(signal.getsignal(signal.SIGINT)):
137 if not callable(signal.getsignal(signal.SIGINT)):
139 signal.signal(signal.SIGINT, signal.default_int_handler)
138 signal.signal(signal.SIGINT, signal.default_int_handler)
140
139
141 evtloop = wx.EventLoop()
140 evtloop = wx.EventLoop()
142 ea = wx.EventLoopActivator(evtloop)
141 ea = wx.EventLoopActivator(evtloop)
143 t = clock()
142 t = clock()
144 while not stdin_ready():
143 while not stdin_ready():
145 while evtloop.Pending():
144 while evtloop.Pending():
146 t = clock()
145 t = clock()
147 evtloop.Dispatch()
146 evtloop.Dispatch()
148 app.ProcessIdle()
147 app.ProcessIdle()
149 # We need to sleep at this point to keep the idle CPU load
148 # We need to sleep at this point to keep the idle CPU load
150 # low. However, if sleep to long, GUI response is poor. As
149 # low. However, if sleep to long, GUI response is poor. As
151 # a compromise, we watch how often GUI events are being processed
150 # a compromise, we watch how often GUI events are being processed
152 # and switch between a short and long sleep time. Here are some
151 # and switch between a short and long sleep time. Here are some
153 # stats useful in helping to tune this.
152 # stats useful in helping to tune this.
154 # time CPU load
153 # time CPU load
155 # 0.001 13%
154 # 0.001 13%
156 # 0.005 3%
155 # 0.005 3%
157 # 0.01 1.5%
156 # 0.01 1.5%
158 # 0.05 0.5%
157 # 0.05 0.5%
159 used_time = clock() - t
158 used_time = clock() - t
160 if used_time > 5*60.0:
159 if used_time > 5*60.0:
161 # print 'Sleep for 5 s' # dbg
160 # print 'Sleep for 5 s' # dbg
162 time.sleep(5.0)
161 time.sleep(5.0)
163 elif used_time > 10.0:
162 elif used_time > 10.0:
164 # print 'Sleep for 1 s' # dbg
163 # print 'Sleep for 1 s' # dbg
165 time.sleep(1.0)
164 time.sleep(1.0)
166 elif used_time > 0.1:
165 elif used_time > 0.1:
167 # Few GUI events coming in, so we can sleep longer
166 # Few GUI events coming in, so we can sleep longer
168 # print 'Sleep for 0.05 s' # dbg
167 # print 'Sleep for 0.05 s' # dbg
169 time.sleep(0.05)
168 time.sleep(0.05)
170 else:
169 else:
171 # Many GUI events coming in, so sleep only very little
170 # Many GUI events coming in, so sleep only very little
172 time.sleep(0.001)
171 time.sleep(0.001)
173 del ea
172 del ea
174 except KeyboardInterrupt:
173 except KeyboardInterrupt:
175 pass
174 pass
176 return 0
175 return 0
177
176
178 # This is our default implementation
177 # This is our default implementation
179 inputhook_wx = inputhook_wx3
178 inputhook_wx = inputhook_wx3
@@ -1,14 +1,13 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3
2
4 def test_import_backgroundjobs():
3 def test_import_backgroundjobs():
5 from IPython.lib import backgroundjobs
4 from IPython.lib import backgroundjobs
6
5
7 def test_import_deepreload():
6 def test_import_deepreload():
8 from IPython.lib import deepreload
7 from IPython.lib import deepreload
9
8
10 def test_import_demo():
9 def test_import_demo():
11 from IPython.lib import demo
10 from IPython.lib import demo
12
11
13 def test_import_irunner():
12 def test_import_irunner():
14 from IPython.lib import demo
13 from IPython.lib import demo
@@ -1,241 +1,240 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 The Base Application class for IPython.parallel apps
3 The Base Application class for IPython.parallel apps
5
4
6 Authors:
5 Authors:
7
6
8 * Brian Granger
7 * Brian Granger
9 * Min RK
8 * Min RK
10
9
11 """
10 """
12
11
13 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
14 # Copyright (C) 2008-2011 The IPython Development Team
13 # Copyright (C) 2008-2011 The IPython Development Team
15 #
14 #
16 # Distributed under the terms of the BSD License. The full license is in
15 # Distributed under the terms of the BSD License. The full license is in
17 # the file COPYING, distributed as part of this software.
16 # the file COPYING, distributed as part of this software.
18 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
19
18
20 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
21 # Imports
20 # Imports
22 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
23
22
24 from __future__ import with_statement
23 from __future__ import with_statement
25
24
26 import os
25 import os
27 import logging
26 import logging
28 import re
27 import re
29 import sys
28 import sys
30
29
31 from subprocess import Popen, PIPE
30 from subprocess import Popen, PIPE
32
31
33 from IPython.core import release
32 from IPython.core import release
34 from IPython.core.crashhandler import CrashHandler
33 from IPython.core.crashhandler import CrashHandler
35 from IPython.core.application import (
34 from IPython.core.application import (
36 BaseIPythonApplication,
35 BaseIPythonApplication,
37 base_aliases as base_ip_aliases,
36 base_aliases as base_ip_aliases,
38 base_flags as base_ip_flags
37 base_flags as base_ip_flags
39 )
38 )
40 from IPython.utils.path import expand_path
39 from IPython.utils.path import expand_path
41
40
42 from IPython.utils.traitlets import Unicode, Bool, Instance, Dict, List
41 from IPython.utils.traitlets import Unicode, Bool, Instance, Dict, List
43
42
44 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
45 # Module errors
44 # Module errors
46 #-----------------------------------------------------------------------------
45 #-----------------------------------------------------------------------------
47
46
48 class PIDFileError(Exception):
47 class PIDFileError(Exception):
49 pass
48 pass
50
49
51
50
52 #-----------------------------------------------------------------------------
51 #-----------------------------------------------------------------------------
53 # Crash handler for this application
52 # Crash handler for this application
54 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
55
54
56 class ParallelCrashHandler(CrashHandler):
55 class ParallelCrashHandler(CrashHandler):
57 """sys.excepthook for IPython itself, leaves a detailed report on disk."""
56 """sys.excepthook for IPython itself, leaves a detailed report on disk."""
58
57
59 def __init__(self, app):
58 def __init__(self, app):
60 contact_name = release.authors['Min'][0]
59 contact_name = release.authors['Min'][0]
61 contact_email = release.authors['Min'][1]
60 contact_email = release.authors['Min'][1]
62 bug_tracker = 'http://github.com/ipython/ipython/issues'
61 bug_tracker = 'http://github.com/ipython/ipython/issues'
63 super(ParallelCrashHandler,self).__init__(
62 super(ParallelCrashHandler,self).__init__(
64 app, contact_name, contact_email, bug_tracker
63 app, contact_name, contact_email, bug_tracker
65 )
64 )
66
65
67
66
68 #-----------------------------------------------------------------------------
67 #-----------------------------------------------------------------------------
69 # Main application
68 # Main application
70 #-----------------------------------------------------------------------------
69 #-----------------------------------------------------------------------------
71 base_aliases = {}
70 base_aliases = {}
72 base_aliases.update(base_ip_aliases)
71 base_aliases.update(base_ip_aliases)
73 base_aliases.update({
72 base_aliases.update({
74 'profile-dir' : 'ProfileDir.location',
73 'profile-dir' : 'ProfileDir.location',
75 'work-dir' : 'BaseParallelApplication.work_dir',
74 'work-dir' : 'BaseParallelApplication.work_dir',
76 'log-to-file' : 'BaseParallelApplication.log_to_file',
75 'log-to-file' : 'BaseParallelApplication.log_to_file',
77 'clean-logs' : 'BaseParallelApplication.clean_logs',
76 'clean-logs' : 'BaseParallelApplication.clean_logs',
78 'log-url' : 'BaseParallelApplication.log_url',
77 'log-url' : 'BaseParallelApplication.log_url',
79 })
78 })
80
79
81 base_flags = {
80 base_flags = {
82 'log-to-file' : (
81 'log-to-file' : (
83 {'BaseParallelApplication' : {'log_to_file' : True}},
82 {'BaseParallelApplication' : {'log_to_file' : True}},
84 "send log output to a file"
83 "send log output to a file"
85 )
84 )
86 }
85 }
87 base_flags.update(base_ip_flags)
86 base_flags.update(base_ip_flags)
88
87
89 class BaseParallelApplication(BaseIPythonApplication):
88 class BaseParallelApplication(BaseIPythonApplication):
90 """The base Application for IPython.parallel apps
89 """The base Application for IPython.parallel apps
91
90
92 Principle extensions to BaseIPyythonApplication:
91 Principle extensions to BaseIPyythonApplication:
93
92
94 * work_dir
93 * work_dir
95 * remote logging via pyzmq
94 * remote logging via pyzmq
96 * IOLoop instance
95 * IOLoop instance
97 """
96 """
98
97
99 crash_handler_class = ParallelCrashHandler
98 crash_handler_class = ParallelCrashHandler
100
99
101 def _log_level_default(self):
100 def _log_level_default(self):
102 # temporarily override default_log_level to INFO
101 # temporarily override default_log_level to INFO
103 return logging.INFO
102 return logging.INFO
104
103
105 work_dir = Unicode(os.getcwdu(), config=True,
104 work_dir = Unicode(os.getcwdu(), config=True,
106 help='Set the working dir for the process.'
105 help='Set the working dir for the process.'
107 )
106 )
108 def _work_dir_changed(self, name, old, new):
107 def _work_dir_changed(self, name, old, new):
109 self.work_dir = unicode(expand_path(new))
108 self.work_dir = unicode(expand_path(new))
110
109
111 log_to_file = Bool(config=True,
110 log_to_file = Bool(config=True,
112 help="whether to log to a file")
111 help="whether to log to a file")
113
112
114 clean_logs = Bool(False, config=True,
113 clean_logs = Bool(False, config=True,
115 help="whether to cleanup old logfiles before starting")
114 help="whether to cleanup old logfiles before starting")
116
115
117 log_url = Unicode('', config=True,
116 log_url = Unicode('', config=True,
118 help="The ZMQ URL of the iplogger to aggregate logging.")
117 help="The ZMQ URL of the iplogger to aggregate logging.")
119
118
120 def _config_files_default(self):
119 def _config_files_default(self):
121 return ['ipcontroller_config.py', 'ipengine_config.py', 'ipcluster_config.py']
120 return ['ipcontroller_config.py', 'ipengine_config.py', 'ipcluster_config.py']
122
121
123 loop = Instance('zmq.eventloop.ioloop.IOLoop')
122 loop = Instance('zmq.eventloop.ioloop.IOLoop')
124 def _loop_default(self):
123 def _loop_default(self):
125 from zmq.eventloop.ioloop import IOLoop
124 from zmq.eventloop.ioloop import IOLoop
126 return IOLoop.instance()
125 return IOLoop.instance()
127
126
128 aliases = Dict(base_aliases)
127 aliases = Dict(base_aliases)
129 flags = Dict(base_flags)
128 flags = Dict(base_flags)
130
129
131 def initialize(self, argv=None):
130 def initialize(self, argv=None):
132 """initialize the app"""
131 """initialize the app"""
133 super(BaseParallelApplication, self).initialize(argv)
132 super(BaseParallelApplication, self).initialize(argv)
134 self.to_work_dir()
133 self.to_work_dir()
135 self.reinit_logging()
134 self.reinit_logging()
136
135
137 def to_work_dir(self):
136 def to_work_dir(self):
138 wd = self.work_dir
137 wd = self.work_dir
139 if unicode(wd) != os.getcwdu():
138 if unicode(wd) != os.getcwdu():
140 os.chdir(wd)
139 os.chdir(wd)
141 self.log.info("Changing to working dir: %s" % wd)
140 self.log.info("Changing to working dir: %s" % wd)
142 # This is the working dir by now.
141 # This is the working dir by now.
143 sys.path.insert(0, '')
142 sys.path.insert(0, '')
144
143
145 def reinit_logging(self):
144 def reinit_logging(self):
146 # Remove old log files
145 # Remove old log files
147 log_dir = self.profile_dir.log_dir
146 log_dir = self.profile_dir.log_dir
148 if self.clean_logs:
147 if self.clean_logs:
149 for f in os.listdir(log_dir):
148 for f in os.listdir(log_dir):
150 if re.match(r'%s-\d+\.(log|err|out)'%self.name,f):
149 if re.match(r'%s-\d+\.(log|err|out)'%self.name,f):
151 os.remove(os.path.join(log_dir, f))
150 os.remove(os.path.join(log_dir, f))
152 if self.log_to_file:
151 if self.log_to_file:
153 # Start logging to the new log file
152 # Start logging to the new log file
154 log_filename = self.name + u'-' + str(os.getpid()) + u'.log'
153 log_filename = self.name + u'-' + str(os.getpid()) + u'.log'
155 logfile = os.path.join(log_dir, log_filename)
154 logfile = os.path.join(log_dir, log_filename)
156 open_log_file = open(logfile, 'w')
155 open_log_file = open(logfile, 'w')
157 else:
156 else:
158 open_log_file = None
157 open_log_file = None
159 if open_log_file is not None:
158 if open_log_file is not None:
160 self.log.removeHandler(self._log_handler)
159 self.log.removeHandler(self._log_handler)
161 self._log_handler = logging.StreamHandler(open_log_file)
160 self._log_handler = logging.StreamHandler(open_log_file)
162 self._log_formatter = logging.Formatter("[%(name)s] %(message)s")
161 self._log_formatter = logging.Formatter("[%(name)s] %(message)s")
163 self._log_handler.setFormatter(self._log_formatter)
162 self._log_handler.setFormatter(self._log_formatter)
164 self.log.addHandler(self._log_handler)
163 self.log.addHandler(self._log_handler)
165
164
166 def write_pid_file(self, overwrite=False):
165 def write_pid_file(self, overwrite=False):
167 """Create a .pid file in the pid_dir with my pid.
166 """Create a .pid file in the pid_dir with my pid.
168
167
169 This must be called after pre_construct, which sets `self.pid_dir`.
168 This must be called after pre_construct, which sets `self.pid_dir`.
170 This raises :exc:`PIDFileError` if the pid file exists already.
169 This raises :exc:`PIDFileError` if the pid file exists already.
171 """
170 """
172 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
171 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
173 if os.path.isfile(pid_file):
172 if os.path.isfile(pid_file):
174 pid = self.get_pid_from_file()
173 pid = self.get_pid_from_file()
175 if not overwrite:
174 if not overwrite:
176 raise PIDFileError(
175 raise PIDFileError(
177 'The pid file [%s] already exists. \nThis could mean that this '
176 'The pid file [%s] already exists. \nThis could mean that this '
178 'server is already running with [pid=%s].' % (pid_file, pid)
177 'server is already running with [pid=%s].' % (pid_file, pid)
179 )
178 )
180 with open(pid_file, 'w') as f:
179 with open(pid_file, 'w') as f:
181 self.log.info("Creating pid file: %s" % pid_file)
180 self.log.info("Creating pid file: %s" % pid_file)
182 f.write(repr(os.getpid())+'\n')
181 f.write(repr(os.getpid())+'\n')
183
182
184 def remove_pid_file(self):
183 def remove_pid_file(self):
185 """Remove the pid file.
184 """Remove the pid file.
186
185
187 This should be called at shutdown by registering a callback with
186 This should be called at shutdown by registering a callback with
188 :func:`reactor.addSystemEventTrigger`. This needs to return
187 :func:`reactor.addSystemEventTrigger`. This needs to return
189 ``None``.
188 ``None``.
190 """
189 """
191 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
190 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
192 if os.path.isfile(pid_file):
191 if os.path.isfile(pid_file):
193 try:
192 try:
194 self.log.info("Removing pid file: %s" % pid_file)
193 self.log.info("Removing pid file: %s" % pid_file)
195 os.remove(pid_file)
194 os.remove(pid_file)
196 except:
195 except:
197 self.log.warn("Error removing the pid file: %s" % pid_file)
196 self.log.warn("Error removing the pid file: %s" % pid_file)
198
197
199 def get_pid_from_file(self):
198 def get_pid_from_file(self):
200 """Get the pid from the pid file.
199 """Get the pid from the pid file.
201
200
202 If the pid file doesn't exist a :exc:`PIDFileError` is raised.
201 If the pid file doesn't exist a :exc:`PIDFileError` is raised.
203 """
202 """
204 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
203 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
205 if os.path.isfile(pid_file):
204 if os.path.isfile(pid_file):
206 with open(pid_file, 'r') as f:
205 with open(pid_file, 'r') as f:
207 s = f.read().strip()
206 s = f.read().strip()
208 try:
207 try:
209 pid = int(s)
208 pid = int(s)
210 except:
209 except:
211 raise PIDFileError("invalid pid file: %s (contents: %r)"%(pid_file, s))
210 raise PIDFileError("invalid pid file: %s (contents: %r)"%(pid_file, s))
212 return pid
211 return pid
213 else:
212 else:
214 raise PIDFileError('pid file not found: %s' % pid_file)
213 raise PIDFileError('pid file not found: %s' % pid_file)
215
214
216 def check_pid(self, pid):
215 def check_pid(self, pid):
217 if os.name == 'nt':
216 if os.name == 'nt':
218 try:
217 try:
219 import ctypes
218 import ctypes
220 # returns 0 if no such process (of ours) exists
219 # returns 0 if no such process (of ours) exists
221 # positive int otherwise
220 # positive int otherwise
222 p = ctypes.windll.kernel32.OpenProcess(1,0,pid)
221 p = ctypes.windll.kernel32.OpenProcess(1,0,pid)
223 except Exception:
222 except Exception:
224 self.log.warn(
223 self.log.warn(
225 "Could not determine whether pid %i is running via `OpenProcess`. "
224 "Could not determine whether pid %i is running via `OpenProcess`. "
226 " Making the likely assumption that it is."%pid
225 " Making the likely assumption that it is."%pid
227 )
226 )
228 return True
227 return True
229 return bool(p)
228 return bool(p)
230 else:
229 else:
231 try:
230 try:
232 p = Popen(['ps','x'], stdout=PIPE, stderr=PIPE)
231 p = Popen(['ps','x'], stdout=PIPE, stderr=PIPE)
233 output,_ = p.communicate()
232 output,_ = p.communicate()
234 except OSError:
233 except OSError:
235 self.log.warn(
234 self.log.warn(
236 "Could not determine whether pid %i is running via `ps x`. "
235 "Could not determine whether pid %i is running via `ps x`. "
237 " Making the likely assumption that it is."%pid
236 " Making the likely assumption that it is."%pid
238 )
237 )
239 return True
238 return True
240 pids = map(int, re.findall(r'^\W*\d+', output, re.MULTILINE))
239 pids = map(int, re.findall(r'^\W*\d+', output, re.MULTILINE))
241 return pid in pids
240 return pid in pids
@@ -1,1142 +1,1141 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 Facilities for launching IPython processes asynchronously.
3 Facilities for launching IPython processes asynchronously.
5
4
6 Authors:
5 Authors:
7
6
8 * Brian Granger
7 * Brian Granger
9 * MinRK
8 * MinRK
10 """
9 """
11
10
12 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
13 # Copyright (C) 2008-2011 The IPython Development Team
12 # Copyright (C) 2008-2011 The IPython Development Team
14 #
13 #
15 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
16 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
17 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
18
17
19 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
20 # Imports
19 # Imports
21 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
22
21
23 import copy
22 import copy
24 import logging
23 import logging
25 import os
24 import os
26 import re
25 import re
27 import stat
26 import stat
28
27
29 # signal imports, handling various platforms, versions
28 # signal imports, handling various platforms, versions
30
29
31 from signal import SIGINT, SIGTERM
30 from signal import SIGINT, SIGTERM
32 try:
31 try:
33 from signal import SIGKILL
32 from signal import SIGKILL
34 except ImportError:
33 except ImportError:
35 # Windows
34 # Windows
36 SIGKILL=SIGTERM
35 SIGKILL=SIGTERM
37
36
38 try:
37 try:
39 # Windows >= 2.7, 3.2
38 # Windows >= 2.7, 3.2
40 from signal import CTRL_C_EVENT as SIGINT
39 from signal import CTRL_C_EVENT as SIGINT
41 except ImportError:
40 except ImportError:
42 pass
41 pass
43
42
44 from subprocess import Popen, PIPE, STDOUT
43 from subprocess import Popen, PIPE, STDOUT
45 try:
44 try:
46 from subprocess import check_output
45 from subprocess import check_output
47 except ImportError:
46 except ImportError:
48 # pre-2.7, define check_output with Popen
47 # pre-2.7, define check_output with Popen
49 def check_output(*args, **kwargs):
48 def check_output(*args, **kwargs):
50 kwargs.update(dict(stdout=PIPE))
49 kwargs.update(dict(stdout=PIPE))
51 p = Popen(*args, **kwargs)
50 p = Popen(*args, **kwargs)
52 out,err = p.communicate()
51 out,err = p.communicate()
53 return out
52 return out
54
53
55 from zmq.eventloop import ioloop
54 from zmq.eventloop import ioloop
56
55
57 from IPython.config.application import Application
56 from IPython.config.application import Application
58 from IPython.config.configurable import LoggingConfigurable
57 from IPython.config.configurable import LoggingConfigurable
59 from IPython.utils.text import EvalFormatter
58 from IPython.utils.text import EvalFormatter
60 from IPython.utils.traitlets import Any, Int, List, Unicode, Dict, Instance
59 from IPython.utils.traitlets import Any, Int, List, Unicode, Dict, Instance
61 from IPython.utils.path import get_ipython_module_path
60 from IPython.utils.path import get_ipython_module_path
62 from IPython.utils.process import find_cmd, pycmd2argv, FindCmdError
61 from IPython.utils.process import find_cmd, pycmd2argv, FindCmdError
63
62
64 from .win32support import forward_read_events
63 from .win32support import forward_read_events
65
64
66 from .winhpcjob import IPControllerTask, IPEngineTask, IPControllerJob, IPEngineSetJob
65 from .winhpcjob import IPControllerTask, IPEngineTask, IPControllerJob, IPEngineSetJob
67
66
68 WINDOWS = os.name == 'nt'
67 WINDOWS = os.name == 'nt'
69
68
70 #-----------------------------------------------------------------------------
69 #-----------------------------------------------------------------------------
71 # Paths to the kernel apps
70 # Paths to the kernel apps
72 #-----------------------------------------------------------------------------
71 #-----------------------------------------------------------------------------
73
72
74
73
75 ipcluster_cmd_argv = pycmd2argv(get_ipython_module_path(
74 ipcluster_cmd_argv = pycmd2argv(get_ipython_module_path(
76 'IPython.parallel.apps.ipclusterapp'
75 'IPython.parallel.apps.ipclusterapp'
77 ))
76 ))
78
77
79 ipengine_cmd_argv = pycmd2argv(get_ipython_module_path(
78 ipengine_cmd_argv = pycmd2argv(get_ipython_module_path(
80 'IPython.parallel.apps.ipengineapp'
79 'IPython.parallel.apps.ipengineapp'
81 ))
80 ))
82
81
83 ipcontroller_cmd_argv = pycmd2argv(get_ipython_module_path(
82 ipcontroller_cmd_argv = pycmd2argv(get_ipython_module_path(
84 'IPython.parallel.apps.ipcontrollerapp'
83 'IPython.parallel.apps.ipcontrollerapp'
85 ))
84 ))
86
85
87 #-----------------------------------------------------------------------------
86 #-----------------------------------------------------------------------------
88 # Base launchers and errors
87 # Base launchers and errors
89 #-----------------------------------------------------------------------------
88 #-----------------------------------------------------------------------------
90
89
91
90
92 class LauncherError(Exception):
91 class LauncherError(Exception):
93 pass
92 pass
94
93
95
94
96 class ProcessStateError(LauncherError):
95 class ProcessStateError(LauncherError):
97 pass
96 pass
98
97
99
98
100 class UnknownStatus(LauncherError):
99 class UnknownStatus(LauncherError):
101 pass
100 pass
102
101
103
102
104 class BaseLauncher(LoggingConfigurable):
103 class BaseLauncher(LoggingConfigurable):
105 """An asbtraction for starting, stopping and signaling a process."""
104 """An asbtraction for starting, stopping and signaling a process."""
106
105
107 # In all of the launchers, the work_dir is where child processes will be
106 # In all of the launchers, the work_dir is where child processes will be
108 # run. This will usually be the profile_dir, but may not be. any work_dir
107 # run. This will usually be the profile_dir, but may not be. any work_dir
109 # passed into the __init__ method will override the config value.
108 # passed into the __init__ method will override the config value.
110 # This should not be used to set the work_dir for the actual engine
109 # This should not be used to set the work_dir for the actual engine
111 # and controller. Instead, use their own config files or the
110 # and controller. Instead, use their own config files or the
112 # controller_args, engine_args attributes of the launchers to add
111 # controller_args, engine_args attributes of the launchers to add
113 # the work_dir option.
112 # the work_dir option.
114 work_dir = Unicode(u'.')
113 work_dir = Unicode(u'.')
115 loop = Instance('zmq.eventloop.ioloop.IOLoop')
114 loop = Instance('zmq.eventloop.ioloop.IOLoop')
116
115
117 start_data = Any()
116 start_data = Any()
118 stop_data = Any()
117 stop_data = Any()
119
118
120 def _loop_default(self):
119 def _loop_default(self):
121 return ioloop.IOLoop.instance()
120 return ioloop.IOLoop.instance()
122
121
123 def __init__(self, work_dir=u'.', config=None, **kwargs):
122 def __init__(self, work_dir=u'.', config=None, **kwargs):
124 super(BaseLauncher, self).__init__(work_dir=work_dir, config=config, **kwargs)
123 super(BaseLauncher, self).__init__(work_dir=work_dir, config=config, **kwargs)
125 self.state = 'before' # can be before, running, after
124 self.state = 'before' # can be before, running, after
126 self.stop_callbacks = []
125 self.stop_callbacks = []
127 self.start_data = None
126 self.start_data = None
128 self.stop_data = None
127 self.stop_data = None
129
128
130 @property
129 @property
131 def args(self):
130 def args(self):
132 """A list of cmd and args that will be used to start the process.
131 """A list of cmd and args that will be used to start the process.
133
132
134 This is what is passed to :func:`spawnProcess` and the first element
133 This is what is passed to :func:`spawnProcess` and the first element
135 will be the process name.
134 will be the process name.
136 """
135 """
137 return self.find_args()
136 return self.find_args()
138
137
139 def find_args(self):
138 def find_args(self):
140 """The ``.args`` property calls this to find the args list.
139 """The ``.args`` property calls this to find the args list.
141
140
142 Subcommand should implement this to construct the cmd and args.
141 Subcommand should implement this to construct the cmd and args.
143 """
142 """
144 raise NotImplementedError('find_args must be implemented in a subclass')
143 raise NotImplementedError('find_args must be implemented in a subclass')
145
144
146 @property
145 @property
147 def arg_str(self):
146 def arg_str(self):
148 """The string form of the program arguments."""
147 """The string form of the program arguments."""
149 return ' '.join(self.args)
148 return ' '.join(self.args)
150
149
151 @property
150 @property
152 def running(self):
151 def running(self):
153 """Am I running."""
152 """Am I running."""
154 if self.state == 'running':
153 if self.state == 'running':
155 return True
154 return True
156 else:
155 else:
157 return False
156 return False
158
157
159 def start(self):
158 def start(self):
160 """Start the process."""
159 """Start the process."""
161 raise NotImplementedError('start must be implemented in a subclass')
160 raise NotImplementedError('start must be implemented in a subclass')
162
161
163 def stop(self):
162 def stop(self):
164 """Stop the process and notify observers of stopping.
163 """Stop the process and notify observers of stopping.
165
164
166 This method will return None immediately.
165 This method will return None immediately.
167 To observe the actual process stopping, see :meth:`on_stop`.
166 To observe the actual process stopping, see :meth:`on_stop`.
168 """
167 """
169 raise NotImplementedError('stop must be implemented in a subclass')
168 raise NotImplementedError('stop must be implemented in a subclass')
170
169
171 def on_stop(self, f):
170 def on_stop(self, f):
172 """Register a callback to be called with this Launcher's stop_data
171 """Register a callback to be called with this Launcher's stop_data
173 when the process actually finishes.
172 when the process actually finishes.
174 """
173 """
175 if self.state=='after':
174 if self.state=='after':
176 return f(self.stop_data)
175 return f(self.stop_data)
177 else:
176 else:
178 self.stop_callbacks.append(f)
177 self.stop_callbacks.append(f)
179
178
180 def notify_start(self, data):
179 def notify_start(self, data):
181 """Call this to trigger startup actions.
180 """Call this to trigger startup actions.
182
181
183 This logs the process startup and sets the state to 'running'. It is
182 This logs the process startup and sets the state to 'running'. It is
184 a pass-through so it can be used as a callback.
183 a pass-through so it can be used as a callback.
185 """
184 """
186
185
187 self.log.info('Process %r started: %r' % (self.args[0], data))
186 self.log.info('Process %r started: %r' % (self.args[0], data))
188 self.start_data = data
187 self.start_data = data
189 self.state = 'running'
188 self.state = 'running'
190 return data
189 return data
191
190
192 def notify_stop(self, data):
191 def notify_stop(self, data):
193 """Call this to trigger process stop actions.
192 """Call this to trigger process stop actions.
194
193
195 This logs the process stopping and sets the state to 'after'. Call
194 This logs the process stopping and sets the state to 'after'. Call
196 this to trigger callbacks registered via :meth:`on_stop`."""
195 this to trigger callbacks registered via :meth:`on_stop`."""
197
196
198 self.log.info('Process %r stopped: %r' % (self.args[0], data))
197 self.log.info('Process %r stopped: %r' % (self.args[0], data))
199 self.stop_data = data
198 self.stop_data = data
200 self.state = 'after'
199 self.state = 'after'
201 for i in range(len(self.stop_callbacks)):
200 for i in range(len(self.stop_callbacks)):
202 d = self.stop_callbacks.pop()
201 d = self.stop_callbacks.pop()
203 d(data)
202 d(data)
204 return data
203 return data
205
204
206 def signal(self, sig):
205 def signal(self, sig):
207 """Signal the process.
206 """Signal the process.
208
207
209 Parameters
208 Parameters
210 ----------
209 ----------
211 sig : str or int
210 sig : str or int
212 'KILL', 'INT', etc., or any signal number
211 'KILL', 'INT', etc., or any signal number
213 """
212 """
214 raise NotImplementedError('signal must be implemented in a subclass')
213 raise NotImplementedError('signal must be implemented in a subclass')
215
214
216
215
217 #-----------------------------------------------------------------------------
216 #-----------------------------------------------------------------------------
218 # Local process launchers
217 # Local process launchers
219 #-----------------------------------------------------------------------------
218 #-----------------------------------------------------------------------------
220
219
221
220
222 class LocalProcessLauncher(BaseLauncher):
221 class LocalProcessLauncher(BaseLauncher):
223 """Start and stop an external process in an asynchronous manner.
222 """Start and stop an external process in an asynchronous manner.
224
223
225 This will launch the external process with a working directory of
224 This will launch the external process with a working directory of
226 ``self.work_dir``.
225 ``self.work_dir``.
227 """
226 """
228
227
229 # This is used to to construct self.args, which is passed to
228 # This is used to to construct self.args, which is passed to
230 # spawnProcess.
229 # spawnProcess.
231 cmd_and_args = List([])
230 cmd_and_args = List([])
232 poll_frequency = Int(100) # in ms
231 poll_frequency = Int(100) # in ms
233
232
234 def __init__(self, work_dir=u'.', config=None, **kwargs):
233 def __init__(self, work_dir=u'.', config=None, **kwargs):
235 super(LocalProcessLauncher, self).__init__(
234 super(LocalProcessLauncher, self).__init__(
236 work_dir=work_dir, config=config, **kwargs
235 work_dir=work_dir, config=config, **kwargs
237 )
236 )
238 self.process = None
237 self.process = None
239 self.poller = None
238 self.poller = None
240
239
241 def find_args(self):
240 def find_args(self):
242 return self.cmd_and_args
241 return self.cmd_and_args
243
242
244 def start(self):
243 def start(self):
245 if self.state == 'before':
244 if self.state == 'before':
246 self.process = Popen(self.args,
245 self.process = Popen(self.args,
247 stdout=PIPE,stderr=PIPE,stdin=PIPE,
246 stdout=PIPE,stderr=PIPE,stdin=PIPE,
248 env=os.environ,
247 env=os.environ,
249 cwd=self.work_dir
248 cwd=self.work_dir
250 )
249 )
251 if WINDOWS:
250 if WINDOWS:
252 self.stdout = forward_read_events(self.process.stdout)
251 self.stdout = forward_read_events(self.process.stdout)
253 self.stderr = forward_read_events(self.process.stderr)
252 self.stderr = forward_read_events(self.process.stderr)
254 else:
253 else:
255 self.stdout = self.process.stdout.fileno()
254 self.stdout = self.process.stdout.fileno()
256 self.stderr = self.process.stderr.fileno()
255 self.stderr = self.process.stderr.fileno()
257 self.loop.add_handler(self.stdout, self.handle_stdout, self.loop.READ)
256 self.loop.add_handler(self.stdout, self.handle_stdout, self.loop.READ)
258 self.loop.add_handler(self.stderr, self.handle_stderr, self.loop.READ)
257 self.loop.add_handler(self.stderr, self.handle_stderr, self.loop.READ)
259 self.poller = ioloop.PeriodicCallback(self.poll, self.poll_frequency, self.loop)
258 self.poller = ioloop.PeriodicCallback(self.poll, self.poll_frequency, self.loop)
260 self.poller.start()
259 self.poller.start()
261 self.notify_start(self.process.pid)
260 self.notify_start(self.process.pid)
262 else:
261 else:
263 s = 'The process was already started and has state: %r' % self.state
262 s = 'The process was already started and has state: %r' % self.state
264 raise ProcessStateError(s)
263 raise ProcessStateError(s)
265
264
266 def stop(self):
265 def stop(self):
267 return self.interrupt_then_kill()
266 return self.interrupt_then_kill()
268
267
269 def signal(self, sig):
268 def signal(self, sig):
270 if self.state == 'running':
269 if self.state == 'running':
271 if WINDOWS and sig != SIGINT:
270 if WINDOWS and sig != SIGINT:
272 # use Windows tree-kill for better child cleanup
271 # use Windows tree-kill for better child cleanup
273 check_output(['taskkill', '-pid', str(self.process.pid), '-t', '-f'])
272 check_output(['taskkill', '-pid', str(self.process.pid), '-t', '-f'])
274 else:
273 else:
275 self.process.send_signal(sig)
274 self.process.send_signal(sig)
276
275
277 def interrupt_then_kill(self, delay=2.0):
276 def interrupt_then_kill(self, delay=2.0):
278 """Send INT, wait a delay and then send KILL."""
277 """Send INT, wait a delay and then send KILL."""
279 try:
278 try:
280 self.signal(SIGINT)
279 self.signal(SIGINT)
281 except Exception:
280 except Exception:
282 self.log.debug("interrupt failed")
281 self.log.debug("interrupt failed")
283 pass
282 pass
284 self.killer = ioloop.DelayedCallback(lambda : self.signal(SIGKILL), delay*1000, self.loop)
283 self.killer = ioloop.DelayedCallback(lambda : self.signal(SIGKILL), delay*1000, self.loop)
285 self.killer.start()
284 self.killer.start()
286
285
287 # callbacks, etc:
286 # callbacks, etc:
288
287
289 def handle_stdout(self, fd, events):
288 def handle_stdout(self, fd, events):
290 if WINDOWS:
289 if WINDOWS:
291 line = self.stdout.recv()
290 line = self.stdout.recv()
292 else:
291 else:
293 line = self.process.stdout.readline()
292 line = self.process.stdout.readline()
294 # a stopped process will be readable but return empty strings
293 # a stopped process will be readable but return empty strings
295 if line:
294 if line:
296 self.log.info(line[:-1])
295 self.log.info(line[:-1])
297 else:
296 else:
298 self.poll()
297 self.poll()
299
298
300 def handle_stderr(self, fd, events):
299 def handle_stderr(self, fd, events):
301 if WINDOWS:
300 if WINDOWS:
302 line = self.stderr.recv()
301 line = self.stderr.recv()
303 else:
302 else:
304 line = self.process.stderr.readline()
303 line = self.process.stderr.readline()
305 # a stopped process will be readable but return empty strings
304 # a stopped process will be readable but return empty strings
306 if line:
305 if line:
307 self.log.error(line[:-1])
306 self.log.error(line[:-1])
308 else:
307 else:
309 self.poll()
308 self.poll()
310
309
311 def poll(self):
310 def poll(self):
312 status = self.process.poll()
311 status = self.process.poll()
313 if status is not None:
312 if status is not None:
314 self.poller.stop()
313 self.poller.stop()
315 self.loop.remove_handler(self.stdout)
314 self.loop.remove_handler(self.stdout)
316 self.loop.remove_handler(self.stderr)
315 self.loop.remove_handler(self.stderr)
317 self.notify_stop(dict(exit_code=status, pid=self.process.pid))
316 self.notify_stop(dict(exit_code=status, pid=self.process.pid))
318 return status
317 return status
319
318
320 class LocalControllerLauncher(LocalProcessLauncher):
319 class LocalControllerLauncher(LocalProcessLauncher):
321 """Launch a controller as a regular external process."""
320 """Launch a controller as a regular external process."""
322
321
323 controller_cmd = List(ipcontroller_cmd_argv, config=True,
322 controller_cmd = List(ipcontroller_cmd_argv, config=True,
324 help="""Popen command to launch ipcontroller.""")
323 help="""Popen command to launch ipcontroller.""")
325 # Command line arguments to ipcontroller.
324 # Command line arguments to ipcontroller.
326 controller_args = List(['--log-to-file','--log-level=%i'%logging.INFO], config=True,
325 controller_args = List(['--log-to-file','--log-level=%i'%logging.INFO], config=True,
327 help="""command-line args to pass to ipcontroller""")
326 help="""command-line args to pass to ipcontroller""")
328
327
329 def find_args(self):
328 def find_args(self):
330 return self.controller_cmd + self.controller_args
329 return self.controller_cmd + self.controller_args
331
330
332 def start(self, profile_dir):
331 def start(self, profile_dir):
333 """Start the controller by profile_dir."""
332 """Start the controller by profile_dir."""
334 self.controller_args.extend(['--profile-dir=%s'%profile_dir])
333 self.controller_args.extend(['--profile-dir=%s'%profile_dir])
335 self.profile_dir = unicode(profile_dir)
334 self.profile_dir = unicode(profile_dir)
336 self.log.info("Starting LocalControllerLauncher: %r" % self.args)
335 self.log.info("Starting LocalControllerLauncher: %r" % self.args)
337 return super(LocalControllerLauncher, self).start()
336 return super(LocalControllerLauncher, self).start()
338
337
339
338
340 class LocalEngineLauncher(LocalProcessLauncher):
339 class LocalEngineLauncher(LocalProcessLauncher):
341 """Launch a single engine as a regular externall process."""
340 """Launch a single engine as a regular externall process."""
342
341
343 engine_cmd = List(ipengine_cmd_argv, config=True,
342 engine_cmd = List(ipengine_cmd_argv, config=True,
344 help="""command to launch the Engine.""")
343 help="""command to launch the Engine.""")
345 # Command line arguments for ipengine.
344 # Command line arguments for ipengine.
346 engine_args = List(['--log-to-file','--log-level=%i'%logging.INFO], config=True,
345 engine_args = List(['--log-to-file','--log-level=%i'%logging.INFO], config=True,
347 help="command-line arguments to pass to ipengine"
346 help="command-line arguments to pass to ipengine"
348 )
347 )
349
348
350 def find_args(self):
349 def find_args(self):
351 return self.engine_cmd + self.engine_args
350 return self.engine_cmd + self.engine_args
352
351
353 def start(self, profile_dir):
352 def start(self, profile_dir):
354 """Start the engine by profile_dir."""
353 """Start the engine by profile_dir."""
355 self.engine_args.extend(['--profile-dir=%s'%profile_dir])
354 self.engine_args.extend(['--profile-dir=%s'%profile_dir])
356 self.profile_dir = unicode(profile_dir)
355 self.profile_dir = unicode(profile_dir)
357 return super(LocalEngineLauncher, self).start()
356 return super(LocalEngineLauncher, self).start()
358
357
359
358
360 class LocalEngineSetLauncher(BaseLauncher):
359 class LocalEngineSetLauncher(BaseLauncher):
361 """Launch a set of engines as regular external processes."""
360 """Launch a set of engines as regular external processes."""
362
361
363 # Command line arguments for ipengine.
362 # Command line arguments for ipengine.
364 engine_args = List(
363 engine_args = List(
365 ['--log-to-file','--log-level=%i'%logging.INFO], config=True,
364 ['--log-to-file','--log-level=%i'%logging.INFO], config=True,
366 help="command-line arguments to pass to ipengine"
365 help="command-line arguments to pass to ipengine"
367 )
366 )
368 # launcher class
367 # launcher class
369 launcher_class = LocalEngineLauncher
368 launcher_class = LocalEngineLauncher
370
369
371 launchers = Dict()
370 launchers = Dict()
372 stop_data = Dict()
371 stop_data = Dict()
373
372
374 def __init__(self, work_dir=u'.', config=None, **kwargs):
373 def __init__(self, work_dir=u'.', config=None, **kwargs):
375 super(LocalEngineSetLauncher, self).__init__(
374 super(LocalEngineSetLauncher, self).__init__(
376 work_dir=work_dir, config=config, **kwargs
375 work_dir=work_dir, config=config, **kwargs
377 )
376 )
378 self.stop_data = {}
377 self.stop_data = {}
379
378
380 def start(self, n, profile_dir):
379 def start(self, n, profile_dir):
381 """Start n engines by profile or profile_dir."""
380 """Start n engines by profile or profile_dir."""
382 self.profile_dir = unicode(profile_dir)
381 self.profile_dir = unicode(profile_dir)
383 dlist = []
382 dlist = []
384 for i in range(n):
383 for i in range(n):
385 el = self.launcher_class(work_dir=self.work_dir, config=self.config, log=self.log)
384 el = self.launcher_class(work_dir=self.work_dir, config=self.config, log=self.log)
386 # Copy the engine args over to each engine launcher.
385 # Copy the engine args over to each engine launcher.
387 el.engine_args = copy.deepcopy(self.engine_args)
386 el.engine_args = copy.deepcopy(self.engine_args)
388 el.on_stop(self._notice_engine_stopped)
387 el.on_stop(self._notice_engine_stopped)
389 d = el.start(profile_dir)
388 d = el.start(profile_dir)
390 if i==0:
389 if i==0:
391 self.log.info("Starting LocalEngineSetLauncher: %r" % el.args)
390 self.log.info("Starting LocalEngineSetLauncher: %r" % el.args)
392 self.launchers[i] = el
391 self.launchers[i] = el
393 dlist.append(d)
392 dlist.append(d)
394 self.notify_start(dlist)
393 self.notify_start(dlist)
395 # The consumeErrors here could be dangerous
394 # The consumeErrors here could be dangerous
396 # dfinal = gatherBoth(dlist, consumeErrors=True)
395 # dfinal = gatherBoth(dlist, consumeErrors=True)
397 # dfinal.addCallback(self.notify_start)
396 # dfinal.addCallback(self.notify_start)
398 return dlist
397 return dlist
399
398
400 def find_args(self):
399 def find_args(self):
401 return ['engine set']
400 return ['engine set']
402
401
403 def signal(self, sig):
402 def signal(self, sig):
404 dlist = []
403 dlist = []
405 for el in self.launchers.itervalues():
404 for el in self.launchers.itervalues():
406 d = el.signal(sig)
405 d = el.signal(sig)
407 dlist.append(d)
406 dlist.append(d)
408 # dfinal = gatherBoth(dlist, consumeErrors=True)
407 # dfinal = gatherBoth(dlist, consumeErrors=True)
409 return dlist
408 return dlist
410
409
411 def interrupt_then_kill(self, delay=1.0):
410 def interrupt_then_kill(self, delay=1.0):
412 dlist = []
411 dlist = []
413 for el in self.launchers.itervalues():
412 for el in self.launchers.itervalues():
414 d = el.interrupt_then_kill(delay)
413 d = el.interrupt_then_kill(delay)
415 dlist.append(d)
414 dlist.append(d)
416 # dfinal = gatherBoth(dlist, consumeErrors=True)
415 # dfinal = gatherBoth(dlist, consumeErrors=True)
417 return dlist
416 return dlist
418
417
419 def stop(self):
418 def stop(self):
420 return self.interrupt_then_kill()
419 return self.interrupt_then_kill()
421
420
422 def _notice_engine_stopped(self, data):
421 def _notice_engine_stopped(self, data):
423 pid = data['pid']
422 pid = data['pid']
424 for idx,el in self.launchers.iteritems():
423 for idx,el in self.launchers.iteritems():
425 if el.process.pid == pid:
424 if el.process.pid == pid:
426 break
425 break
427 self.launchers.pop(idx)
426 self.launchers.pop(idx)
428 self.stop_data[idx] = data
427 self.stop_data[idx] = data
429 if not self.launchers:
428 if not self.launchers:
430 self.notify_stop(self.stop_data)
429 self.notify_stop(self.stop_data)
431
430
432
431
433 #-----------------------------------------------------------------------------
432 #-----------------------------------------------------------------------------
434 # MPIExec launchers
433 # MPIExec launchers
435 #-----------------------------------------------------------------------------
434 #-----------------------------------------------------------------------------
436
435
437
436
438 class MPIExecLauncher(LocalProcessLauncher):
437 class MPIExecLauncher(LocalProcessLauncher):
439 """Launch an external process using mpiexec."""
438 """Launch an external process using mpiexec."""
440
439
441 mpi_cmd = List(['mpiexec'], config=True,
440 mpi_cmd = List(['mpiexec'], config=True,
442 help="The mpiexec command to use in starting the process."
441 help="The mpiexec command to use in starting the process."
443 )
442 )
444 mpi_args = List([], config=True,
443 mpi_args = List([], config=True,
445 help="The command line arguments to pass to mpiexec."
444 help="The command line arguments to pass to mpiexec."
446 )
445 )
447 program = List(['date'], config=True,
446 program = List(['date'], config=True,
448 help="The program to start via mpiexec.")
447 help="The program to start via mpiexec.")
449 program_args = List([], config=True,
448 program_args = List([], config=True,
450 help="The command line argument to the program."
449 help="The command line argument to the program."
451 )
450 )
452 n = Int(1)
451 n = Int(1)
453
452
454 def find_args(self):
453 def find_args(self):
455 """Build self.args using all the fields."""
454 """Build self.args using all the fields."""
456 return self.mpi_cmd + ['-n', str(self.n)] + self.mpi_args + \
455 return self.mpi_cmd + ['-n', str(self.n)] + self.mpi_args + \
457 self.program + self.program_args
456 self.program + self.program_args
458
457
459 def start(self, n):
458 def start(self, n):
460 """Start n instances of the program using mpiexec."""
459 """Start n instances of the program using mpiexec."""
461 self.n = n
460 self.n = n
462 return super(MPIExecLauncher, self).start()
461 return super(MPIExecLauncher, self).start()
463
462
464
463
465 class MPIExecControllerLauncher(MPIExecLauncher):
464 class MPIExecControllerLauncher(MPIExecLauncher):
466 """Launch a controller using mpiexec."""
465 """Launch a controller using mpiexec."""
467
466
468 controller_cmd = List(ipcontroller_cmd_argv, config=True,
467 controller_cmd = List(ipcontroller_cmd_argv, config=True,
469 help="Popen command to launch the Contropper"
468 help="Popen command to launch the Contropper"
470 )
469 )
471 controller_args = List(['--log-to-file','--log-level=%i'%logging.INFO], config=True,
470 controller_args = List(['--log-to-file','--log-level=%i'%logging.INFO], config=True,
472 help="Command line arguments to pass to ipcontroller."
471 help="Command line arguments to pass to ipcontroller."
473 )
472 )
474 n = Int(1)
473 n = Int(1)
475
474
476 def start(self, profile_dir):
475 def start(self, profile_dir):
477 """Start the controller by profile_dir."""
476 """Start the controller by profile_dir."""
478 self.controller_args.extend(['--profile-dir=%s'%profile_dir])
477 self.controller_args.extend(['--profile-dir=%s'%profile_dir])
479 self.profile_dir = unicode(profile_dir)
478 self.profile_dir = unicode(profile_dir)
480 self.log.info("Starting MPIExecControllerLauncher: %r" % self.args)
479 self.log.info("Starting MPIExecControllerLauncher: %r" % self.args)
481 return super(MPIExecControllerLauncher, self).start(1)
480 return super(MPIExecControllerLauncher, self).start(1)
482
481
483 def find_args(self):
482 def find_args(self):
484 return self.mpi_cmd + ['-n', str(self.n)] + self.mpi_args + \
483 return self.mpi_cmd + ['-n', str(self.n)] + self.mpi_args + \
485 self.controller_cmd + self.controller_args
484 self.controller_cmd + self.controller_args
486
485
487
486
488 class MPIExecEngineSetLauncher(MPIExecLauncher):
487 class MPIExecEngineSetLauncher(MPIExecLauncher):
489
488
490 program = List(ipengine_cmd_argv, config=True,
489 program = List(ipengine_cmd_argv, config=True,
491 help="Popen command for ipengine"
490 help="Popen command for ipengine"
492 )
491 )
493 program_args = List(
492 program_args = List(
494 ['--log-to-file','--log-level=%i'%logging.INFO], config=True,
493 ['--log-to-file','--log-level=%i'%logging.INFO], config=True,
495 help="Command line arguments for ipengine."
494 help="Command line arguments for ipengine."
496 )
495 )
497 n = Int(1)
496 n = Int(1)
498
497
499 def start(self, n, profile_dir):
498 def start(self, n, profile_dir):
500 """Start n engines by profile or profile_dir."""
499 """Start n engines by profile or profile_dir."""
501 self.program_args.extend(['--profile-dir=%s'%profile_dir])
500 self.program_args.extend(['--profile-dir=%s'%profile_dir])
502 self.profile_dir = unicode(profile_dir)
501 self.profile_dir = unicode(profile_dir)
503 self.n = n
502 self.n = n
504 self.log.info('Starting MPIExecEngineSetLauncher: %r' % self.args)
503 self.log.info('Starting MPIExecEngineSetLauncher: %r' % self.args)
505 return super(MPIExecEngineSetLauncher, self).start(n)
504 return super(MPIExecEngineSetLauncher, self).start(n)
506
505
507 #-----------------------------------------------------------------------------
506 #-----------------------------------------------------------------------------
508 # SSH launchers
507 # SSH launchers
509 #-----------------------------------------------------------------------------
508 #-----------------------------------------------------------------------------
510
509
511 # TODO: Get SSH Launcher back to level of sshx in 0.10.2
510 # TODO: Get SSH Launcher back to level of sshx in 0.10.2
512
511
513 class SSHLauncher(LocalProcessLauncher):
512 class SSHLauncher(LocalProcessLauncher):
514 """A minimal launcher for ssh.
513 """A minimal launcher for ssh.
515
514
516 To be useful this will probably have to be extended to use the ``sshx``
515 To be useful this will probably have to be extended to use the ``sshx``
517 idea for environment variables. There could be other things this needs
516 idea for environment variables. There could be other things this needs
518 as well.
517 as well.
519 """
518 """
520
519
521 ssh_cmd = List(['ssh'], config=True,
520 ssh_cmd = List(['ssh'], config=True,
522 help="command for starting ssh")
521 help="command for starting ssh")
523 ssh_args = List(['-tt'], config=True,
522 ssh_args = List(['-tt'], config=True,
524 help="args to pass to ssh")
523 help="args to pass to ssh")
525 program = List(['date'], config=True,
524 program = List(['date'], config=True,
526 help="Program to launch via ssh")
525 help="Program to launch via ssh")
527 program_args = List([], config=True,
526 program_args = List([], config=True,
528 help="args to pass to remote program")
527 help="args to pass to remote program")
529 hostname = Unicode('', config=True,
528 hostname = Unicode('', config=True,
530 help="hostname on which to launch the program")
529 help="hostname on which to launch the program")
531 user = Unicode('', config=True,
530 user = Unicode('', config=True,
532 help="username for ssh")
531 help="username for ssh")
533 location = Unicode('', config=True,
532 location = Unicode('', config=True,
534 help="user@hostname location for ssh in one setting")
533 help="user@hostname location for ssh in one setting")
535
534
536 def _hostname_changed(self, name, old, new):
535 def _hostname_changed(self, name, old, new):
537 if self.user:
536 if self.user:
538 self.location = u'%s@%s' % (self.user, new)
537 self.location = u'%s@%s' % (self.user, new)
539 else:
538 else:
540 self.location = new
539 self.location = new
541
540
542 def _user_changed(self, name, old, new):
541 def _user_changed(self, name, old, new):
543 self.location = u'%s@%s' % (new, self.hostname)
542 self.location = u'%s@%s' % (new, self.hostname)
544
543
545 def find_args(self):
544 def find_args(self):
546 return self.ssh_cmd + self.ssh_args + [self.location] + \
545 return self.ssh_cmd + self.ssh_args + [self.location] + \
547 self.program + self.program_args
546 self.program + self.program_args
548
547
549 def start(self, profile_dir, hostname=None, user=None):
548 def start(self, profile_dir, hostname=None, user=None):
550 self.profile_dir = unicode(profile_dir)
549 self.profile_dir = unicode(profile_dir)
551 if hostname is not None:
550 if hostname is not None:
552 self.hostname = hostname
551 self.hostname = hostname
553 if user is not None:
552 if user is not None:
554 self.user = user
553 self.user = user
555
554
556 return super(SSHLauncher, self).start()
555 return super(SSHLauncher, self).start()
557
556
558 def signal(self, sig):
557 def signal(self, sig):
559 if self.state == 'running':
558 if self.state == 'running':
560 # send escaped ssh connection-closer
559 # send escaped ssh connection-closer
561 self.process.stdin.write('~.')
560 self.process.stdin.write('~.')
562 self.process.stdin.flush()
561 self.process.stdin.flush()
563
562
564
563
565
564
566 class SSHControllerLauncher(SSHLauncher):
565 class SSHControllerLauncher(SSHLauncher):
567
566
568 program = List(ipcontroller_cmd_argv, config=True,
567 program = List(ipcontroller_cmd_argv, config=True,
569 help="remote ipcontroller command.")
568 help="remote ipcontroller command.")
570 program_args = List(['--reuse-files', '--log-to-file','--log-level=%i'%logging.INFO], config=True,
569 program_args = List(['--reuse-files', '--log-to-file','--log-level=%i'%logging.INFO], config=True,
571 help="Command line arguments to ipcontroller.")
570 help="Command line arguments to ipcontroller.")
572
571
573
572
574 class SSHEngineLauncher(SSHLauncher):
573 class SSHEngineLauncher(SSHLauncher):
575 program = List(ipengine_cmd_argv, config=True,
574 program = List(ipengine_cmd_argv, config=True,
576 help="remote ipengine command.")
575 help="remote ipengine command.")
577 # Command line arguments for ipengine.
576 # Command line arguments for ipengine.
578 program_args = List(
577 program_args = List(
579 ['--log-to-file','--log_level=%i'%logging.INFO], config=True,
578 ['--log-to-file','--log_level=%i'%logging.INFO], config=True,
580 help="Command line arguments to ipengine."
579 help="Command line arguments to ipengine."
581 )
580 )
582
581
583 class SSHEngineSetLauncher(LocalEngineSetLauncher):
582 class SSHEngineSetLauncher(LocalEngineSetLauncher):
584 launcher_class = SSHEngineLauncher
583 launcher_class = SSHEngineLauncher
585 engines = Dict(config=True,
584 engines = Dict(config=True,
586 help="""dict of engines to launch. This is a dict by hostname of ints,
585 help="""dict of engines to launch. This is a dict by hostname of ints,
587 corresponding to the number of engines to start on that host.""")
586 corresponding to the number of engines to start on that host.""")
588
587
589 def start(self, n, profile_dir):
588 def start(self, n, profile_dir):
590 """Start engines by profile or profile_dir.
589 """Start engines by profile or profile_dir.
591 `n` is ignored, and the `engines` config property is used instead.
590 `n` is ignored, and the `engines` config property is used instead.
592 """
591 """
593
592
594 self.profile_dir = unicode(profile_dir)
593 self.profile_dir = unicode(profile_dir)
595 dlist = []
594 dlist = []
596 for host, n in self.engines.iteritems():
595 for host, n in self.engines.iteritems():
597 if isinstance(n, (tuple, list)):
596 if isinstance(n, (tuple, list)):
598 n, args = n
597 n, args = n
599 else:
598 else:
600 args = copy.deepcopy(self.engine_args)
599 args = copy.deepcopy(self.engine_args)
601
600
602 if '@' in host:
601 if '@' in host:
603 user,host = host.split('@',1)
602 user,host = host.split('@',1)
604 else:
603 else:
605 user=None
604 user=None
606 for i in range(n):
605 for i in range(n):
607 el = self.launcher_class(work_dir=self.work_dir, config=self.config, log=self.log)
606 el = self.launcher_class(work_dir=self.work_dir, config=self.config, log=self.log)
608
607
609 # Copy the engine args over to each engine launcher.
608 # Copy the engine args over to each engine launcher.
610 i
609 i
611 el.program_args = args
610 el.program_args = args
612 el.on_stop(self._notice_engine_stopped)
611 el.on_stop(self._notice_engine_stopped)
613 d = el.start(profile_dir, user=user, hostname=host)
612 d = el.start(profile_dir, user=user, hostname=host)
614 if i==0:
613 if i==0:
615 self.log.info("Starting SSHEngineSetLauncher: %r" % el.args)
614 self.log.info("Starting SSHEngineSetLauncher: %r" % el.args)
616 self.launchers[host+str(i)] = el
615 self.launchers[host+str(i)] = el
617 dlist.append(d)
616 dlist.append(d)
618 self.notify_start(dlist)
617 self.notify_start(dlist)
619 return dlist
618 return dlist
620
619
621
620
622
621
623 #-----------------------------------------------------------------------------
622 #-----------------------------------------------------------------------------
624 # Windows HPC Server 2008 scheduler launchers
623 # Windows HPC Server 2008 scheduler launchers
625 #-----------------------------------------------------------------------------
624 #-----------------------------------------------------------------------------
626
625
627
626
628 # This is only used on Windows.
627 # This is only used on Windows.
629 def find_job_cmd():
628 def find_job_cmd():
630 if WINDOWS:
629 if WINDOWS:
631 try:
630 try:
632 return find_cmd('job')
631 return find_cmd('job')
633 except (FindCmdError, ImportError):
632 except (FindCmdError, ImportError):
634 # ImportError will be raised if win32api is not installed
633 # ImportError will be raised if win32api is not installed
635 return 'job'
634 return 'job'
636 else:
635 else:
637 return 'job'
636 return 'job'
638
637
639
638
640 class WindowsHPCLauncher(BaseLauncher):
639 class WindowsHPCLauncher(BaseLauncher):
641
640
642 job_id_regexp = Unicode(r'\d+', config=True,
641 job_id_regexp = Unicode(r'\d+', config=True,
643 help="""A regular expression used to get the job id from the output of the
642 help="""A regular expression used to get the job id from the output of the
644 submit_command. """
643 submit_command. """
645 )
644 )
646 job_file_name = Unicode(u'ipython_job.xml', config=True,
645 job_file_name = Unicode(u'ipython_job.xml', config=True,
647 help="The filename of the instantiated job script.")
646 help="The filename of the instantiated job script.")
648 # The full path to the instantiated job script. This gets made dynamically
647 # The full path to the instantiated job script. This gets made dynamically
649 # by combining the work_dir with the job_file_name.
648 # by combining the work_dir with the job_file_name.
650 job_file = Unicode(u'')
649 job_file = Unicode(u'')
651 scheduler = Unicode('', config=True,
650 scheduler = Unicode('', config=True,
652 help="The hostname of the scheduler to submit the job to.")
651 help="The hostname of the scheduler to submit the job to.")
653 job_cmd = Unicode(find_job_cmd(), config=True,
652 job_cmd = Unicode(find_job_cmd(), config=True,
654 help="The command for submitting jobs.")
653 help="The command for submitting jobs.")
655
654
656 def __init__(self, work_dir=u'.', config=None, **kwargs):
655 def __init__(self, work_dir=u'.', config=None, **kwargs):
657 super(WindowsHPCLauncher, self).__init__(
656 super(WindowsHPCLauncher, self).__init__(
658 work_dir=work_dir, config=config, **kwargs
657 work_dir=work_dir, config=config, **kwargs
659 )
658 )
660
659
661 @property
660 @property
662 def job_file(self):
661 def job_file(self):
663 return os.path.join(self.work_dir, self.job_file_name)
662 return os.path.join(self.work_dir, self.job_file_name)
664
663
665 def write_job_file(self, n):
664 def write_job_file(self, n):
666 raise NotImplementedError("Implement write_job_file in a subclass.")
665 raise NotImplementedError("Implement write_job_file in a subclass.")
667
666
668 def find_args(self):
667 def find_args(self):
669 return [u'job.exe']
668 return [u'job.exe']
670
669
671 def parse_job_id(self, output):
670 def parse_job_id(self, output):
672 """Take the output of the submit command and return the job id."""
671 """Take the output of the submit command and return the job id."""
673 m = re.search(self.job_id_regexp, output)
672 m = re.search(self.job_id_regexp, output)
674 if m is not None:
673 if m is not None:
675 job_id = m.group()
674 job_id = m.group()
676 else:
675 else:
677 raise LauncherError("Job id couldn't be determined: %s" % output)
676 raise LauncherError("Job id couldn't be determined: %s" % output)
678 self.job_id = job_id
677 self.job_id = job_id
679 self.log.info('Job started with job id: %r' % job_id)
678 self.log.info('Job started with job id: %r' % job_id)
680 return job_id
679 return job_id
681
680
682 def start(self, n):
681 def start(self, n):
683 """Start n copies of the process using the Win HPC job scheduler."""
682 """Start n copies of the process using the Win HPC job scheduler."""
684 self.write_job_file(n)
683 self.write_job_file(n)
685 args = [
684 args = [
686 'submit',
685 'submit',
687 '/jobfile:%s' % self.job_file,
686 '/jobfile:%s' % self.job_file,
688 '/scheduler:%s' % self.scheduler
687 '/scheduler:%s' % self.scheduler
689 ]
688 ]
690 self.log.info("Starting Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
689 self.log.info("Starting Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
691
690
692 output = check_output([self.job_cmd]+args,
691 output = check_output([self.job_cmd]+args,
693 env=os.environ,
692 env=os.environ,
694 cwd=self.work_dir,
693 cwd=self.work_dir,
695 stderr=STDOUT
694 stderr=STDOUT
696 )
695 )
697 job_id = self.parse_job_id(output)
696 job_id = self.parse_job_id(output)
698 self.notify_start(job_id)
697 self.notify_start(job_id)
699 return job_id
698 return job_id
700
699
701 def stop(self):
700 def stop(self):
702 args = [
701 args = [
703 'cancel',
702 'cancel',
704 self.job_id,
703 self.job_id,
705 '/scheduler:%s' % self.scheduler
704 '/scheduler:%s' % self.scheduler
706 ]
705 ]
707 self.log.info("Stopping Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
706 self.log.info("Stopping Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
708 try:
707 try:
709 output = check_output([self.job_cmd]+args,
708 output = check_output([self.job_cmd]+args,
710 env=os.environ,
709 env=os.environ,
711 cwd=self.work_dir,
710 cwd=self.work_dir,
712 stderr=STDOUT
711 stderr=STDOUT
713 )
712 )
714 except:
713 except:
715 output = 'The job already appears to be stoppped: %r' % self.job_id
714 output = 'The job already appears to be stoppped: %r' % self.job_id
716 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
715 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
717 return output
716 return output
718
717
719
718
720 class WindowsHPCControllerLauncher(WindowsHPCLauncher):
719 class WindowsHPCControllerLauncher(WindowsHPCLauncher):
721
720
722 job_file_name = Unicode(u'ipcontroller_job.xml', config=True,
721 job_file_name = Unicode(u'ipcontroller_job.xml', config=True,
723 help="WinHPC xml job file.")
722 help="WinHPC xml job file.")
724 extra_args = List([], config=False,
723 extra_args = List([], config=False,
725 help="extra args to pass to ipcontroller")
724 help="extra args to pass to ipcontroller")
726
725
727 def write_job_file(self, n):
726 def write_job_file(self, n):
728 job = IPControllerJob(config=self.config)
727 job = IPControllerJob(config=self.config)
729
728
730 t = IPControllerTask(config=self.config)
729 t = IPControllerTask(config=self.config)
731 # The tasks work directory is *not* the actual work directory of
730 # The tasks work directory is *not* the actual work directory of
732 # the controller. It is used as the base path for the stdout/stderr
731 # the controller. It is used as the base path for the stdout/stderr
733 # files that the scheduler redirects to.
732 # files that the scheduler redirects to.
734 t.work_directory = self.profile_dir
733 t.work_directory = self.profile_dir
735 # Add the profile_dir and from self.start().
734 # Add the profile_dir and from self.start().
736 t.controller_args.extend(self.extra_args)
735 t.controller_args.extend(self.extra_args)
737 job.add_task(t)
736 job.add_task(t)
738
737
739 self.log.info("Writing job description file: %s" % self.job_file)
738 self.log.info("Writing job description file: %s" % self.job_file)
740 job.write(self.job_file)
739 job.write(self.job_file)
741
740
742 @property
741 @property
743 def job_file(self):
742 def job_file(self):
744 return os.path.join(self.profile_dir, self.job_file_name)
743 return os.path.join(self.profile_dir, self.job_file_name)
745
744
746 def start(self, profile_dir):
745 def start(self, profile_dir):
747 """Start the controller by profile_dir."""
746 """Start the controller by profile_dir."""
748 self.extra_args = ['--profile-dir=%s'%profile_dir]
747 self.extra_args = ['--profile-dir=%s'%profile_dir]
749 self.profile_dir = unicode(profile_dir)
748 self.profile_dir = unicode(profile_dir)
750 return super(WindowsHPCControllerLauncher, self).start(1)
749 return super(WindowsHPCControllerLauncher, self).start(1)
751
750
752
751
753 class WindowsHPCEngineSetLauncher(WindowsHPCLauncher):
752 class WindowsHPCEngineSetLauncher(WindowsHPCLauncher):
754
753
755 job_file_name = Unicode(u'ipengineset_job.xml', config=True,
754 job_file_name = Unicode(u'ipengineset_job.xml', config=True,
756 help="jobfile for ipengines job")
755 help="jobfile for ipengines job")
757 extra_args = List([], config=False,
756 extra_args = List([], config=False,
758 help="extra args to pas to ipengine")
757 help="extra args to pas to ipengine")
759
758
760 def write_job_file(self, n):
759 def write_job_file(self, n):
761 job = IPEngineSetJob(config=self.config)
760 job = IPEngineSetJob(config=self.config)
762
761
763 for i in range(n):
762 for i in range(n):
764 t = IPEngineTask(config=self.config)
763 t = IPEngineTask(config=self.config)
765 # The tasks work directory is *not* the actual work directory of
764 # The tasks work directory is *not* the actual work directory of
766 # the engine. It is used as the base path for the stdout/stderr
765 # the engine. It is used as the base path for the stdout/stderr
767 # files that the scheduler redirects to.
766 # files that the scheduler redirects to.
768 t.work_directory = self.profile_dir
767 t.work_directory = self.profile_dir
769 # Add the profile_dir and from self.start().
768 # Add the profile_dir and from self.start().
770 t.engine_args.extend(self.extra_args)
769 t.engine_args.extend(self.extra_args)
771 job.add_task(t)
770 job.add_task(t)
772
771
773 self.log.info("Writing job description file: %s" % self.job_file)
772 self.log.info("Writing job description file: %s" % self.job_file)
774 job.write(self.job_file)
773 job.write(self.job_file)
775
774
776 @property
775 @property
777 def job_file(self):
776 def job_file(self):
778 return os.path.join(self.profile_dir, self.job_file_name)
777 return os.path.join(self.profile_dir, self.job_file_name)
779
778
780 def start(self, n, profile_dir):
779 def start(self, n, profile_dir):
781 """Start the controller by profile_dir."""
780 """Start the controller by profile_dir."""
782 self.extra_args = ['--profile-dir=%s'%profile_dir]
781 self.extra_args = ['--profile-dir=%s'%profile_dir]
783 self.profile_dir = unicode(profile_dir)
782 self.profile_dir = unicode(profile_dir)
784 return super(WindowsHPCEngineSetLauncher, self).start(n)
783 return super(WindowsHPCEngineSetLauncher, self).start(n)
785
784
786
785
787 #-----------------------------------------------------------------------------
786 #-----------------------------------------------------------------------------
788 # Batch (PBS) system launchers
787 # Batch (PBS) system launchers
789 #-----------------------------------------------------------------------------
788 #-----------------------------------------------------------------------------
790
789
791 class BatchSystemLauncher(BaseLauncher):
790 class BatchSystemLauncher(BaseLauncher):
792 """Launch an external process using a batch system.
791 """Launch an external process using a batch system.
793
792
794 This class is designed to work with UNIX batch systems like PBS, LSF,
793 This class is designed to work with UNIX batch systems like PBS, LSF,
795 GridEngine, etc. The overall model is that there are different commands
794 GridEngine, etc. The overall model is that there are different commands
796 like qsub, qdel, etc. that handle the starting and stopping of the process.
795 like qsub, qdel, etc. that handle the starting and stopping of the process.
797
796
798 This class also has the notion of a batch script. The ``batch_template``
797 This class also has the notion of a batch script. The ``batch_template``
799 attribute can be set to a string that is a template for the batch script.
798 attribute can be set to a string that is a template for the batch script.
800 This template is instantiated using string formatting. Thus the template can
799 This template is instantiated using string formatting. Thus the template can
801 use {n} fot the number of instances. Subclasses can add additional variables
800 use {n} fot the number of instances. Subclasses can add additional variables
802 to the template dict.
801 to the template dict.
803 """
802 """
804
803
805 # Subclasses must fill these in. See PBSEngineSet
804 # Subclasses must fill these in. See PBSEngineSet
806 submit_command = List([''], config=True,
805 submit_command = List([''], config=True,
807 help="The name of the command line program used to submit jobs.")
806 help="The name of the command line program used to submit jobs.")
808 delete_command = List([''], config=True,
807 delete_command = List([''], config=True,
809 help="The name of the command line program used to delete jobs.")
808 help="The name of the command line program used to delete jobs.")
810 job_id_regexp = Unicode('', config=True,
809 job_id_regexp = Unicode('', config=True,
811 help="""A regular expression used to get the job id from the output of the
810 help="""A regular expression used to get the job id from the output of the
812 submit_command.""")
811 submit_command.""")
813 batch_template = Unicode('', config=True,
812 batch_template = Unicode('', config=True,
814 help="The string that is the batch script template itself.")
813 help="The string that is the batch script template itself.")
815 batch_template_file = Unicode(u'', config=True,
814 batch_template_file = Unicode(u'', config=True,
816 help="The file that contains the batch template.")
815 help="The file that contains the batch template.")
817 batch_file_name = Unicode(u'batch_script', config=True,
816 batch_file_name = Unicode(u'batch_script', config=True,
818 help="The filename of the instantiated batch script.")
817 help="The filename of the instantiated batch script.")
819 queue = Unicode(u'', config=True,
818 queue = Unicode(u'', config=True,
820 help="The PBS Queue.")
819 help="The PBS Queue.")
821
820
822 # not configurable, override in subclasses
821 # not configurable, override in subclasses
823 # PBS Job Array regex
822 # PBS Job Array regex
824 job_array_regexp = Unicode('')
823 job_array_regexp = Unicode('')
825 job_array_template = Unicode('')
824 job_array_template = Unicode('')
826 # PBS Queue regex
825 # PBS Queue regex
827 queue_regexp = Unicode('')
826 queue_regexp = Unicode('')
828 queue_template = Unicode('')
827 queue_template = Unicode('')
829 # The default batch template, override in subclasses
828 # The default batch template, override in subclasses
830 default_template = Unicode('')
829 default_template = Unicode('')
831 # The full path to the instantiated batch script.
830 # The full path to the instantiated batch script.
832 batch_file = Unicode(u'')
831 batch_file = Unicode(u'')
833 # the format dict used with batch_template:
832 # the format dict used with batch_template:
834 context = Dict()
833 context = Dict()
835 # the Formatter instance for rendering the templates:
834 # the Formatter instance for rendering the templates:
836 formatter = Instance(EvalFormatter, (), {})
835 formatter = Instance(EvalFormatter, (), {})
837
836
838
837
839 def find_args(self):
838 def find_args(self):
840 return self.submit_command + [self.batch_file]
839 return self.submit_command + [self.batch_file]
841
840
842 def __init__(self, work_dir=u'.', config=None, **kwargs):
841 def __init__(self, work_dir=u'.', config=None, **kwargs):
843 super(BatchSystemLauncher, self).__init__(
842 super(BatchSystemLauncher, self).__init__(
844 work_dir=work_dir, config=config, **kwargs
843 work_dir=work_dir, config=config, **kwargs
845 )
844 )
846 self.batch_file = os.path.join(self.work_dir, self.batch_file_name)
845 self.batch_file = os.path.join(self.work_dir, self.batch_file_name)
847
846
848 def parse_job_id(self, output):
847 def parse_job_id(self, output):
849 """Take the output of the submit command and return the job id."""
848 """Take the output of the submit command and return the job id."""
850 m = re.search(self.job_id_regexp, output)
849 m = re.search(self.job_id_regexp, output)
851 if m is not None:
850 if m is not None:
852 job_id = m.group()
851 job_id = m.group()
853 else:
852 else:
854 raise LauncherError("Job id couldn't be determined: %s" % output)
853 raise LauncherError("Job id couldn't be determined: %s" % output)
855 self.job_id = job_id
854 self.job_id = job_id
856 self.log.info('Job submitted with job id: %r' % job_id)
855 self.log.info('Job submitted with job id: %r' % job_id)
857 return job_id
856 return job_id
858
857
859 def write_batch_script(self, n):
858 def write_batch_script(self, n):
860 """Instantiate and write the batch script to the work_dir."""
859 """Instantiate and write the batch script to the work_dir."""
861 self.context['n'] = n
860 self.context['n'] = n
862 self.context['queue'] = self.queue
861 self.context['queue'] = self.queue
863 # first priority is batch_template if set
862 # first priority is batch_template if set
864 if self.batch_template_file and not self.batch_template:
863 if self.batch_template_file and not self.batch_template:
865 # second priority is batch_template_file
864 # second priority is batch_template_file
866 with open(self.batch_template_file) as f:
865 with open(self.batch_template_file) as f:
867 self.batch_template = f.read()
866 self.batch_template = f.read()
868 if not self.batch_template:
867 if not self.batch_template:
869 # third (last) priority is default_template
868 # third (last) priority is default_template
870 self.batch_template = self.default_template
869 self.batch_template = self.default_template
871
870
872 # add jobarray or queue lines to user-specified template
871 # add jobarray or queue lines to user-specified template
873 # note that this is *only* when user did not specify a template.
872 # note that this is *only* when user did not specify a template.
874 regex = re.compile(self.job_array_regexp)
873 regex = re.compile(self.job_array_regexp)
875 # print regex.search(self.batch_template)
874 # print regex.search(self.batch_template)
876 if not regex.search(self.batch_template):
875 if not regex.search(self.batch_template):
877 self.log.info("adding job array settings to batch script")
876 self.log.info("adding job array settings to batch script")
878 firstline, rest = self.batch_template.split('\n',1)
877 firstline, rest = self.batch_template.split('\n',1)
879 self.batch_template = u'\n'.join([firstline, self.job_array_template, rest])
878 self.batch_template = u'\n'.join([firstline, self.job_array_template, rest])
880
879
881 regex = re.compile(self.queue_regexp)
880 regex = re.compile(self.queue_regexp)
882 # print regex.search(self.batch_template)
881 # print regex.search(self.batch_template)
883 if self.queue and not regex.search(self.batch_template):
882 if self.queue and not regex.search(self.batch_template):
884 self.log.info("adding PBS queue settings to batch script")
883 self.log.info("adding PBS queue settings to batch script")
885 firstline, rest = self.batch_template.split('\n',1)
884 firstline, rest = self.batch_template.split('\n',1)
886 self.batch_template = u'\n'.join([firstline, self.queue_template, rest])
885 self.batch_template = u'\n'.join([firstline, self.queue_template, rest])
887
886
888 script_as_string = self.formatter.format(self.batch_template, **self.context)
887 script_as_string = self.formatter.format(self.batch_template, **self.context)
889 self.log.info('Writing instantiated batch script: %s' % self.batch_file)
888 self.log.info('Writing instantiated batch script: %s' % self.batch_file)
890
889
891 with open(self.batch_file, 'w') as f:
890 with open(self.batch_file, 'w') as f:
892 f.write(script_as_string)
891 f.write(script_as_string)
893 os.chmod(self.batch_file, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
892 os.chmod(self.batch_file, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
894
893
895 def start(self, n, profile_dir):
894 def start(self, n, profile_dir):
896 """Start n copies of the process using a batch system."""
895 """Start n copies of the process using a batch system."""
897 # Here we save profile_dir in the context so they
896 # Here we save profile_dir in the context so they
898 # can be used in the batch script template as {profile_dir}
897 # can be used in the batch script template as {profile_dir}
899 self.context['profile_dir'] = profile_dir
898 self.context['profile_dir'] = profile_dir
900 self.profile_dir = unicode(profile_dir)
899 self.profile_dir = unicode(profile_dir)
901 self.write_batch_script(n)
900 self.write_batch_script(n)
902 output = check_output(self.args, env=os.environ)
901 output = check_output(self.args, env=os.environ)
903
902
904 job_id = self.parse_job_id(output)
903 job_id = self.parse_job_id(output)
905 self.notify_start(job_id)
904 self.notify_start(job_id)
906 return job_id
905 return job_id
907
906
908 def stop(self):
907 def stop(self):
909 output = check_output(self.delete_command+[self.job_id], env=os.environ)
908 output = check_output(self.delete_command+[self.job_id], env=os.environ)
910 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
909 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
911 return output
910 return output
912
911
913
912
914 class PBSLauncher(BatchSystemLauncher):
913 class PBSLauncher(BatchSystemLauncher):
915 """A BatchSystemLauncher subclass for PBS."""
914 """A BatchSystemLauncher subclass for PBS."""
916
915
917 submit_command = List(['qsub'], config=True,
916 submit_command = List(['qsub'], config=True,
918 help="The PBS submit command ['qsub']")
917 help="The PBS submit command ['qsub']")
919 delete_command = List(['qdel'], config=True,
918 delete_command = List(['qdel'], config=True,
920 help="The PBS delete command ['qsub']")
919 help="The PBS delete command ['qsub']")
921 job_id_regexp = Unicode(r'\d+', config=True,
920 job_id_regexp = Unicode(r'\d+', config=True,
922 help="Regular expresion for identifying the job ID [r'\d+']")
921 help="Regular expresion for identifying the job ID [r'\d+']")
923
922
924 batch_file = Unicode(u'')
923 batch_file = Unicode(u'')
925 job_array_regexp = Unicode('#PBS\W+-t\W+[\w\d\-\$]+')
924 job_array_regexp = Unicode('#PBS\W+-t\W+[\w\d\-\$]+')
926 job_array_template = Unicode('#PBS -t 1-{n}')
925 job_array_template = Unicode('#PBS -t 1-{n}')
927 queue_regexp = Unicode('#PBS\W+-q\W+\$?\w+')
926 queue_regexp = Unicode('#PBS\W+-q\W+\$?\w+')
928 queue_template = Unicode('#PBS -q {queue}')
927 queue_template = Unicode('#PBS -q {queue}')
929
928
930
929
931 class PBSControllerLauncher(PBSLauncher):
930 class PBSControllerLauncher(PBSLauncher):
932 """Launch a controller using PBS."""
931 """Launch a controller using PBS."""
933
932
934 batch_file_name = Unicode(u'pbs_controller', config=True,
933 batch_file_name = Unicode(u'pbs_controller', config=True,
935 help="batch file name for the controller job.")
934 help="batch file name for the controller job.")
936 default_template= Unicode("""#!/bin/sh
935 default_template= Unicode("""#!/bin/sh
937 #PBS -V
936 #PBS -V
938 #PBS -N ipcontroller
937 #PBS -N ipcontroller
939 %s --log-to-file --profile-dir={profile_dir}
938 %s --log-to-file --profile-dir={profile_dir}
940 """%(' '.join(ipcontroller_cmd_argv)))
939 """%(' '.join(ipcontroller_cmd_argv)))
941
940
942 def start(self, profile_dir):
941 def start(self, profile_dir):
943 """Start the controller by profile or profile_dir."""
942 """Start the controller by profile or profile_dir."""
944 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
943 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
945 return super(PBSControllerLauncher, self).start(1, profile_dir)
944 return super(PBSControllerLauncher, self).start(1, profile_dir)
946
945
947
946
948 class PBSEngineSetLauncher(PBSLauncher):
947 class PBSEngineSetLauncher(PBSLauncher):
949 """Launch Engines using PBS"""
948 """Launch Engines using PBS"""
950 batch_file_name = Unicode(u'pbs_engines', config=True,
949 batch_file_name = Unicode(u'pbs_engines', config=True,
951 help="batch file name for the engine(s) job.")
950 help="batch file name for the engine(s) job.")
952 default_template= Unicode(u"""#!/bin/sh
951 default_template= Unicode(u"""#!/bin/sh
953 #PBS -V
952 #PBS -V
954 #PBS -N ipengine
953 #PBS -N ipengine
955 %s --profile-dir={profile_dir}
954 %s --profile-dir={profile_dir}
956 """%(' '.join(ipengine_cmd_argv)))
955 """%(' '.join(ipengine_cmd_argv)))
957
956
958 def start(self, n, profile_dir):
957 def start(self, n, profile_dir):
959 """Start n engines by profile or profile_dir."""
958 """Start n engines by profile or profile_dir."""
960 self.log.info('Starting %i engines with PBSEngineSetLauncher: %r' % (n, self.args))
959 self.log.info('Starting %i engines with PBSEngineSetLauncher: %r' % (n, self.args))
961 return super(PBSEngineSetLauncher, self).start(n, profile_dir)
960 return super(PBSEngineSetLauncher, self).start(n, profile_dir)
962
961
963 #SGE is very similar to PBS
962 #SGE is very similar to PBS
964
963
965 class SGELauncher(PBSLauncher):
964 class SGELauncher(PBSLauncher):
966 """Sun GridEngine is a PBS clone with slightly different syntax"""
965 """Sun GridEngine is a PBS clone with slightly different syntax"""
967 job_array_regexp = Unicode('#\$\W+\-t')
966 job_array_regexp = Unicode('#\$\W+\-t')
968 job_array_template = Unicode('#$ -t 1-{n}')
967 job_array_template = Unicode('#$ -t 1-{n}')
969 queue_regexp = Unicode('#\$\W+-q\W+\$?\w+')
968 queue_regexp = Unicode('#\$\W+-q\W+\$?\w+')
970 queue_template = Unicode('#$ -q {queue}')
969 queue_template = Unicode('#$ -q {queue}')
971
970
972 class SGEControllerLauncher(SGELauncher):
971 class SGEControllerLauncher(SGELauncher):
973 """Launch a controller using SGE."""
972 """Launch a controller using SGE."""
974
973
975 batch_file_name = Unicode(u'sge_controller', config=True,
974 batch_file_name = Unicode(u'sge_controller', config=True,
976 help="batch file name for the ipontroller job.")
975 help="batch file name for the ipontroller job.")
977 default_template= Unicode(u"""#$ -V
976 default_template= Unicode(u"""#$ -V
978 #$ -S /bin/sh
977 #$ -S /bin/sh
979 #$ -N ipcontroller
978 #$ -N ipcontroller
980 %s --log-to-file --profile-dir={profile_dir}
979 %s --log-to-file --profile-dir={profile_dir}
981 """%(' '.join(ipcontroller_cmd_argv)))
980 """%(' '.join(ipcontroller_cmd_argv)))
982
981
983 def start(self, profile_dir):
982 def start(self, profile_dir):
984 """Start the controller by profile or profile_dir."""
983 """Start the controller by profile or profile_dir."""
985 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
984 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
986 return super(SGEControllerLauncher, self).start(1, profile_dir)
985 return super(SGEControllerLauncher, self).start(1, profile_dir)
987
986
988 class SGEEngineSetLauncher(SGELauncher):
987 class SGEEngineSetLauncher(SGELauncher):
989 """Launch Engines with SGE"""
988 """Launch Engines with SGE"""
990 batch_file_name = Unicode(u'sge_engines', config=True,
989 batch_file_name = Unicode(u'sge_engines', config=True,
991 help="batch file name for the engine(s) job.")
990 help="batch file name for the engine(s) job.")
992 default_template = Unicode("""#$ -V
991 default_template = Unicode("""#$ -V
993 #$ -S /bin/sh
992 #$ -S /bin/sh
994 #$ -N ipengine
993 #$ -N ipengine
995 %s --profile-dir={profile_dir}
994 %s --profile-dir={profile_dir}
996 """%(' '.join(ipengine_cmd_argv)))
995 """%(' '.join(ipengine_cmd_argv)))
997
996
998 def start(self, n, profile_dir):
997 def start(self, n, profile_dir):
999 """Start n engines by profile or profile_dir."""
998 """Start n engines by profile or profile_dir."""
1000 self.log.info('Starting %i engines with SGEEngineSetLauncher: %r' % (n, self.args))
999 self.log.info('Starting %i engines with SGEEngineSetLauncher: %r' % (n, self.args))
1001 return super(SGEEngineSetLauncher, self).start(n, profile_dir)
1000 return super(SGEEngineSetLauncher, self).start(n, profile_dir)
1002
1001
1003
1002
1004 # LSF launchers
1003 # LSF launchers
1005
1004
1006 class LSFLauncher(BatchSystemLauncher):
1005 class LSFLauncher(BatchSystemLauncher):
1007 """A BatchSystemLauncher subclass for LSF."""
1006 """A BatchSystemLauncher subclass for LSF."""
1008
1007
1009 submit_command = List(['bsub'], config=True,
1008 submit_command = List(['bsub'], config=True,
1010 help="The PBS submit command ['bsub']")
1009 help="The PBS submit command ['bsub']")
1011 delete_command = List(['bkill'], config=True,
1010 delete_command = List(['bkill'], config=True,
1012 help="The PBS delete command ['bkill']")
1011 help="The PBS delete command ['bkill']")
1013 job_id_regexp = Unicode(r'\d+', config=True,
1012 job_id_regexp = Unicode(r'\d+', config=True,
1014 help="Regular expresion for identifying the job ID [r'\d+']")
1013 help="Regular expresion for identifying the job ID [r'\d+']")
1015
1014
1016 batch_file = Unicode(u'')
1015 batch_file = Unicode(u'')
1017 job_array_regexp = Unicode('#BSUB[ \t]-J+\w+\[\d+-\d+\]')
1016 job_array_regexp = Unicode('#BSUB[ \t]-J+\w+\[\d+-\d+\]')
1018 job_array_template = Unicode('#BSUB -J ipengine[1-{n}]')
1017 job_array_template = Unicode('#BSUB -J ipengine[1-{n}]')
1019 queue_regexp = Unicode('#BSUB[ \t]+-q[ \t]+\w+')
1018 queue_regexp = Unicode('#BSUB[ \t]+-q[ \t]+\w+')
1020 queue_template = Unicode('#BSUB -q {queue}')
1019 queue_template = Unicode('#BSUB -q {queue}')
1021
1020
1022 def start(self, n, profile_dir):
1021 def start(self, n, profile_dir):
1023 """Start n copies of the process using LSF batch system.
1022 """Start n copies of the process using LSF batch system.
1024 This cant inherit from the base class because bsub expects
1023 This cant inherit from the base class because bsub expects
1025 to be piped a shell script in order to honor the #BSUB directives :
1024 to be piped a shell script in order to honor the #BSUB directives :
1026 bsub < script
1025 bsub < script
1027 """
1026 """
1028 # Here we save profile_dir in the context so they
1027 # Here we save profile_dir in the context so they
1029 # can be used in the batch script template as {profile_dir}
1028 # can be used in the batch script template as {profile_dir}
1030 self.context['profile_dir'] = profile_dir
1029 self.context['profile_dir'] = profile_dir
1031 self.profile_dir = unicode(profile_dir)
1030 self.profile_dir = unicode(profile_dir)
1032 self.write_batch_script(n)
1031 self.write_batch_script(n)
1033 #output = check_output(self.args, env=os.environ)
1032 #output = check_output(self.args, env=os.environ)
1034 piped_cmd = self.args[0]+'<\"'+self.args[1]+'\"'
1033 piped_cmd = self.args[0]+'<\"'+self.args[1]+'\"'
1035 p = Popen(piped_cmd, shell=True,env=os.environ,stdout=PIPE)
1034 p = Popen(piped_cmd, shell=True,env=os.environ,stdout=PIPE)
1036 output,err = p.communicate()
1035 output,err = p.communicate()
1037 job_id = self.parse_job_id(output)
1036 job_id = self.parse_job_id(output)
1038 self.notify_start(job_id)
1037 self.notify_start(job_id)
1039 return job_id
1038 return job_id
1040
1039
1041
1040
1042 class LSFControllerLauncher(LSFLauncher):
1041 class LSFControllerLauncher(LSFLauncher):
1043 """Launch a controller using LSF."""
1042 """Launch a controller using LSF."""
1044
1043
1045 batch_file_name = Unicode(u'lsf_controller', config=True,
1044 batch_file_name = Unicode(u'lsf_controller', config=True,
1046 help="batch file name for the controller job.")
1045 help="batch file name for the controller job.")
1047 default_template= Unicode("""#!/bin/sh
1046 default_template= Unicode("""#!/bin/sh
1048 #BSUB -J ipcontroller
1047 #BSUB -J ipcontroller
1049 #BSUB -oo ipcontroller.o.%%J
1048 #BSUB -oo ipcontroller.o.%%J
1050 #BSUB -eo ipcontroller.e.%%J
1049 #BSUB -eo ipcontroller.e.%%J
1051 %s --log-to-file --profile-dir={profile_dir}
1050 %s --log-to-file --profile-dir={profile_dir}
1052 """%(' '.join(ipcontroller_cmd_argv)))
1051 """%(' '.join(ipcontroller_cmd_argv)))
1053
1052
1054 def start(self, profile_dir):
1053 def start(self, profile_dir):
1055 """Start the controller by profile or profile_dir."""
1054 """Start the controller by profile or profile_dir."""
1056 self.log.info("Starting LSFControllerLauncher: %r" % self.args)
1055 self.log.info("Starting LSFControllerLauncher: %r" % self.args)
1057 return super(LSFControllerLauncher, self).start(1, profile_dir)
1056 return super(LSFControllerLauncher, self).start(1, profile_dir)
1058
1057
1059
1058
1060 class LSFEngineSetLauncher(LSFLauncher):
1059 class LSFEngineSetLauncher(LSFLauncher):
1061 """Launch Engines using LSF"""
1060 """Launch Engines using LSF"""
1062 batch_file_name = Unicode(u'lsf_engines', config=True,
1061 batch_file_name = Unicode(u'lsf_engines', config=True,
1063 help="batch file name for the engine(s) job.")
1062 help="batch file name for the engine(s) job.")
1064 default_template= Unicode(u"""#!/bin/sh
1063 default_template= Unicode(u"""#!/bin/sh
1065 #BSUB -oo ipengine.o.%%J
1064 #BSUB -oo ipengine.o.%%J
1066 #BSUB -eo ipengine.e.%%J
1065 #BSUB -eo ipengine.e.%%J
1067 %s --profile-dir={profile_dir}
1066 %s --profile-dir={profile_dir}
1068 """%(' '.join(ipengine_cmd_argv)))
1067 """%(' '.join(ipengine_cmd_argv)))
1069
1068
1070 def start(self, n, profile_dir):
1069 def start(self, n, profile_dir):
1071 """Start n engines by profile or profile_dir."""
1070 """Start n engines by profile or profile_dir."""
1072 self.log.info('Starting %i engines with LSFEngineSetLauncher: %r' % (n, self.args))
1071 self.log.info('Starting %i engines with LSFEngineSetLauncher: %r' % (n, self.args))
1073 return super(LSFEngineSetLauncher, self).start(n, profile_dir)
1072 return super(LSFEngineSetLauncher, self).start(n, profile_dir)
1074
1073
1075
1074
1076 #-----------------------------------------------------------------------------
1075 #-----------------------------------------------------------------------------
1077 # A launcher for ipcluster itself!
1076 # A launcher for ipcluster itself!
1078 #-----------------------------------------------------------------------------
1077 #-----------------------------------------------------------------------------
1079
1078
1080
1079
1081 class IPClusterLauncher(LocalProcessLauncher):
1080 class IPClusterLauncher(LocalProcessLauncher):
1082 """Launch the ipcluster program in an external process."""
1081 """Launch the ipcluster program in an external process."""
1083
1082
1084 ipcluster_cmd = List(ipcluster_cmd_argv, config=True,
1083 ipcluster_cmd = List(ipcluster_cmd_argv, config=True,
1085 help="Popen command for ipcluster")
1084 help="Popen command for ipcluster")
1086 ipcluster_args = List(
1085 ipcluster_args = List(
1087 ['--clean-logs', '--log-to-file', '--log-level=%i'%logging.INFO], config=True,
1086 ['--clean-logs', '--log-to-file', '--log-level=%i'%logging.INFO], config=True,
1088 help="Command line arguments to pass to ipcluster.")
1087 help="Command line arguments to pass to ipcluster.")
1089 ipcluster_subcommand = Unicode('start')
1088 ipcluster_subcommand = Unicode('start')
1090 ipcluster_n = Int(2)
1089 ipcluster_n = Int(2)
1091
1090
1092 def find_args(self):
1091 def find_args(self):
1093 return self.ipcluster_cmd + [self.ipcluster_subcommand] + \
1092 return self.ipcluster_cmd + [self.ipcluster_subcommand] + \
1094 ['--n=%i'%self.ipcluster_n] + self.ipcluster_args
1093 ['--n=%i'%self.ipcluster_n] + self.ipcluster_args
1095
1094
1096 def start(self):
1095 def start(self):
1097 self.log.info("Starting ipcluster: %r" % self.args)
1096 self.log.info("Starting ipcluster: %r" % self.args)
1098 return super(IPClusterLauncher, self).start()
1097 return super(IPClusterLauncher, self).start()
1099
1098
1100 #-----------------------------------------------------------------------------
1099 #-----------------------------------------------------------------------------
1101 # Collections of launchers
1100 # Collections of launchers
1102 #-----------------------------------------------------------------------------
1101 #-----------------------------------------------------------------------------
1103
1102
1104 local_launchers = [
1103 local_launchers = [
1105 LocalControllerLauncher,
1104 LocalControllerLauncher,
1106 LocalEngineLauncher,
1105 LocalEngineLauncher,
1107 LocalEngineSetLauncher,
1106 LocalEngineSetLauncher,
1108 ]
1107 ]
1109 mpi_launchers = [
1108 mpi_launchers = [
1110 MPIExecLauncher,
1109 MPIExecLauncher,
1111 MPIExecControllerLauncher,
1110 MPIExecControllerLauncher,
1112 MPIExecEngineSetLauncher,
1111 MPIExecEngineSetLauncher,
1113 ]
1112 ]
1114 ssh_launchers = [
1113 ssh_launchers = [
1115 SSHLauncher,
1114 SSHLauncher,
1116 SSHControllerLauncher,
1115 SSHControllerLauncher,
1117 SSHEngineLauncher,
1116 SSHEngineLauncher,
1118 SSHEngineSetLauncher,
1117 SSHEngineSetLauncher,
1119 ]
1118 ]
1120 winhpc_launchers = [
1119 winhpc_launchers = [
1121 WindowsHPCLauncher,
1120 WindowsHPCLauncher,
1122 WindowsHPCControllerLauncher,
1121 WindowsHPCControllerLauncher,
1123 WindowsHPCEngineSetLauncher,
1122 WindowsHPCEngineSetLauncher,
1124 ]
1123 ]
1125 pbs_launchers = [
1124 pbs_launchers = [
1126 PBSLauncher,
1125 PBSLauncher,
1127 PBSControllerLauncher,
1126 PBSControllerLauncher,
1128 PBSEngineSetLauncher,
1127 PBSEngineSetLauncher,
1129 ]
1128 ]
1130 sge_launchers = [
1129 sge_launchers = [
1131 SGELauncher,
1130 SGELauncher,
1132 SGEControllerLauncher,
1131 SGEControllerLauncher,
1133 SGEEngineSetLauncher,
1132 SGEEngineSetLauncher,
1134 ]
1133 ]
1135 lsf_launchers = [
1134 lsf_launchers = [
1136 LSFLauncher,
1135 LSFLauncher,
1137 LSFControllerLauncher,
1136 LSFControllerLauncher,
1138 LSFEngineSetLauncher,
1137 LSFEngineSetLauncher,
1139 ]
1138 ]
1140 all_launchers = local_launchers + mpi_launchers + ssh_launchers + winhpc_launchers\
1139 all_launchers = local_launchers + mpi_launchers + ssh_launchers + winhpc_launchers\
1141 + pbs_launchers + sge_launchers + lsf_launchers
1140 + pbs_launchers + sge_launchers + lsf_launchers
1142
1141
@@ -1,115 +1,114 b''
1 #!/usr/bin/env python
2 """
1 """
3 A simple logger object that consolidates messages incoming from ipcluster processes.
2 A simple logger object that consolidates messages incoming from ipcluster processes.
4
3
5 Authors:
4 Authors:
6
5
7 * MinRK
6 * MinRK
8
7
9 """
8 """
10
9
11 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
12 # Copyright (C) 2011 The IPython Development Team
11 # Copyright (C) 2011 The IPython Development Team
13 #
12 #
14 # Distributed under the terms of the BSD License. The full license is in
13 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
14 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
17
16
18 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
19 # Imports
18 # Imports
20 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
21
20
22
21
23 import logging
22 import logging
24 import sys
23 import sys
25
24
26 import zmq
25 import zmq
27 from zmq.eventloop import ioloop, zmqstream
26 from zmq.eventloop import ioloop, zmqstream
28
27
29 from IPython.config.configurable import LoggingConfigurable
28 from IPython.config.configurable import LoggingConfigurable
30 from IPython.utils.traitlets import Int, Unicode, Instance, List
29 from IPython.utils.traitlets import Int, Unicode, Instance, List
31
30
32 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
33 # Classes
32 # Classes
34 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
35
34
36
35
37 class LogWatcher(LoggingConfigurable):
36 class LogWatcher(LoggingConfigurable):
38 """A simple class that receives messages on a SUB socket, as published
37 """A simple class that receives messages on a SUB socket, as published
39 by subclasses of `zmq.log.handlers.PUBHandler`, and logs them itself.
38 by subclasses of `zmq.log.handlers.PUBHandler`, and logs them itself.
40
39
41 This can subscribe to multiple topics, but defaults to all topics.
40 This can subscribe to multiple topics, but defaults to all topics.
42 """
41 """
43
42
44 # configurables
43 # configurables
45 topics = List([''], config=True,
44 topics = List([''], config=True,
46 help="The ZMQ topics to subscribe to. Default is to subscribe to all messages")
45 help="The ZMQ topics to subscribe to. Default is to subscribe to all messages")
47 url = Unicode('tcp://127.0.0.1:20202', config=True,
46 url = Unicode('tcp://127.0.0.1:20202', config=True,
48 help="ZMQ url on which to listen for log messages")
47 help="ZMQ url on which to listen for log messages")
49
48
50 # internals
49 # internals
51 stream = Instance('zmq.eventloop.zmqstream.ZMQStream')
50 stream = Instance('zmq.eventloop.zmqstream.ZMQStream')
52
51
53 context = Instance(zmq.Context)
52 context = Instance(zmq.Context)
54 def _context_default(self):
53 def _context_default(self):
55 return zmq.Context.instance()
54 return zmq.Context.instance()
56
55
57 loop = Instance(zmq.eventloop.ioloop.IOLoop)
56 loop = Instance(zmq.eventloop.ioloop.IOLoop)
58 def _loop_default(self):
57 def _loop_default(self):
59 return ioloop.IOLoop.instance()
58 return ioloop.IOLoop.instance()
60
59
61 def __init__(self, **kwargs):
60 def __init__(self, **kwargs):
62 super(LogWatcher, self).__init__(**kwargs)
61 super(LogWatcher, self).__init__(**kwargs)
63 s = self.context.socket(zmq.SUB)
62 s = self.context.socket(zmq.SUB)
64 s.bind(self.url)
63 s.bind(self.url)
65 self.stream = zmqstream.ZMQStream(s, self.loop)
64 self.stream = zmqstream.ZMQStream(s, self.loop)
66 self.subscribe()
65 self.subscribe()
67 self.on_trait_change(self.subscribe, 'topics')
66 self.on_trait_change(self.subscribe, 'topics')
68
67
69 def start(self):
68 def start(self):
70 self.stream.on_recv(self.log_message)
69 self.stream.on_recv(self.log_message)
71
70
72 def stop(self):
71 def stop(self):
73 self.stream.stop_on_recv()
72 self.stream.stop_on_recv()
74
73
75 def subscribe(self):
74 def subscribe(self):
76 """Update our SUB socket's subscriptions."""
75 """Update our SUB socket's subscriptions."""
77 self.stream.setsockopt(zmq.UNSUBSCRIBE, '')
76 self.stream.setsockopt(zmq.UNSUBSCRIBE, '')
78 if '' in self.topics:
77 if '' in self.topics:
79 self.log.debug("Subscribing to: everything")
78 self.log.debug("Subscribing to: everything")
80 self.stream.setsockopt(zmq.SUBSCRIBE, '')
79 self.stream.setsockopt(zmq.SUBSCRIBE, '')
81 else:
80 else:
82 for topic in self.topics:
81 for topic in self.topics:
83 self.log.debug("Subscribing to: %r"%(topic))
82 self.log.debug("Subscribing to: %r"%(topic))
84 self.stream.setsockopt(zmq.SUBSCRIBE, topic)
83 self.stream.setsockopt(zmq.SUBSCRIBE, topic)
85
84
86 def _extract_level(self, topic_str):
85 def _extract_level(self, topic_str):
87 """Turn 'engine.0.INFO.extra' into (logging.INFO, 'engine.0.extra')"""
86 """Turn 'engine.0.INFO.extra' into (logging.INFO, 'engine.0.extra')"""
88 topics = topic_str.split('.')
87 topics = topic_str.split('.')
89 for idx,t in enumerate(topics):
88 for idx,t in enumerate(topics):
90 level = getattr(logging, t, None)
89 level = getattr(logging, t, None)
91 if level is not None:
90 if level is not None:
92 break
91 break
93
92
94 if level is None:
93 if level is None:
95 level = logging.INFO
94 level = logging.INFO
96 else:
95 else:
97 topics.pop(idx)
96 topics.pop(idx)
98
97
99 return level, '.'.join(topics)
98 return level, '.'.join(topics)
100
99
101
100
102 def log_message(self, raw):
101 def log_message(self, raw):
103 """receive and parse a message, then log it."""
102 """receive and parse a message, then log it."""
104 if len(raw) != 2 or '.' not in raw[0]:
103 if len(raw) != 2 or '.' not in raw[0]:
105 self.log.error("Invalid log message: %s"%raw)
104 self.log.error("Invalid log message: %s"%raw)
106 return
105 return
107 else:
106 else:
108 topic, msg = raw
107 topic, msg = raw
109 # don't newline, since log messages always newline:
108 # don't newline, since log messages always newline:
110 topic,level_name = topic.rsplit('.',1)
109 topic,level_name = topic.rsplit('.',1)
111 level,topic = self._extract_level(topic)
110 level,topic = self._extract_level(topic)
112 if msg[-1] == '\n':
111 if msg[-1] == '\n':
113 msg = msg[:-1]
112 msg = msg[:-1]
114 self.log.log(level, "[%s] %s" % (topic, msg))
113 self.log.log(level, "[%s] %s" % (topic, msg))
115
114
@@ -1,73 +1,72 b''
1 #!/usr/bin/env python
2 """Utility for forwarding file read events over a zmq socket.
1 """Utility for forwarding file read events over a zmq socket.
3
2
4 This is necessary because select on Windows only supports sockets, not FDs.
3 This is necessary because select on Windows only supports sockets, not FDs.
5
4
6 Authors:
5 Authors:
7
6
8 * MinRK
7 * MinRK
9
8
10 """
9 """
11
10
12 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
13 # Copyright (C) 2011 The IPython Development Team
12 # Copyright (C) 2011 The IPython Development Team
14 #
13 #
15 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
16 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
17 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
18
17
19 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
20 # Imports
19 # Imports
21 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
22
21
23 import uuid
22 import uuid
24 import zmq
23 import zmq
25
24
26 from threading import Thread
25 from threading import Thread
27
26
28 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
29 # Code
28 # Code
30 #-----------------------------------------------------------------------------
29 #-----------------------------------------------------------------------------
31
30
32 class ForwarderThread(Thread):
31 class ForwarderThread(Thread):
33 def __init__(self, sock, fd):
32 def __init__(self, sock, fd):
34 Thread.__init__(self)
33 Thread.__init__(self)
35 self.daemon=True
34 self.daemon=True
36 self.sock = sock
35 self.sock = sock
37 self.fd = fd
36 self.fd = fd
38
37
39 def run(self):
38 def run(self):
40 """Loop through lines in self.fd, and send them over self.sock."""
39 """Loop through lines in self.fd, and send them over self.sock."""
41 line = self.fd.readline()
40 line = self.fd.readline()
42 # allow for files opened in unicode mode
41 # allow for files opened in unicode mode
43 if isinstance(line, unicode):
42 if isinstance(line, unicode):
44 send = self.sock.send_unicode
43 send = self.sock.send_unicode
45 else:
44 else:
46 send = self.sock.send
45 send = self.sock.send
47 while line:
46 while line:
48 send(line)
47 send(line)
49 line = self.fd.readline()
48 line = self.fd.readline()
50 # line == '' means EOF
49 # line == '' means EOF
51 self.fd.close()
50 self.fd.close()
52 self.sock.close()
51 self.sock.close()
53
52
54 def forward_read_events(fd, context=None):
53 def forward_read_events(fd, context=None):
55 """Forward read events from an FD over a socket.
54 """Forward read events from an FD over a socket.
56
55
57 This method wraps a file in a socket pair, so it can
56 This method wraps a file in a socket pair, so it can
58 be polled for read events by select (specifically zmq.eventloop.ioloop)
57 be polled for read events by select (specifically zmq.eventloop.ioloop)
59 """
58 """
60 if context is None:
59 if context is None:
61 context = zmq.Context.instance()
60 context = zmq.Context.instance()
62 push = context.socket(zmq.PUSH)
61 push = context.socket(zmq.PUSH)
63 push.setsockopt(zmq.LINGER, -1)
62 push.setsockopt(zmq.LINGER, -1)
64 pull = context.socket(zmq.PULL)
63 pull = context.socket(zmq.PULL)
65 addr='inproc://%s'%uuid.uuid4()
64 addr='inproc://%s'%uuid.uuid4()
66 push.bind(addr)
65 push.bind(addr)
67 pull.connect(addr)
66 pull.connect(addr)
68 forwarder = ForwarderThread(push, fd)
67 forwarder = ForwarderThread(push, fd)
69 forwarder.start()
68 forwarder.start()
70 return pull
69 return pull
71
70
72
71
73 __all__ = ['forward_read_events']
72 __all__ = ['forward_read_events']
@@ -1,320 +1,319 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 Job and task components for writing .xml files that the Windows HPC Server
3 Job and task components for writing .xml files that the Windows HPC Server
5 2008 can use to start jobs.
4 2008 can use to start jobs.
6
5
7 Authors:
6 Authors:
8
7
9 * Brian Granger
8 * Brian Granger
10 * MinRK
9 * MinRK
11
10
12 """
11 """
13
12
14 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
15 # Copyright (C) 2008-2011 The IPython Development Team
14 # Copyright (C) 2008-2011 The IPython Development Team
16 #
15 #
17 # Distributed under the terms of the BSD License. The full license is in
16 # Distributed under the terms of the BSD License. The full license is in
18 # the file COPYING, distributed as part of this software.
17 # the file COPYING, distributed as part of this software.
19 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
20
19
21 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
22 # Imports
21 # Imports
23 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
24
23
25 import os
24 import os
26 import re
25 import re
27 import uuid
26 import uuid
28
27
29 from xml.etree import ElementTree as ET
28 from xml.etree import ElementTree as ET
30
29
31 from IPython.config.configurable import Configurable
30 from IPython.config.configurable import Configurable
32 from IPython.utils.traitlets import (
31 from IPython.utils.traitlets import (
33 Unicode, Int, List, Instance,
32 Unicode, Int, List, Instance,
34 Enum, Bool
33 Enum, Bool
35 )
34 )
36
35
37 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
38 # Job and Task classes
37 # Job and Task classes
39 #-----------------------------------------------------------------------------
38 #-----------------------------------------------------------------------------
40
39
41
40
42 def as_str(value):
41 def as_str(value):
43 if isinstance(value, str):
42 if isinstance(value, str):
44 return value
43 return value
45 elif isinstance(value, bool):
44 elif isinstance(value, bool):
46 if value:
45 if value:
47 return 'true'
46 return 'true'
48 else:
47 else:
49 return 'false'
48 return 'false'
50 elif isinstance(value, (int, float)):
49 elif isinstance(value, (int, float)):
51 return repr(value)
50 return repr(value)
52 else:
51 else:
53 return value
52 return value
54
53
55
54
56 def indent(elem, level=0):
55 def indent(elem, level=0):
57 i = "\n" + level*" "
56 i = "\n" + level*" "
58 if len(elem):
57 if len(elem):
59 if not elem.text or not elem.text.strip():
58 if not elem.text or not elem.text.strip():
60 elem.text = i + " "
59 elem.text = i + " "
61 if not elem.tail or not elem.tail.strip():
60 if not elem.tail or not elem.tail.strip():
62 elem.tail = i
61 elem.tail = i
63 for elem in elem:
62 for elem in elem:
64 indent(elem, level+1)
63 indent(elem, level+1)
65 if not elem.tail or not elem.tail.strip():
64 if not elem.tail or not elem.tail.strip():
66 elem.tail = i
65 elem.tail = i
67 else:
66 else:
68 if level and (not elem.tail or not elem.tail.strip()):
67 if level and (not elem.tail or not elem.tail.strip()):
69 elem.tail = i
68 elem.tail = i
70
69
71
70
72 def find_username():
71 def find_username():
73 domain = os.environ.get('USERDOMAIN')
72 domain = os.environ.get('USERDOMAIN')
74 username = os.environ.get('USERNAME','')
73 username = os.environ.get('USERNAME','')
75 if domain is None:
74 if domain is None:
76 return username
75 return username
77 else:
76 else:
78 return '%s\\%s' % (domain, username)
77 return '%s\\%s' % (domain, username)
79
78
80
79
81 class WinHPCJob(Configurable):
80 class WinHPCJob(Configurable):
82
81
83 job_id = Unicode('')
82 job_id = Unicode('')
84 job_name = Unicode('MyJob', config=True)
83 job_name = Unicode('MyJob', config=True)
85 min_cores = Int(1, config=True)
84 min_cores = Int(1, config=True)
86 max_cores = Int(1, config=True)
85 max_cores = Int(1, config=True)
87 min_sockets = Int(1, config=True)
86 min_sockets = Int(1, config=True)
88 max_sockets = Int(1, config=True)
87 max_sockets = Int(1, config=True)
89 min_nodes = Int(1, config=True)
88 min_nodes = Int(1, config=True)
90 max_nodes = Int(1, config=True)
89 max_nodes = Int(1, config=True)
91 unit_type = Unicode("Core", config=True)
90 unit_type = Unicode("Core", config=True)
92 auto_calculate_min = Bool(True, config=True)
91 auto_calculate_min = Bool(True, config=True)
93 auto_calculate_max = Bool(True, config=True)
92 auto_calculate_max = Bool(True, config=True)
94 run_until_canceled = Bool(False, config=True)
93 run_until_canceled = Bool(False, config=True)
95 is_exclusive = Bool(False, config=True)
94 is_exclusive = Bool(False, config=True)
96 username = Unicode(find_username(), config=True)
95 username = Unicode(find_username(), config=True)
97 job_type = Unicode('Batch', config=True)
96 job_type = Unicode('Batch', config=True)
98 priority = Enum(('Lowest','BelowNormal','Normal','AboveNormal','Highest'),
97 priority = Enum(('Lowest','BelowNormal','Normal','AboveNormal','Highest'),
99 default_value='Highest', config=True)
98 default_value='Highest', config=True)
100 requested_nodes = Unicode('', config=True)
99 requested_nodes = Unicode('', config=True)
101 project = Unicode('IPython', config=True)
100 project = Unicode('IPython', config=True)
102 xmlns = Unicode('http://schemas.microsoft.com/HPCS2008/scheduler/')
101 xmlns = Unicode('http://schemas.microsoft.com/HPCS2008/scheduler/')
103 version = Unicode("2.000")
102 version = Unicode("2.000")
104 tasks = List([])
103 tasks = List([])
105
104
106 @property
105 @property
107 def owner(self):
106 def owner(self):
108 return self.username
107 return self.username
109
108
110 def _write_attr(self, root, attr, key):
109 def _write_attr(self, root, attr, key):
111 s = as_str(getattr(self, attr, ''))
110 s = as_str(getattr(self, attr, ''))
112 if s:
111 if s:
113 root.set(key, s)
112 root.set(key, s)
114
113
115 def as_element(self):
114 def as_element(self):
116 # We have to add _A_ type things to get the right order than
115 # We have to add _A_ type things to get the right order than
117 # the MSFT XML parser expects.
116 # the MSFT XML parser expects.
118 root = ET.Element('Job')
117 root = ET.Element('Job')
119 self._write_attr(root, 'version', '_A_Version')
118 self._write_attr(root, 'version', '_A_Version')
120 self._write_attr(root, 'job_name', '_B_Name')
119 self._write_attr(root, 'job_name', '_B_Name')
121 self._write_attr(root, 'unit_type', '_C_UnitType')
120 self._write_attr(root, 'unit_type', '_C_UnitType')
122 self._write_attr(root, 'min_cores', '_D_MinCores')
121 self._write_attr(root, 'min_cores', '_D_MinCores')
123 self._write_attr(root, 'max_cores', '_E_MaxCores')
122 self._write_attr(root, 'max_cores', '_E_MaxCores')
124 self._write_attr(root, 'min_sockets', '_F_MinSockets')
123 self._write_attr(root, 'min_sockets', '_F_MinSockets')
125 self._write_attr(root, 'max_sockets', '_G_MaxSockets')
124 self._write_attr(root, 'max_sockets', '_G_MaxSockets')
126 self._write_attr(root, 'min_nodes', '_H_MinNodes')
125 self._write_attr(root, 'min_nodes', '_H_MinNodes')
127 self._write_attr(root, 'max_nodes', '_I_MaxNodes')
126 self._write_attr(root, 'max_nodes', '_I_MaxNodes')
128 self._write_attr(root, 'run_until_canceled', '_J_RunUntilCanceled')
127 self._write_attr(root, 'run_until_canceled', '_J_RunUntilCanceled')
129 self._write_attr(root, 'is_exclusive', '_K_IsExclusive')
128 self._write_attr(root, 'is_exclusive', '_K_IsExclusive')
130 self._write_attr(root, 'username', '_L_UserName')
129 self._write_attr(root, 'username', '_L_UserName')
131 self._write_attr(root, 'job_type', '_M_JobType')
130 self._write_attr(root, 'job_type', '_M_JobType')
132 self._write_attr(root, 'priority', '_N_Priority')
131 self._write_attr(root, 'priority', '_N_Priority')
133 self._write_attr(root, 'requested_nodes', '_O_RequestedNodes')
132 self._write_attr(root, 'requested_nodes', '_O_RequestedNodes')
134 self._write_attr(root, 'auto_calculate_max', '_P_AutoCalculateMax')
133 self._write_attr(root, 'auto_calculate_max', '_P_AutoCalculateMax')
135 self._write_attr(root, 'auto_calculate_min', '_Q_AutoCalculateMin')
134 self._write_attr(root, 'auto_calculate_min', '_Q_AutoCalculateMin')
136 self._write_attr(root, 'project', '_R_Project')
135 self._write_attr(root, 'project', '_R_Project')
137 self._write_attr(root, 'owner', '_S_Owner')
136 self._write_attr(root, 'owner', '_S_Owner')
138 self._write_attr(root, 'xmlns', '_T_xmlns')
137 self._write_attr(root, 'xmlns', '_T_xmlns')
139 dependencies = ET.SubElement(root, "Dependencies")
138 dependencies = ET.SubElement(root, "Dependencies")
140 etasks = ET.SubElement(root, "Tasks")
139 etasks = ET.SubElement(root, "Tasks")
141 for t in self.tasks:
140 for t in self.tasks:
142 etasks.append(t.as_element())
141 etasks.append(t.as_element())
143 return root
142 return root
144
143
145 def tostring(self):
144 def tostring(self):
146 """Return the string representation of the job description XML."""
145 """Return the string representation of the job description XML."""
147 root = self.as_element()
146 root = self.as_element()
148 indent(root)
147 indent(root)
149 txt = ET.tostring(root, encoding="utf-8")
148 txt = ET.tostring(root, encoding="utf-8")
150 # Now remove the tokens used to order the attributes.
149 # Now remove the tokens used to order the attributes.
151 txt = re.sub(r'_[A-Z]_','',txt)
150 txt = re.sub(r'_[A-Z]_','',txt)
152 txt = '<?xml version="1.0" encoding="utf-8"?>\n' + txt
151 txt = '<?xml version="1.0" encoding="utf-8"?>\n' + txt
153 return txt
152 return txt
154
153
155 def write(self, filename):
154 def write(self, filename):
156 """Write the XML job description to a file."""
155 """Write the XML job description to a file."""
157 txt = self.tostring()
156 txt = self.tostring()
158 with open(filename, 'w') as f:
157 with open(filename, 'w') as f:
159 f.write(txt)
158 f.write(txt)
160
159
161 def add_task(self, task):
160 def add_task(self, task):
162 """Add a task to the job.
161 """Add a task to the job.
163
162
164 Parameters
163 Parameters
165 ----------
164 ----------
166 task : :class:`WinHPCTask`
165 task : :class:`WinHPCTask`
167 The task object to add.
166 The task object to add.
168 """
167 """
169 self.tasks.append(task)
168 self.tasks.append(task)
170
169
171
170
172 class WinHPCTask(Configurable):
171 class WinHPCTask(Configurable):
173
172
174 task_id = Unicode('')
173 task_id = Unicode('')
175 task_name = Unicode('')
174 task_name = Unicode('')
176 version = Unicode("2.000")
175 version = Unicode("2.000")
177 min_cores = Int(1, config=True)
176 min_cores = Int(1, config=True)
178 max_cores = Int(1, config=True)
177 max_cores = Int(1, config=True)
179 min_sockets = Int(1, config=True)
178 min_sockets = Int(1, config=True)
180 max_sockets = Int(1, config=True)
179 max_sockets = Int(1, config=True)
181 min_nodes = Int(1, config=True)
180 min_nodes = Int(1, config=True)
182 max_nodes = Int(1, config=True)
181 max_nodes = Int(1, config=True)
183 unit_type = Unicode("Core", config=True)
182 unit_type = Unicode("Core", config=True)
184 command_line = Unicode('', config=True)
183 command_line = Unicode('', config=True)
185 work_directory = Unicode('', config=True)
184 work_directory = Unicode('', config=True)
186 is_rerunnaable = Bool(True, config=True)
185 is_rerunnaable = Bool(True, config=True)
187 std_out_file_path = Unicode('', config=True)
186 std_out_file_path = Unicode('', config=True)
188 std_err_file_path = Unicode('', config=True)
187 std_err_file_path = Unicode('', config=True)
189 is_parametric = Bool(False, config=True)
188 is_parametric = Bool(False, config=True)
190 environment_variables = Instance(dict, args=(), config=True)
189 environment_variables = Instance(dict, args=(), config=True)
191
190
192 def _write_attr(self, root, attr, key):
191 def _write_attr(self, root, attr, key):
193 s = as_str(getattr(self, attr, ''))
192 s = as_str(getattr(self, attr, ''))
194 if s:
193 if s:
195 root.set(key, s)
194 root.set(key, s)
196
195
197 def as_element(self):
196 def as_element(self):
198 root = ET.Element('Task')
197 root = ET.Element('Task')
199 self._write_attr(root, 'version', '_A_Version')
198 self._write_attr(root, 'version', '_A_Version')
200 self._write_attr(root, 'task_name', '_B_Name')
199 self._write_attr(root, 'task_name', '_B_Name')
201 self._write_attr(root, 'min_cores', '_C_MinCores')
200 self._write_attr(root, 'min_cores', '_C_MinCores')
202 self._write_attr(root, 'max_cores', '_D_MaxCores')
201 self._write_attr(root, 'max_cores', '_D_MaxCores')
203 self._write_attr(root, 'min_sockets', '_E_MinSockets')
202 self._write_attr(root, 'min_sockets', '_E_MinSockets')
204 self._write_attr(root, 'max_sockets', '_F_MaxSockets')
203 self._write_attr(root, 'max_sockets', '_F_MaxSockets')
205 self._write_attr(root, 'min_nodes', '_G_MinNodes')
204 self._write_attr(root, 'min_nodes', '_G_MinNodes')
206 self._write_attr(root, 'max_nodes', '_H_MaxNodes')
205 self._write_attr(root, 'max_nodes', '_H_MaxNodes')
207 self._write_attr(root, 'command_line', '_I_CommandLine')
206 self._write_attr(root, 'command_line', '_I_CommandLine')
208 self._write_attr(root, 'work_directory', '_J_WorkDirectory')
207 self._write_attr(root, 'work_directory', '_J_WorkDirectory')
209 self._write_attr(root, 'is_rerunnaable', '_K_IsRerunnable')
208 self._write_attr(root, 'is_rerunnaable', '_K_IsRerunnable')
210 self._write_attr(root, 'std_out_file_path', '_L_StdOutFilePath')
209 self._write_attr(root, 'std_out_file_path', '_L_StdOutFilePath')
211 self._write_attr(root, 'std_err_file_path', '_M_StdErrFilePath')
210 self._write_attr(root, 'std_err_file_path', '_M_StdErrFilePath')
212 self._write_attr(root, 'is_parametric', '_N_IsParametric')
211 self._write_attr(root, 'is_parametric', '_N_IsParametric')
213 self._write_attr(root, 'unit_type', '_O_UnitType')
212 self._write_attr(root, 'unit_type', '_O_UnitType')
214 root.append(self.get_env_vars())
213 root.append(self.get_env_vars())
215 return root
214 return root
216
215
217 def get_env_vars(self):
216 def get_env_vars(self):
218 env_vars = ET.Element('EnvironmentVariables')
217 env_vars = ET.Element('EnvironmentVariables')
219 for k, v in self.environment_variables.iteritems():
218 for k, v in self.environment_variables.iteritems():
220 variable = ET.SubElement(env_vars, "Variable")
219 variable = ET.SubElement(env_vars, "Variable")
221 name = ET.SubElement(variable, "Name")
220 name = ET.SubElement(variable, "Name")
222 name.text = k
221 name.text = k
223 value = ET.SubElement(variable, "Value")
222 value = ET.SubElement(variable, "Value")
224 value.text = v
223 value.text = v
225 return env_vars
224 return env_vars
226
225
227
226
228
227
229 # By declaring these, we can configure the controller and engine separately!
228 # By declaring these, we can configure the controller and engine separately!
230
229
231 class IPControllerJob(WinHPCJob):
230 class IPControllerJob(WinHPCJob):
232 job_name = Unicode('IPController', config=False)
231 job_name = Unicode('IPController', config=False)
233 is_exclusive = Bool(False, config=True)
232 is_exclusive = Bool(False, config=True)
234 username = Unicode(find_username(), config=True)
233 username = Unicode(find_username(), config=True)
235 priority = Enum(('Lowest','BelowNormal','Normal','AboveNormal','Highest'),
234 priority = Enum(('Lowest','BelowNormal','Normal','AboveNormal','Highest'),
236 default_value='Highest', config=True)
235 default_value='Highest', config=True)
237 requested_nodes = Unicode('', config=True)
236 requested_nodes = Unicode('', config=True)
238 project = Unicode('IPython', config=True)
237 project = Unicode('IPython', config=True)
239
238
240
239
241 class IPEngineSetJob(WinHPCJob):
240 class IPEngineSetJob(WinHPCJob):
242 job_name = Unicode('IPEngineSet', config=False)
241 job_name = Unicode('IPEngineSet', config=False)
243 is_exclusive = Bool(False, config=True)
242 is_exclusive = Bool(False, config=True)
244 username = Unicode(find_username(), config=True)
243 username = Unicode(find_username(), config=True)
245 priority = Enum(('Lowest','BelowNormal','Normal','AboveNormal','Highest'),
244 priority = Enum(('Lowest','BelowNormal','Normal','AboveNormal','Highest'),
246 default_value='Highest', config=True)
245 default_value='Highest', config=True)
247 requested_nodes = Unicode('', config=True)
246 requested_nodes = Unicode('', config=True)
248 project = Unicode('IPython', config=True)
247 project = Unicode('IPython', config=True)
249
248
250
249
251 class IPControllerTask(WinHPCTask):
250 class IPControllerTask(WinHPCTask):
252
251
253 task_name = Unicode('IPController', config=True)
252 task_name = Unicode('IPController', config=True)
254 controller_cmd = List(['ipcontroller.exe'], config=True)
253 controller_cmd = List(['ipcontroller.exe'], config=True)
255 controller_args = List(['--log-to-file', '--log-level=40'], config=True)
254 controller_args = List(['--log-to-file', '--log-level=40'], config=True)
256 # I don't want these to be configurable
255 # I don't want these to be configurable
257 std_out_file_path = Unicode('', config=False)
256 std_out_file_path = Unicode('', config=False)
258 std_err_file_path = Unicode('', config=False)
257 std_err_file_path = Unicode('', config=False)
259 min_cores = Int(1, config=False)
258 min_cores = Int(1, config=False)
260 max_cores = Int(1, config=False)
259 max_cores = Int(1, config=False)
261 min_sockets = Int(1, config=False)
260 min_sockets = Int(1, config=False)
262 max_sockets = Int(1, config=False)
261 max_sockets = Int(1, config=False)
263 min_nodes = Int(1, config=False)
262 min_nodes = Int(1, config=False)
264 max_nodes = Int(1, config=False)
263 max_nodes = Int(1, config=False)
265 unit_type = Unicode("Core", config=False)
264 unit_type = Unicode("Core", config=False)
266 work_directory = Unicode('', config=False)
265 work_directory = Unicode('', config=False)
267
266
268 def __init__(self, config=None):
267 def __init__(self, config=None):
269 super(IPControllerTask, self).__init__(config=config)
268 super(IPControllerTask, self).__init__(config=config)
270 the_uuid = uuid.uuid1()
269 the_uuid = uuid.uuid1()
271 self.std_out_file_path = os.path.join('log','ipcontroller-%s.out' % the_uuid)
270 self.std_out_file_path = os.path.join('log','ipcontroller-%s.out' % the_uuid)
272 self.std_err_file_path = os.path.join('log','ipcontroller-%s.err' % the_uuid)
271 self.std_err_file_path = os.path.join('log','ipcontroller-%s.err' % the_uuid)
273
272
274 @property
273 @property
275 def command_line(self):
274 def command_line(self):
276 return ' '.join(self.controller_cmd + self.controller_args)
275 return ' '.join(self.controller_cmd + self.controller_args)
277
276
278
277
279 class IPEngineTask(WinHPCTask):
278 class IPEngineTask(WinHPCTask):
280
279
281 task_name = Unicode('IPEngine', config=True)
280 task_name = Unicode('IPEngine', config=True)
282 engine_cmd = List(['ipengine.exe'], config=True)
281 engine_cmd = List(['ipengine.exe'], config=True)
283 engine_args = List(['--log-to-file', '--log-level=40'], config=True)
282 engine_args = List(['--log-to-file', '--log-level=40'], config=True)
284 # I don't want these to be configurable
283 # I don't want these to be configurable
285 std_out_file_path = Unicode('', config=False)
284 std_out_file_path = Unicode('', config=False)
286 std_err_file_path = Unicode('', config=False)
285 std_err_file_path = Unicode('', config=False)
287 min_cores = Int(1, config=False)
286 min_cores = Int(1, config=False)
288 max_cores = Int(1, config=False)
287 max_cores = Int(1, config=False)
289 min_sockets = Int(1, config=False)
288 min_sockets = Int(1, config=False)
290 max_sockets = Int(1, config=False)
289 max_sockets = Int(1, config=False)
291 min_nodes = Int(1, config=False)
290 min_nodes = Int(1, config=False)
292 max_nodes = Int(1, config=False)
291 max_nodes = Int(1, config=False)
293 unit_type = Unicode("Core", config=False)
292 unit_type = Unicode("Core", config=False)
294 work_directory = Unicode('', config=False)
293 work_directory = Unicode('', config=False)
295
294
296 def __init__(self, config=None):
295 def __init__(self, config=None):
297 super(IPEngineTask,self).__init__(config=config)
296 super(IPEngineTask,self).__init__(config=config)
298 the_uuid = uuid.uuid1()
297 the_uuid = uuid.uuid1()
299 self.std_out_file_path = os.path.join('log','ipengine-%s.out' % the_uuid)
298 self.std_out_file_path = os.path.join('log','ipengine-%s.out' % the_uuid)
300 self.std_err_file_path = os.path.join('log','ipengine-%s.err' % the_uuid)
299 self.std_err_file_path = os.path.join('log','ipengine-%s.err' % the_uuid)
301
300
302 @property
301 @property
303 def command_line(self):
302 def command_line(self):
304 return ' '.join(self.engine_cmd + self.engine_args)
303 return ' '.join(self.engine_cmd + self.engine_args)
305
304
306
305
307 # j = WinHPCJob(None)
306 # j = WinHPCJob(None)
308 # j.job_name = 'IPCluster'
307 # j.job_name = 'IPCluster'
309 # j.username = 'GNET\\bgranger'
308 # j.username = 'GNET\\bgranger'
310 # j.requested_nodes = 'GREEN'
309 # j.requested_nodes = 'GREEN'
311 #
310 #
312 # t = WinHPCTask(None)
311 # t = WinHPCTask(None)
313 # t.task_name = 'Controller'
312 # t.task_name = 'Controller'
314 # t.command_line = r"\\blue\domainusers$\bgranger\Python\Python25\Scripts\ipcontroller.exe --log-to-file -p default --log-level 10"
313 # t.command_line = r"\\blue\domainusers$\bgranger\Python\Python25\Scripts\ipcontroller.exe --log-to-file -p default --log-level 10"
315 # t.work_directory = r"\\blue\domainusers$\bgranger\.ipython\cluster_default"
314 # t.work_directory = r"\\blue\domainusers$\bgranger\.ipython\cluster_default"
316 # t.std_out_file_path = 'controller-out.txt'
315 # t.std_out_file_path = 'controller-out.txt'
317 # t.std_err_file_path = 'controller-err.txt'
316 # t.std_err_file_path = 'controller-err.txt'
318 # t.environment_variables['PYTHONPATH'] = r"\\blue\domainusers$\bgranger\Python\Python25\Lib\site-packages"
317 # t.environment_variables['PYTHONPATH'] = r"\\blue\domainusers$\bgranger\Python\Python25\Lib\site-packages"
319 # j.add_task(t)
318 # j.add_task(t)
320
319
1 NO CONTENT: modified file chmod 100644 => 100755
NO CONTENT: modified file chmod 100644 => 100755
@@ -1,1291 +1,1290 b''
1 #!/usr/bin/env python
2 """The IPython Controller Hub with 0MQ
1 """The IPython Controller Hub with 0MQ
3 This is the master object that handles connections from engines and clients,
2 This is the master object that handles connections from engines and clients,
4 and monitors traffic through the various queues.
3 and monitors traffic through the various queues.
5
4
6 Authors:
5 Authors:
7
6
8 * Min RK
7 * Min RK
9 """
8 """
10 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
11 # Copyright (C) 2010 The IPython Development Team
10 # Copyright (C) 2010 The IPython Development Team
12 #
11 #
13 # Distributed under the terms of the BSD License. The full license is in
12 # Distributed under the terms of the BSD License. The full license is in
14 # the file COPYING, distributed as part of this software.
13 # the file COPYING, distributed as part of this software.
15 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
16
15
17 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
18 # Imports
17 # Imports
19 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
20 from __future__ import print_function
19 from __future__ import print_function
21
20
22 import sys
21 import sys
23 import time
22 import time
24 from datetime import datetime
23 from datetime import datetime
25
24
26 import zmq
25 import zmq
27 from zmq.eventloop import ioloop
26 from zmq.eventloop import ioloop
28 from zmq.eventloop.zmqstream import ZMQStream
27 from zmq.eventloop.zmqstream import ZMQStream
29
28
30 # internal:
29 # internal:
31 from IPython.utils.importstring import import_item
30 from IPython.utils.importstring import import_item
32 from IPython.utils.traitlets import (
31 from IPython.utils.traitlets import (
33 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
32 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
34 )
33 )
35
34
36 from IPython.parallel import error, util
35 from IPython.parallel import error, util
37 from IPython.parallel.factory import RegistrationFactory
36 from IPython.parallel.factory import RegistrationFactory
38
37
39 from IPython.zmq.session import SessionFactory
38 from IPython.zmq.session import SessionFactory
40
39
41 from .heartmonitor import HeartMonitor
40 from .heartmonitor import HeartMonitor
42
41
43 #-----------------------------------------------------------------------------
42 #-----------------------------------------------------------------------------
44 # Code
43 # Code
45 #-----------------------------------------------------------------------------
44 #-----------------------------------------------------------------------------
46
45
47 def _passer(*args, **kwargs):
46 def _passer(*args, **kwargs):
48 return
47 return
49
48
50 def _printer(*args, **kwargs):
49 def _printer(*args, **kwargs):
51 print (args)
50 print (args)
52 print (kwargs)
51 print (kwargs)
53
52
54 def empty_record():
53 def empty_record():
55 """Return an empty dict with all record keys."""
54 """Return an empty dict with all record keys."""
56 return {
55 return {
57 'msg_id' : None,
56 'msg_id' : None,
58 'header' : None,
57 'header' : None,
59 'content': None,
58 'content': None,
60 'buffers': None,
59 'buffers': None,
61 'submitted': None,
60 'submitted': None,
62 'client_uuid' : None,
61 'client_uuid' : None,
63 'engine_uuid' : None,
62 'engine_uuid' : None,
64 'started': None,
63 'started': None,
65 'completed': None,
64 'completed': None,
66 'resubmitted': None,
65 'resubmitted': None,
67 'result_header' : None,
66 'result_header' : None,
68 'result_content' : None,
67 'result_content' : None,
69 'result_buffers' : None,
68 'result_buffers' : None,
70 'queue' : None,
69 'queue' : None,
71 'pyin' : None,
70 'pyin' : None,
72 'pyout': None,
71 'pyout': None,
73 'pyerr': None,
72 'pyerr': None,
74 'stdout': '',
73 'stdout': '',
75 'stderr': '',
74 'stderr': '',
76 }
75 }
77
76
78 def init_record(msg):
77 def init_record(msg):
79 """Initialize a TaskRecord based on a request."""
78 """Initialize a TaskRecord based on a request."""
80 header = msg['header']
79 header = msg['header']
81 return {
80 return {
82 'msg_id' : header['msg_id'],
81 'msg_id' : header['msg_id'],
83 'header' : header,
82 'header' : header,
84 'content': msg['content'],
83 'content': msg['content'],
85 'buffers': msg['buffers'],
84 'buffers': msg['buffers'],
86 'submitted': header['date'],
85 'submitted': header['date'],
87 'client_uuid' : None,
86 'client_uuid' : None,
88 'engine_uuid' : None,
87 'engine_uuid' : None,
89 'started': None,
88 'started': None,
90 'completed': None,
89 'completed': None,
91 'resubmitted': None,
90 'resubmitted': None,
92 'result_header' : None,
91 'result_header' : None,
93 'result_content' : None,
92 'result_content' : None,
94 'result_buffers' : None,
93 'result_buffers' : None,
95 'queue' : None,
94 'queue' : None,
96 'pyin' : None,
95 'pyin' : None,
97 'pyout': None,
96 'pyout': None,
98 'pyerr': None,
97 'pyerr': None,
99 'stdout': '',
98 'stdout': '',
100 'stderr': '',
99 'stderr': '',
101 }
100 }
102
101
103
102
104 class EngineConnector(HasTraits):
103 class EngineConnector(HasTraits):
105 """A simple object for accessing the various zmq connections of an object.
104 """A simple object for accessing the various zmq connections of an object.
106 Attributes are:
105 Attributes are:
107 id (int): engine ID
106 id (int): engine ID
108 uuid (str): uuid (unused?)
107 uuid (str): uuid (unused?)
109 queue (str): identity of queue's XREQ socket
108 queue (str): identity of queue's XREQ socket
110 registration (str): identity of registration XREQ socket
109 registration (str): identity of registration XREQ socket
111 heartbeat (str): identity of heartbeat XREQ socket
110 heartbeat (str): identity of heartbeat XREQ socket
112 """
111 """
113 id=Int(0)
112 id=Int(0)
114 queue=CBytes()
113 queue=CBytes()
115 control=CBytes()
114 control=CBytes()
116 registration=CBytes()
115 registration=CBytes()
117 heartbeat=CBytes()
116 heartbeat=CBytes()
118 pending=Set()
117 pending=Set()
119
118
120 class HubFactory(RegistrationFactory):
119 class HubFactory(RegistrationFactory):
121 """The Configurable for setting up a Hub."""
120 """The Configurable for setting up a Hub."""
122
121
123 # port-pairs for monitoredqueues:
122 # port-pairs for monitoredqueues:
124 hb = Tuple(Int,Int,config=True,
123 hb = Tuple(Int,Int,config=True,
125 help="""XREQ/SUB Port pair for Engine heartbeats""")
124 help="""XREQ/SUB Port pair for Engine heartbeats""")
126 def _hb_default(self):
125 def _hb_default(self):
127 return tuple(util.select_random_ports(2))
126 return tuple(util.select_random_ports(2))
128
127
129 mux = Tuple(Int,Int,config=True,
128 mux = Tuple(Int,Int,config=True,
130 help="""Engine/Client Port pair for MUX queue""")
129 help="""Engine/Client Port pair for MUX queue""")
131
130
132 def _mux_default(self):
131 def _mux_default(self):
133 return tuple(util.select_random_ports(2))
132 return tuple(util.select_random_ports(2))
134
133
135 task = Tuple(Int,Int,config=True,
134 task = Tuple(Int,Int,config=True,
136 help="""Engine/Client Port pair for Task queue""")
135 help="""Engine/Client Port pair for Task queue""")
137 def _task_default(self):
136 def _task_default(self):
138 return tuple(util.select_random_ports(2))
137 return tuple(util.select_random_ports(2))
139
138
140 control = Tuple(Int,Int,config=True,
139 control = Tuple(Int,Int,config=True,
141 help="""Engine/Client Port pair for Control queue""")
140 help="""Engine/Client Port pair for Control queue""")
142
141
143 def _control_default(self):
142 def _control_default(self):
144 return tuple(util.select_random_ports(2))
143 return tuple(util.select_random_ports(2))
145
144
146 iopub = Tuple(Int,Int,config=True,
145 iopub = Tuple(Int,Int,config=True,
147 help="""Engine/Client Port pair for IOPub relay""")
146 help="""Engine/Client Port pair for IOPub relay""")
148
147
149 def _iopub_default(self):
148 def _iopub_default(self):
150 return tuple(util.select_random_ports(2))
149 return tuple(util.select_random_ports(2))
151
150
152 # single ports:
151 # single ports:
153 mon_port = Int(config=True,
152 mon_port = Int(config=True,
154 help="""Monitor (SUB) port for queue traffic""")
153 help="""Monitor (SUB) port for queue traffic""")
155
154
156 def _mon_port_default(self):
155 def _mon_port_default(self):
157 return util.select_random_ports(1)[0]
156 return util.select_random_ports(1)[0]
158
157
159 notifier_port = Int(config=True,
158 notifier_port = Int(config=True,
160 help="""PUB port for sending engine status notifications""")
159 help="""PUB port for sending engine status notifications""")
161
160
162 def _notifier_port_default(self):
161 def _notifier_port_default(self):
163 return util.select_random_ports(1)[0]
162 return util.select_random_ports(1)[0]
164
163
165 engine_ip = Unicode('127.0.0.1', config=True,
164 engine_ip = Unicode('127.0.0.1', config=True,
166 help="IP on which to listen for engine connections. [default: loopback]")
165 help="IP on which to listen for engine connections. [default: loopback]")
167 engine_transport = Unicode('tcp', config=True,
166 engine_transport = Unicode('tcp', config=True,
168 help="0MQ transport for engine connections. [default: tcp]")
167 help="0MQ transport for engine connections. [default: tcp]")
169
168
170 client_ip = Unicode('127.0.0.1', config=True,
169 client_ip = Unicode('127.0.0.1', config=True,
171 help="IP on which to listen for client connections. [default: loopback]")
170 help="IP on which to listen for client connections. [default: loopback]")
172 client_transport = Unicode('tcp', config=True,
171 client_transport = Unicode('tcp', config=True,
173 help="0MQ transport for client connections. [default : tcp]")
172 help="0MQ transport for client connections. [default : tcp]")
174
173
175 monitor_ip = Unicode('127.0.0.1', config=True,
174 monitor_ip = Unicode('127.0.0.1', config=True,
176 help="IP on which to listen for monitor messages. [default: loopback]")
175 help="IP on which to listen for monitor messages. [default: loopback]")
177 monitor_transport = Unicode('tcp', config=True,
176 monitor_transport = Unicode('tcp', config=True,
178 help="0MQ transport for monitor messages. [default : tcp]")
177 help="0MQ transport for monitor messages. [default : tcp]")
179
178
180 monitor_url = Unicode('')
179 monitor_url = Unicode('')
181
180
182 db_class = DottedObjectName('IPython.parallel.controller.dictdb.DictDB',
181 db_class = DottedObjectName('IPython.parallel.controller.dictdb.DictDB',
183 config=True, help="""The class to use for the DB backend""")
182 config=True, help="""The class to use for the DB backend""")
184
183
185 # not configurable
184 # not configurable
186 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
185 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
187 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
186 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
188
187
189 def _ip_changed(self, name, old, new):
188 def _ip_changed(self, name, old, new):
190 self.engine_ip = new
189 self.engine_ip = new
191 self.client_ip = new
190 self.client_ip = new
192 self.monitor_ip = new
191 self.monitor_ip = new
193 self._update_monitor_url()
192 self._update_monitor_url()
194
193
195 def _update_monitor_url(self):
194 def _update_monitor_url(self):
196 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
195 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
197
196
198 def _transport_changed(self, name, old, new):
197 def _transport_changed(self, name, old, new):
199 self.engine_transport = new
198 self.engine_transport = new
200 self.client_transport = new
199 self.client_transport = new
201 self.monitor_transport = new
200 self.monitor_transport = new
202 self._update_monitor_url()
201 self._update_monitor_url()
203
202
204 def __init__(self, **kwargs):
203 def __init__(self, **kwargs):
205 super(HubFactory, self).__init__(**kwargs)
204 super(HubFactory, self).__init__(**kwargs)
206 self._update_monitor_url()
205 self._update_monitor_url()
207
206
208
207
209 def construct(self):
208 def construct(self):
210 self.init_hub()
209 self.init_hub()
211
210
212 def start(self):
211 def start(self):
213 self.heartmonitor.start()
212 self.heartmonitor.start()
214 self.log.info("Heartmonitor started")
213 self.log.info("Heartmonitor started")
215
214
216 def init_hub(self):
215 def init_hub(self):
217 """construct"""
216 """construct"""
218 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
217 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
219 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
218 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
220
219
221 ctx = self.context
220 ctx = self.context
222 loop = self.loop
221 loop = self.loop
223
222
224 # Registrar socket
223 # Registrar socket
225 q = ZMQStream(ctx.socket(zmq.XREP), loop)
224 q = ZMQStream(ctx.socket(zmq.XREP), loop)
226 q.bind(client_iface % self.regport)
225 q.bind(client_iface % self.regport)
227 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
226 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
228 if self.client_ip != self.engine_ip:
227 if self.client_ip != self.engine_ip:
229 q.bind(engine_iface % self.regport)
228 q.bind(engine_iface % self.regport)
230 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
229 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
231
230
232 ### Engine connections ###
231 ### Engine connections ###
233
232
234 # heartbeat
233 # heartbeat
235 hpub = ctx.socket(zmq.PUB)
234 hpub = ctx.socket(zmq.PUB)
236 hpub.bind(engine_iface % self.hb[0])
235 hpub.bind(engine_iface % self.hb[0])
237 hrep = ctx.socket(zmq.XREP)
236 hrep = ctx.socket(zmq.XREP)
238 hrep.bind(engine_iface % self.hb[1])
237 hrep.bind(engine_iface % self.hb[1])
239 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
238 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
240 pingstream=ZMQStream(hpub,loop),
239 pingstream=ZMQStream(hpub,loop),
241 pongstream=ZMQStream(hrep,loop)
240 pongstream=ZMQStream(hrep,loop)
242 )
241 )
243
242
244 ### Client connections ###
243 ### Client connections ###
245 # Notifier socket
244 # Notifier socket
246 n = ZMQStream(ctx.socket(zmq.PUB), loop)
245 n = ZMQStream(ctx.socket(zmq.PUB), loop)
247 n.bind(client_iface%self.notifier_port)
246 n.bind(client_iface%self.notifier_port)
248
247
249 ### build and launch the queues ###
248 ### build and launch the queues ###
250
249
251 # monitor socket
250 # monitor socket
252 sub = ctx.socket(zmq.SUB)
251 sub = ctx.socket(zmq.SUB)
253 sub.setsockopt(zmq.SUBSCRIBE, b"")
252 sub.setsockopt(zmq.SUBSCRIBE, b"")
254 sub.bind(self.monitor_url)
253 sub.bind(self.monitor_url)
255 sub.bind('inproc://monitor')
254 sub.bind('inproc://monitor')
256 sub = ZMQStream(sub, loop)
255 sub = ZMQStream(sub, loop)
257
256
258 # connect the db
257 # connect the db
259 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
258 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
260 # cdir = self.config.Global.cluster_dir
259 # cdir = self.config.Global.cluster_dir
261 self.db = import_item(str(self.db_class))(session=self.session.session,
260 self.db = import_item(str(self.db_class))(session=self.session.session,
262 config=self.config, log=self.log)
261 config=self.config, log=self.log)
263 time.sleep(.25)
262 time.sleep(.25)
264 try:
263 try:
265 scheme = self.config.TaskScheduler.scheme_name
264 scheme = self.config.TaskScheduler.scheme_name
266 except AttributeError:
265 except AttributeError:
267 from .scheduler import TaskScheduler
266 from .scheduler import TaskScheduler
268 scheme = TaskScheduler.scheme_name.get_default_value()
267 scheme = TaskScheduler.scheme_name.get_default_value()
269 # build connection dicts
268 # build connection dicts
270 self.engine_info = {
269 self.engine_info = {
271 'control' : engine_iface%self.control[1],
270 'control' : engine_iface%self.control[1],
272 'mux': engine_iface%self.mux[1],
271 'mux': engine_iface%self.mux[1],
273 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
272 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
274 'task' : engine_iface%self.task[1],
273 'task' : engine_iface%self.task[1],
275 'iopub' : engine_iface%self.iopub[1],
274 'iopub' : engine_iface%self.iopub[1],
276 # 'monitor' : engine_iface%self.mon_port,
275 # 'monitor' : engine_iface%self.mon_port,
277 }
276 }
278
277
279 self.client_info = {
278 self.client_info = {
280 'control' : client_iface%self.control[0],
279 'control' : client_iface%self.control[0],
281 'mux': client_iface%self.mux[0],
280 'mux': client_iface%self.mux[0],
282 'task' : (scheme, client_iface%self.task[0]),
281 'task' : (scheme, client_iface%self.task[0]),
283 'iopub' : client_iface%self.iopub[0],
282 'iopub' : client_iface%self.iopub[0],
284 'notification': client_iface%self.notifier_port
283 'notification': client_iface%self.notifier_port
285 }
284 }
286 self.log.debug("Hub engine addrs: %s"%self.engine_info)
285 self.log.debug("Hub engine addrs: %s"%self.engine_info)
287 self.log.debug("Hub client addrs: %s"%self.client_info)
286 self.log.debug("Hub client addrs: %s"%self.client_info)
288
287
289 # resubmit stream
288 # resubmit stream
290 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
289 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
291 url = util.disambiguate_url(self.client_info['task'][-1])
290 url = util.disambiguate_url(self.client_info['task'][-1])
292 r.setsockopt(zmq.IDENTITY, util.asbytes(self.session.session))
291 r.setsockopt(zmq.IDENTITY, util.asbytes(self.session.session))
293 r.connect(url)
292 r.connect(url)
294
293
295 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
294 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
296 query=q, notifier=n, resubmit=r, db=self.db,
295 query=q, notifier=n, resubmit=r, db=self.db,
297 engine_info=self.engine_info, client_info=self.client_info,
296 engine_info=self.engine_info, client_info=self.client_info,
298 log=self.log)
297 log=self.log)
299
298
300
299
301 class Hub(SessionFactory):
300 class Hub(SessionFactory):
302 """The IPython Controller Hub with 0MQ connections
301 """The IPython Controller Hub with 0MQ connections
303
302
304 Parameters
303 Parameters
305 ==========
304 ==========
306 loop: zmq IOLoop instance
305 loop: zmq IOLoop instance
307 session: Session object
306 session: Session object
308 <removed> context: zmq context for creating new connections (?)
307 <removed> context: zmq context for creating new connections (?)
309 queue: ZMQStream for monitoring the command queue (SUB)
308 queue: ZMQStream for monitoring the command queue (SUB)
310 query: ZMQStream for engine registration and client queries requests (XREP)
309 query: ZMQStream for engine registration and client queries requests (XREP)
311 heartbeat: HeartMonitor object checking the pulse of the engines
310 heartbeat: HeartMonitor object checking the pulse of the engines
312 notifier: ZMQStream for broadcasting engine registration changes (PUB)
311 notifier: ZMQStream for broadcasting engine registration changes (PUB)
313 db: connection to db for out of memory logging of commands
312 db: connection to db for out of memory logging of commands
314 NotImplemented
313 NotImplemented
315 engine_info: dict of zmq connection information for engines to connect
314 engine_info: dict of zmq connection information for engines to connect
316 to the queues.
315 to the queues.
317 client_info: dict of zmq connection information for engines to connect
316 client_info: dict of zmq connection information for engines to connect
318 to the queues.
317 to the queues.
319 """
318 """
320 # internal data structures:
319 # internal data structures:
321 ids=Set() # engine IDs
320 ids=Set() # engine IDs
322 keytable=Dict()
321 keytable=Dict()
323 by_ident=Dict()
322 by_ident=Dict()
324 engines=Dict()
323 engines=Dict()
325 clients=Dict()
324 clients=Dict()
326 hearts=Dict()
325 hearts=Dict()
327 pending=Set()
326 pending=Set()
328 queues=Dict() # pending msg_ids keyed by engine_id
327 queues=Dict() # pending msg_ids keyed by engine_id
329 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
328 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
330 completed=Dict() # completed msg_ids keyed by engine_id
329 completed=Dict() # completed msg_ids keyed by engine_id
331 all_completed=Set() # completed msg_ids keyed by engine_id
330 all_completed=Set() # completed msg_ids keyed by engine_id
332 dead_engines=Set() # completed msg_ids keyed by engine_id
331 dead_engines=Set() # completed msg_ids keyed by engine_id
333 unassigned=Set() # set of task msg_ds not yet assigned a destination
332 unassigned=Set() # set of task msg_ds not yet assigned a destination
334 incoming_registrations=Dict()
333 incoming_registrations=Dict()
335 registration_timeout=Int()
334 registration_timeout=Int()
336 _idcounter=Int(0)
335 _idcounter=Int(0)
337
336
338 # objects from constructor:
337 # objects from constructor:
339 query=Instance(ZMQStream)
338 query=Instance(ZMQStream)
340 monitor=Instance(ZMQStream)
339 monitor=Instance(ZMQStream)
341 notifier=Instance(ZMQStream)
340 notifier=Instance(ZMQStream)
342 resubmit=Instance(ZMQStream)
341 resubmit=Instance(ZMQStream)
343 heartmonitor=Instance(HeartMonitor)
342 heartmonitor=Instance(HeartMonitor)
344 db=Instance(object)
343 db=Instance(object)
345 client_info=Dict()
344 client_info=Dict()
346 engine_info=Dict()
345 engine_info=Dict()
347
346
348
347
349 def __init__(self, **kwargs):
348 def __init__(self, **kwargs):
350 """
349 """
351 # universal:
350 # universal:
352 loop: IOLoop for creating future connections
351 loop: IOLoop for creating future connections
353 session: streamsession for sending serialized data
352 session: streamsession for sending serialized data
354 # engine:
353 # engine:
355 queue: ZMQStream for monitoring queue messages
354 queue: ZMQStream for monitoring queue messages
356 query: ZMQStream for engine+client registration and client requests
355 query: ZMQStream for engine+client registration and client requests
357 heartbeat: HeartMonitor object for tracking engines
356 heartbeat: HeartMonitor object for tracking engines
358 # extra:
357 # extra:
359 db: ZMQStream for db connection (NotImplemented)
358 db: ZMQStream for db connection (NotImplemented)
360 engine_info: zmq address/protocol dict for engine connections
359 engine_info: zmq address/protocol dict for engine connections
361 client_info: zmq address/protocol dict for client connections
360 client_info: zmq address/protocol dict for client connections
362 """
361 """
363
362
364 super(Hub, self).__init__(**kwargs)
363 super(Hub, self).__init__(**kwargs)
365 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
364 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
366
365
367 # validate connection dicts:
366 # validate connection dicts:
368 for k,v in self.client_info.iteritems():
367 for k,v in self.client_info.iteritems():
369 if k == 'task':
368 if k == 'task':
370 util.validate_url_container(v[1])
369 util.validate_url_container(v[1])
371 else:
370 else:
372 util.validate_url_container(v)
371 util.validate_url_container(v)
373 # util.validate_url_container(self.client_info)
372 # util.validate_url_container(self.client_info)
374 util.validate_url_container(self.engine_info)
373 util.validate_url_container(self.engine_info)
375
374
376 # register our callbacks
375 # register our callbacks
377 self.query.on_recv(self.dispatch_query)
376 self.query.on_recv(self.dispatch_query)
378 self.monitor.on_recv(self.dispatch_monitor_traffic)
377 self.monitor.on_recv(self.dispatch_monitor_traffic)
379
378
380 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
379 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
381 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
380 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
382
381
383 self.monitor_handlers = {b'in' : self.save_queue_request,
382 self.monitor_handlers = {b'in' : self.save_queue_request,
384 b'out': self.save_queue_result,
383 b'out': self.save_queue_result,
385 b'intask': self.save_task_request,
384 b'intask': self.save_task_request,
386 b'outtask': self.save_task_result,
385 b'outtask': self.save_task_result,
387 b'tracktask': self.save_task_destination,
386 b'tracktask': self.save_task_destination,
388 b'incontrol': _passer,
387 b'incontrol': _passer,
389 b'outcontrol': _passer,
388 b'outcontrol': _passer,
390 b'iopub': self.save_iopub_message,
389 b'iopub': self.save_iopub_message,
391 }
390 }
392
391
393 self.query_handlers = {'queue_request': self.queue_status,
392 self.query_handlers = {'queue_request': self.queue_status,
394 'result_request': self.get_results,
393 'result_request': self.get_results,
395 'history_request': self.get_history,
394 'history_request': self.get_history,
396 'db_request': self.db_query,
395 'db_request': self.db_query,
397 'purge_request': self.purge_results,
396 'purge_request': self.purge_results,
398 'load_request': self.check_load,
397 'load_request': self.check_load,
399 'resubmit_request': self.resubmit_task,
398 'resubmit_request': self.resubmit_task,
400 'shutdown_request': self.shutdown_request,
399 'shutdown_request': self.shutdown_request,
401 'registration_request' : self.register_engine,
400 'registration_request' : self.register_engine,
402 'unregistration_request' : self.unregister_engine,
401 'unregistration_request' : self.unregister_engine,
403 'connection_request': self.connection_request,
402 'connection_request': self.connection_request,
404 }
403 }
405
404
406 # ignore resubmit replies
405 # ignore resubmit replies
407 self.resubmit.on_recv(lambda msg: None, copy=False)
406 self.resubmit.on_recv(lambda msg: None, copy=False)
408
407
409 self.log.info("hub::created hub")
408 self.log.info("hub::created hub")
410
409
411 @property
410 @property
412 def _next_id(self):
411 def _next_id(self):
413 """gemerate a new ID.
412 """gemerate a new ID.
414
413
415 No longer reuse old ids, just count from 0."""
414 No longer reuse old ids, just count from 0."""
416 newid = self._idcounter
415 newid = self._idcounter
417 self._idcounter += 1
416 self._idcounter += 1
418 return newid
417 return newid
419 # newid = 0
418 # newid = 0
420 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
419 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
421 # # print newid, self.ids, self.incoming_registrations
420 # # print newid, self.ids, self.incoming_registrations
422 # while newid in self.ids or newid in incoming:
421 # while newid in self.ids or newid in incoming:
423 # newid += 1
422 # newid += 1
424 # return newid
423 # return newid
425
424
426 #-----------------------------------------------------------------------------
425 #-----------------------------------------------------------------------------
427 # message validation
426 # message validation
428 #-----------------------------------------------------------------------------
427 #-----------------------------------------------------------------------------
429
428
430 def _validate_targets(self, targets):
429 def _validate_targets(self, targets):
431 """turn any valid targets argument into a list of integer ids"""
430 """turn any valid targets argument into a list of integer ids"""
432 if targets is None:
431 if targets is None:
433 # default to all
432 # default to all
434 targets = self.ids
433 targets = self.ids
435
434
436 if isinstance(targets, (int,str,unicode)):
435 if isinstance(targets, (int,str,unicode)):
437 # only one target specified
436 # only one target specified
438 targets = [targets]
437 targets = [targets]
439 _targets = []
438 _targets = []
440 for t in targets:
439 for t in targets:
441 # map raw identities to ids
440 # map raw identities to ids
442 if isinstance(t, (str,unicode)):
441 if isinstance(t, (str,unicode)):
443 t = self.by_ident.get(t, t)
442 t = self.by_ident.get(t, t)
444 _targets.append(t)
443 _targets.append(t)
445 targets = _targets
444 targets = _targets
446 bad_targets = [ t for t in targets if t not in self.ids ]
445 bad_targets = [ t for t in targets if t not in self.ids ]
447 if bad_targets:
446 if bad_targets:
448 raise IndexError("No Such Engine: %r"%bad_targets)
447 raise IndexError("No Such Engine: %r"%bad_targets)
449 if not targets:
448 if not targets:
450 raise IndexError("No Engines Registered")
449 raise IndexError("No Engines Registered")
451 return targets
450 return targets
452
451
453 #-----------------------------------------------------------------------------
452 #-----------------------------------------------------------------------------
454 # dispatch methods (1 per stream)
453 # dispatch methods (1 per stream)
455 #-----------------------------------------------------------------------------
454 #-----------------------------------------------------------------------------
456
455
457
456
458 def dispatch_monitor_traffic(self, msg):
457 def dispatch_monitor_traffic(self, msg):
459 """all ME and Task queue messages come through here, as well as
458 """all ME and Task queue messages come through here, as well as
460 IOPub traffic."""
459 IOPub traffic."""
461 self.log.debug("monitor traffic: %r"%msg[:2])
460 self.log.debug("monitor traffic: %r"%msg[:2])
462 switch = msg[0]
461 switch = msg[0]
463 try:
462 try:
464 idents, msg = self.session.feed_identities(msg[1:])
463 idents, msg = self.session.feed_identities(msg[1:])
465 except ValueError:
464 except ValueError:
466 idents=[]
465 idents=[]
467 if not idents:
466 if not idents:
468 self.log.error("Bad Monitor Message: %r"%msg)
467 self.log.error("Bad Monitor Message: %r"%msg)
469 return
468 return
470 handler = self.monitor_handlers.get(switch, None)
469 handler = self.monitor_handlers.get(switch, None)
471 if handler is not None:
470 if handler is not None:
472 handler(idents, msg)
471 handler(idents, msg)
473 else:
472 else:
474 self.log.error("Invalid monitor topic: %r"%switch)
473 self.log.error("Invalid monitor topic: %r"%switch)
475
474
476
475
477 def dispatch_query(self, msg):
476 def dispatch_query(self, msg):
478 """Route registration requests and queries from clients."""
477 """Route registration requests and queries from clients."""
479 try:
478 try:
480 idents, msg = self.session.feed_identities(msg)
479 idents, msg = self.session.feed_identities(msg)
481 except ValueError:
480 except ValueError:
482 idents = []
481 idents = []
483 if not idents:
482 if not idents:
484 self.log.error("Bad Query Message: %r"%msg)
483 self.log.error("Bad Query Message: %r"%msg)
485 return
484 return
486 client_id = idents[0]
485 client_id = idents[0]
487 try:
486 try:
488 msg = self.session.unpack_message(msg, content=True)
487 msg = self.session.unpack_message(msg, content=True)
489 except Exception:
488 except Exception:
490 content = error.wrap_exception()
489 content = error.wrap_exception()
491 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
490 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
492 self.session.send(self.query, "hub_error", ident=client_id,
491 self.session.send(self.query, "hub_error", ident=client_id,
493 content=content)
492 content=content)
494 return
493 return
495 # print client_id, header, parent, content
494 # print client_id, header, parent, content
496 #switch on message type:
495 #switch on message type:
497 msg_type = msg['msg_type']
496 msg_type = msg['msg_type']
498 self.log.info("client::client %r requested %r"%(client_id, msg_type))
497 self.log.info("client::client %r requested %r"%(client_id, msg_type))
499 handler = self.query_handlers.get(msg_type, None)
498 handler = self.query_handlers.get(msg_type, None)
500 try:
499 try:
501 assert handler is not None, "Bad Message Type: %r"%msg_type
500 assert handler is not None, "Bad Message Type: %r"%msg_type
502 except:
501 except:
503 content = error.wrap_exception()
502 content = error.wrap_exception()
504 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
503 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
505 self.session.send(self.query, "hub_error", ident=client_id,
504 self.session.send(self.query, "hub_error", ident=client_id,
506 content=content)
505 content=content)
507 return
506 return
508
507
509 else:
508 else:
510 handler(idents, msg)
509 handler(idents, msg)
511
510
512 def dispatch_db(self, msg):
511 def dispatch_db(self, msg):
513 """"""
512 """"""
514 raise NotImplementedError
513 raise NotImplementedError
515
514
516 #---------------------------------------------------------------------------
515 #---------------------------------------------------------------------------
517 # handler methods (1 per event)
516 # handler methods (1 per event)
518 #---------------------------------------------------------------------------
517 #---------------------------------------------------------------------------
519
518
520 #----------------------- Heartbeat --------------------------------------
519 #----------------------- Heartbeat --------------------------------------
521
520
522 def handle_new_heart(self, heart):
521 def handle_new_heart(self, heart):
523 """handler to attach to heartbeater.
522 """handler to attach to heartbeater.
524 Called when a new heart starts to beat.
523 Called when a new heart starts to beat.
525 Triggers completion of registration."""
524 Triggers completion of registration."""
526 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
525 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
527 if heart not in self.incoming_registrations:
526 if heart not in self.incoming_registrations:
528 self.log.info("heartbeat::ignoring new heart: %r"%heart)
527 self.log.info("heartbeat::ignoring new heart: %r"%heart)
529 else:
528 else:
530 self.finish_registration(heart)
529 self.finish_registration(heart)
531
530
532
531
533 def handle_heart_failure(self, heart):
532 def handle_heart_failure(self, heart):
534 """handler to attach to heartbeater.
533 """handler to attach to heartbeater.
535 called when a previously registered heart fails to respond to beat request.
534 called when a previously registered heart fails to respond to beat request.
536 triggers unregistration"""
535 triggers unregistration"""
537 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
536 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
538 eid = self.hearts.get(heart, None)
537 eid = self.hearts.get(heart, None)
539 queue = self.engines[eid].queue
538 queue = self.engines[eid].queue
540 if eid is None:
539 if eid is None:
541 self.log.info("heartbeat::ignoring heart failure %r"%heart)
540 self.log.info("heartbeat::ignoring heart failure %r"%heart)
542 else:
541 else:
543 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
542 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
544
543
545 #----------------------- MUX Queue Traffic ------------------------------
544 #----------------------- MUX Queue Traffic ------------------------------
546
545
547 def save_queue_request(self, idents, msg):
546 def save_queue_request(self, idents, msg):
548 if len(idents) < 2:
547 if len(idents) < 2:
549 self.log.error("invalid identity prefix: %r"%idents)
548 self.log.error("invalid identity prefix: %r"%idents)
550 return
549 return
551 queue_id, client_id = idents[:2]
550 queue_id, client_id = idents[:2]
552 try:
551 try:
553 msg = self.session.unpack_message(msg)
552 msg = self.session.unpack_message(msg)
554 except Exception:
553 except Exception:
555 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
554 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
556 return
555 return
557
556
558 eid = self.by_ident.get(queue_id, None)
557 eid = self.by_ident.get(queue_id, None)
559 if eid is None:
558 if eid is None:
560 self.log.error("queue::target %r not registered"%queue_id)
559 self.log.error("queue::target %r not registered"%queue_id)
561 self.log.debug("queue:: valid are: %r"%(self.by_ident.keys()))
560 self.log.debug("queue:: valid are: %r"%(self.by_ident.keys()))
562 return
561 return
563 record = init_record(msg)
562 record = init_record(msg)
564 msg_id = record['msg_id']
563 msg_id = record['msg_id']
565 # Unicode in records
564 # Unicode in records
566 record['engine_uuid'] = queue_id.decode('ascii')
565 record['engine_uuid'] = queue_id.decode('ascii')
567 record['client_uuid'] = client_id.decode('ascii')
566 record['client_uuid'] = client_id.decode('ascii')
568 record['queue'] = 'mux'
567 record['queue'] = 'mux'
569
568
570 try:
569 try:
571 # it's posible iopub arrived first:
570 # it's posible iopub arrived first:
572 existing = self.db.get_record(msg_id)
571 existing = self.db.get_record(msg_id)
573 for key,evalue in existing.iteritems():
572 for key,evalue in existing.iteritems():
574 rvalue = record.get(key, None)
573 rvalue = record.get(key, None)
575 if evalue and rvalue and evalue != rvalue:
574 if evalue and rvalue and evalue != rvalue:
576 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
575 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
577 elif evalue and not rvalue:
576 elif evalue and not rvalue:
578 record[key] = evalue
577 record[key] = evalue
579 try:
578 try:
580 self.db.update_record(msg_id, record)
579 self.db.update_record(msg_id, record)
581 except Exception:
580 except Exception:
582 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
581 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
583 except KeyError:
582 except KeyError:
584 try:
583 try:
585 self.db.add_record(msg_id, record)
584 self.db.add_record(msg_id, record)
586 except Exception:
585 except Exception:
587 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
586 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
588
587
589
588
590 self.pending.add(msg_id)
589 self.pending.add(msg_id)
591 self.queues[eid].append(msg_id)
590 self.queues[eid].append(msg_id)
592
591
593 def save_queue_result(self, idents, msg):
592 def save_queue_result(self, idents, msg):
594 if len(idents) < 2:
593 if len(idents) < 2:
595 self.log.error("invalid identity prefix: %r"%idents)
594 self.log.error("invalid identity prefix: %r"%idents)
596 return
595 return
597
596
598 client_id, queue_id = idents[:2]
597 client_id, queue_id = idents[:2]
599 try:
598 try:
600 msg = self.session.unpack_message(msg)
599 msg = self.session.unpack_message(msg)
601 except Exception:
600 except Exception:
602 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
601 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
603 queue_id,client_id, msg), exc_info=True)
602 queue_id,client_id, msg), exc_info=True)
604 return
603 return
605
604
606 eid = self.by_ident.get(queue_id, None)
605 eid = self.by_ident.get(queue_id, None)
607 if eid is None:
606 if eid is None:
608 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
607 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
609 return
608 return
610
609
611 parent = msg['parent_header']
610 parent = msg['parent_header']
612 if not parent:
611 if not parent:
613 return
612 return
614 msg_id = parent['msg_id']
613 msg_id = parent['msg_id']
615 if msg_id in self.pending:
614 if msg_id in self.pending:
616 self.pending.remove(msg_id)
615 self.pending.remove(msg_id)
617 self.all_completed.add(msg_id)
616 self.all_completed.add(msg_id)
618 self.queues[eid].remove(msg_id)
617 self.queues[eid].remove(msg_id)
619 self.completed[eid].append(msg_id)
618 self.completed[eid].append(msg_id)
620 elif msg_id not in self.all_completed:
619 elif msg_id not in self.all_completed:
621 # it could be a result from a dead engine that died before delivering the
620 # it could be a result from a dead engine that died before delivering the
622 # result
621 # result
623 self.log.warn("queue:: unknown msg finished %r"%msg_id)
622 self.log.warn("queue:: unknown msg finished %r"%msg_id)
624 return
623 return
625 # update record anyway, because the unregistration could have been premature
624 # update record anyway, because the unregistration could have been premature
626 rheader = msg['header']
625 rheader = msg['header']
627 completed = rheader['date']
626 completed = rheader['date']
628 started = rheader.get('started', None)
627 started = rheader.get('started', None)
629 result = {
628 result = {
630 'result_header' : rheader,
629 'result_header' : rheader,
631 'result_content': msg['content'],
630 'result_content': msg['content'],
632 'started' : started,
631 'started' : started,
633 'completed' : completed
632 'completed' : completed
634 }
633 }
635
634
636 result['result_buffers'] = msg['buffers']
635 result['result_buffers'] = msg['buffers']
637 try:
636 try:
638 self.db.update_record(msg_id, result)
637 self.db.update_record(msg_id, result)
639 except Exception:
638 except Exception:
640 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
639 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
641
640
642
641
643 #--------------------- Task Queue Traffic ------------------------------
642 #--------------------- Task Queue Traffic ------------------------------
644
643
645 def save_task_request(self, idents, msg):
644 def save_task_request(self, idents, msg):
646 """Save the submission of a task."""
645 """Save the submission of a task."""
647 client_id = idents[0]
646 client_id = idents[0]
648
647
649 try:
648 try:
650 msg = self.session.unpack_message(msg)
649 msg = self.session.unpack_message(msg)
651 except Exception:
650 except Exception:
652 self.log.error("task::client %r sent invalid task message: %r"%(
651 self.log.error("task::client %r sent invalid task message: %r"%(
653 client_id, msg), exc_info=True)
652 client_id, msg), exc_info=True)
654 return
653 return
655 record = init_record(msg)
654 record = init_record(msg)
656
655
657 record['client_uuid'] = client_id
656 record['client_uuid'] = client_id
658 record['queue'] = 'task'
657 record['queue'] = 'task'
659 header = msg['header']
658 header = msg['header']
660 msg_id = header['msg_id']
659 msg_id = header['msg_id']
661 self.pending.add(msg_id)
660 self.pending.add(msg_id)
662 self.unassigned.add(msg_id)
661 self.unassigned.add(msg_id)
663 try:
662 try:
664 # it's posible iopub arrived first:
663 # it's posible iopub arrived first:
665 existing = self.db.get_record(msg_id)
664 existing = self.db.get_record(msg_id)
666 if existing['resubmitted']:
665 if existing['resubmitted']:
667 for key in ('submitted', 'client_uuid', 'buffers'):
666 for key in ('submitted', 'client_uuid', 'buffers'):
668 # don't clobber these keys on resubmit
667 # don't clobber these keys on resubmit
669 # submitted and client_uuid should be different
668 # submitted and client_uuid should be different
670 # and buffers might be big, and shouldn't have changed
669 # and buffers might be big, and shouldn't have changed
671 record.pop(key)
670 record.pop(key)
672 # still check content,header which should not change
671 # still check content,header which should not change
673 # but are not expensive to compare as buffers
672 # but are not expensive to compare as buffers
674
673
675 for key,evalue in existing.iteritems():
674 for key,evalue in existing.iteritems():
676 if key.endswith('buffers'):
675 if key.endswith('buffers'):
677 # don't compare buffers
676 # don't compare buffers
678 continue
677 continue
679 rvalue = record.get(key, None)
678 rvalue = record.get(key, None)
680 if evalue and rvalue and evalue != rvalue:
679 if evalue and rvalue and evalue != rvalue:
681 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
680 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
682 elif evalue and not rvalue:
681 elif evalue and not rvalue:
683 record[key] = evalue
682 record[key] = evalue
684 try:
683 try:
685 self.db.update_record(msg_id, record)
684 self.db.update_record(msg_id, record)
686 except Exception:
685 except Exception:
687 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
686 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
688 except KeyError:
687 except KeyError:
689 try:
688 try:
690 self.db.add_record(msg_id, record)
689 self.db.add_record(msg_id, record)
691 except Exception:
690 except Exception:
692 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
691 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
693 except Exception:
692 except Exception:
694 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
693 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
695
694
696 def save_task_result(self, idents, msg):
695 def save_task_result(self, idents, msg):
697 """save the result of a completed task."""
696 """save the result of a completed task."""
698 client_id = idents[0]
697 client_id = idents[0]
699 try:
698 try:
700 msg = self.session.unpack_message(msg)
699 msg = self.session.unpack_message(msg)
701 except Exception:
700 except Exception:
702 self.log.error("task::invalid task result message send to %r: %r"%(
701 self.log.error("task::invalid task result message send to %r: %r"%(
703 client_id, msg), exc_info=True)
702 client_id, msg), exc_info=True)
704 return
703 return
705
704
706 parent = msg['parent_header']
705 parent = msg['parent_header']
707 if not parent:
706 if not parent:
708 # print msg
707 # print msg
709 self.log.warn("Task %r had no parent!"%msg)
708 self.log.warn("Task %r had no parent!"%msg)
710 return
709 return
711 msg_id = parent['msg_id']
710 msg_id = parent['msg_id']
712 if msg_id in self.unassigned:
711 if msg_id in self.unassigned:
713 self.unassigned.remove(msg_id)
712 self.unassigned.remove(msg_id)
714
713
715 header = msg['header']
714 header = msg['header']
716 engine_uuid = header.get('engine', None)
715 engine_uuid = header.get('engine', None)
717 eid = self.by_ident.get(engine_uuid, None)
716 eid = self.by_ident.get(engine_uuid, None)
718
717
719 if msg_id in self.pending:
718 if msg_id in self.pending:
720 self.pending.remove(msg_id)
719 self.pending.remove(msg_id)
721 self.all_completed.add(msg_id)
720 self.all_completed.add(msg_id)
722 if eid is not None:
721 if eid is not None:
723 self.completed[eid].append(msg_id)
722 self.completed[eid].append(msg_id)
724 if msg_id in self.tasks[eid]:
723 if msg_id in self.tasks[eid]:
725 self.tasks[eid].remove(msg_id)
724 self.tasks[eid].remove(msg_id)
726 completed = header['date']
725 completed = header['date']
727 started = header.get('started', None)
726 started = header.get('started', None)
728 result = {
727 result = {
729 'result_header' : header,
728 'result_header' : header,
730 'result_content': msg['content'],
729 'result_content': msg['content'],
731 'started' : started,
730 'started' : started,
732 'completed' : completed,
731 'completed' : completed,
733 'engine_uuid': engine_uuid
732 'engine_uuid': engine_uuid
734 }
733 }
735
734
736 result['result_buffers'] = msg['buffers']
735 result['result_buffers'] = msg['buffers']
737 try:
736 try:
738 self.db.update_record(msg_id, result)
737 self.db.update_record(msg_id, result)
739 except Exception:
738 except Exception:
740 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
739 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
741
740
742 else:
741 else:
743 self.log.debug("task::unknown task %r finished"%msg_id)
742 self.log.debug("task::unknown task %r finished"%msg_id)
744
743
745 def save_task_destination(self, idents, msg):
744 def save_task_destination(self, idents, msg):
746 try:
745 try:
747 msg = self.session.unpack_message(msg, content=True)
746 msg = self.session.unpack_message(msg, content=True)
748 except Exception:
747 except Exception:
749 self.log.error("task::invalid task tracking message", exc_info=True)
748 self.log.error("task::invalid task tracking message", exc_info=True)
750 return
749 return
751 content = msg['content']
750 content = msg['content']
752 # print (content)
751 # print (content)
753 msg_id = content['msg_id']
752 msg_id = content['msg_id']
754 engine_uuid = content['engine_id']
753 engine_uuid = content['engine_id']
755 eid = self.by_ident[util.asbytes(engine_uuid)]
754 eid = self.by_ident[util.asbytes(engine_uuid)]
756
755
757 self.log.info("task::task %r arrived on %r"%(msg_id, eid))
756 self.log.info("task::task %r arrived on %r"%(msg_id, eid))
758 if msg_id in self.unassigned:
757 if msg_id in self.unassigned:
759 self.unassigned.remove(msg_id)
758 self.unassigned.remove(msg_id)
760 # else:
759 # else:
761 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
760 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
762
761
763 self.tasks[eid].append(msg_id)
762 self.tasks[eid].append(msg_id)
764 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
763 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
765 try:
764 try:
766 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
765 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
767 except Exception:
766 except Exception:
768 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
767 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
769
768
770
769
771 def mia_task_request(self, idents, msg):
770 def mia_task_request(self, idents, msg):
772 raise NotImplementedError
771 raise NotImplementedError
773 client_id = idents[0]
772 client_id = idents[0]
774 # content = dict(mia=self.mia,status='ok')
773 # content = dict(mia=self.mia,status='ok')
775 # self.session.send('mia_reply', content=content, idents=client_id)
774 # self.session.send('mia_reply', content=content, idents=client_id)
776
775
777
776
778 #--------------------- IOPub Traffic ------------------------------
777 #--------------------- IOPub Traffic ------------------------------
779
778
780 def save_iopub_message(self, topics, msg):
779 def save_iopub_message(self, topics, msg):
781 """save an iopub message into the db"""
780 """save an iopub message into the db"""
782 # print (topics)
781 # print (topics)
783 try:
782 try:
784 msg = self.session.unpack_message(msg, content=True)
783 msg = self.session.unpack_message(msg, content=True)
785 except Exception:
784 except Exception:
786 self.log.error("iopub::invalid IOPub message", exc_info=True)
785 self.log.error("iopub::invalid IOPub message", exc_info=True)
787 return
786 return
788
787
789 parent = msg['parent_header']
788 parent = msg['parent_header']
790 if not parent:
789 if not parent:
791 self.log.error("iopub::invalid IOPub message: %r"%msg)
790 self.log.error("iopub::invalid IOPub message: %r"%msg)
792 return
791 return
793 msg_id = parent['msg_id']
792 msg_id = parent['msg_id']
794 msg_type = msg['msg_type']
793 msg_type = msg['msg_type']
795 content = msg['content']
794 content = msg['content']
796
795
797 # ensure msg_id is in db
796 # ensure msg_id is in db
798 try:
797 try:
799 rec = self.db.get_record(msg_id)
798 rec = self.db.get_record(msg_id)
800 except KeyError:
799 except KeyError:
801 rec = empty_record()
800 rec = empty_record()
802 rec['msg_id'] = msg_id
801 rec['msg_id'] = msg_id
803 self.db.add_record(msg_id, rec)
802 self.db.add_record(msg_id, rec)
804 # stream
803 # stream
805 d = {}
804 d = {}
806 if msg_type == 'stream':
805 if msg_type == 'stream':
807 name = content['name']
806 name = content['name']
808 s = rec[name] or ''
807 s = rec[name] or ''
809 d[name] = s + content['data']
808 d[name] = s + content['data']
810
809
811 elif msg_type == 'pyerr':
810 elif msg_type == 'pyerr':
812 d['pyerr'] = content
811 d['pyerr'] = content
813 elif msg_type == 'pyin':
812 elif msg_type == 'pyin':
814 d['pyin'] = content['code']
813 d['pyin'] = content['code']
815 else:
814 else:
816 d[msg_type] = content.get('data', '')
815 d[msg_type] = content.get('data', '')
817
816
818 try:
817 try:
819 self.db.update_record(msg_id, d)
818 self.db.update_record(msg_id, d)
820 except Exception:
819 except Exception:
821 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
820 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
822
821
823
822
824
823
825 #-------------------------------------------------------------------------
824 #-------------------------------------------------------------------------
826 # Registration requests
825 # Registration requests
827 #-------------------------------------------------------------------------
826 #-------------------------------------------------------------------------
828
827
829 def connection_request(self, client_id, msg):
828 def connection_request(self, client_id, msg):
830 """Reply with connection addresses for clients."""
829 """Reply with connection addresses for clients."""
831 self.log.info("client::client %r connected"%client_id)
830 self.log.info("client::client %r connected"%client_id)
832 content = dict(status='ok')
831 content = dict(status='ok')
833 content.update(self.client_info)
832 content.update(self.client_info)
834 jsonable = {}
833 jsonable = {}
835 for k,v in self.keytable.iteritems():
834 for k,v in self.keytable.iteritems():
836 if v not in self.dead_engines:
835 if v not in self.dead_engines:
837 jsonable[str(k)] = v.decode('ascii')
836 jsonable[str(k)] = v.decode('ascii')
838 content['engines'] = jsonable
837 content['engines'] = jsonable
839 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
838 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
840
839
841 def register_engine(self, reg, msg):
840 def register_engine(self, reg, msg):
842 """Register a new engine."""
841 """Register a new engine."""
843 content = msg['content']
842 content = msg['content']
844 try:
843 try:
845 queue = util.asbytes(content['queue'])
844 queue = util.asbytes(content['queue'])
846 except KeyError:
845 except KeyError:
847 self.log.error("registration::queue not specified", exc_info=True)
846 self.log.error("registration::queue not specified", exc_info=True)
848 return
847 return
849 heart = content.get('heartbeat', None)
848 heart = content.get('heartbeat', None)
850 if heart:
849 if heart:
851 heart = util.asbytes(heart)
850 heart = util.asbytes(heart)
852 """register a new engine, and create the socket(s) necessary"""
851 """register a new engine, and create the socket(s) necessary"""
853 eid = self._next_id
852 eid = self._next_id
854 # print (eid, queue, reg, heart)
853 # print (eid, queue, reg, heart)
855
854
856 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
855 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
857
856
858 content = dict(id=eid,status='ok')
857 content = dict(id=eid,status='ok')
859 content.update(self.engine_info)
858 content.update(self.engine_info)
860 # check if requesting available IDs:
859 # check if requesting available IDs:
861 if queue in self.by_ident:
860 if queue in self.by_ident:
862 try:
861 try:
863 raise KeyError("queue_id %r in use"%queue)
862 raise KeyError("queue_id %r in use"%queue)
864 except:
863 except:
865 content = error.wrap_exception()
864 content = error.wrap_exception()
866 self.log.error("queue_id %r in use"%queue, exc_info=True)
865 self.log.error("queue_id %r in use"%queue, exc_info=True)
867 elif heart in self.hearts: # need to check unique hearts?
866 elif heart in self.hearts: # need to check unique hearts?
868 try:
867 try:
869 raise KeyError("heart_id %r in use"%heart)
868 raise KeyError("heart_id %r in use"%heart)
870 except:
869 except:
871 self.log.error("heart_id %r in use"%heart, exc_info=True)
870 self.log.error("heart_id %r in use"%heart, exc_info=True)
872 content = error.wrap_exception()
871 content = error.wrap_exception()
873 else:
872 else:
874 for h, pack in self.incoming_registrations.iteritems():
873 for h, pack in self.incoming_registrations.iteritems():
875 if heart == h:
874 if heart == h:
876 try:
875 try:
877 raise KeyError("heart_id %r in use"%heart)
876 raise KeyError("heart_id %r in use"%heart)
878 except:
877 except:
879 self.log.error("heart_id %r in use"%heart, exc_info=True)
878 self.log.error("heart_id %r in use"%heart, exc_info=True)
880 content = error.wrap_exception()
879 content = error.wrap_exception()
881 break
880 break
882 elif queue == pack[1]:
881 elif queue == pack[1]:
883 try:
882 try:
884 raise KeyError("queue_id %r in use"%queue)
883 raise KeyError("queue_id %r in use"%queue)
885 except:
884 except:
886 self.log.error("queue_id %r in use"%queue, exc_info=True)
885 self.log.error("queue_id %r in use"%queue, exc_info=True)
887 content = error.wrap_exception()
886 content = error.wrap_exception()
888 break
887 break
889
888
890 msg = self.session.send(self.query, "registration_reply",
889 msg = self.session.send(self.query, "registration_reply",
891 content=content,
890 content=content,
892 ident=reg)
891 ident=reg)
893
892
894 if content['status'] == 'ok':
893 if content['status'] == 'ok':
895 if heart in self.heartmonitor.hearts:
894 if heart in self.heartmonitor.hearts:
896 # already beating
895 # already beating
897 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
896 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
898 self.finish_registration(heart)
897 self.finish_registration(heart)
899 else:
898 else:
900 purge = lambda : self._purge_stalled_registration(heart)
899 purge = lambda : self._purge_stalled_registration(heart)
901 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
900 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
902 dc.start()
901 dc.start()
903 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
902 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
904 else:
903 else:
905 self.log.error("registration::registration %i failed: %r"%(eid, content['evalue']))
904 self.log.error("registration::registration %i failed: %r"%(eid, content['evalue']))
906 return eid
905 return eid
907
906
908 def unregister_engine(self, ident, msg):
907 def unregister_engine(self, ident, msg):
909 """Unregister an engine that explicitly requested to leave."""
908 """Unregister an engine that explicitly requested to leave."""
910 try:
909 try:
911 eid = msg['content']['id']
910 eid = msg['content']['id']
912 except:
911 except:
913 self.log.error("registration::bad engine id for unregistration: %r"%ident, exc_info=True)
912 self.log.error("registration::bad engine id for unregistration: %r"%ident, exc_info=True)
914 return
913 return
915 self.log.info("registration::unregister_engine(%r)"%eid)
914 self.log.info("registration::unregister_engine(%r)"%eid)
916 # print (eid)
915 # print (eid)
917 uuid = self.keytable[eid]
916 uuid = self.keytable[eid]
918 content=dict(id=eid, queue=uuid.decode('ascii'))
917 content=dict(id=eid, queue=uuid.decode('ascii'))
919 self.dead_engines.add(uuid)
918 self.dead_engines.add(uuid)
920 # self.ids.remove(eid)
919 # self.ids.remove(eid)
921 # uuid = self.keytable.pop(eid)
920 # uuid = self.keytable.pop(eid)
922 #
921 #
923 # ec = self.engines.pop(eid)
922 # ec = self.engines.pop(eid)
924 # self.hearts.pop(ec.heartbeat)
923 # self.hearts.pop(ec.heartbeat)
925 # self.by_ident.pop(ec.queue)
924 # self.by_ident.pop(ec.queue)
926 # self.completed.pop(eid)
925 # self.completed.pop(eid)
927 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
926 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
928 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
927 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
929 dc.start()
928 dc.start()
930 ############## TODO: HANDLE IT ################
929 ############## TODO: HANDLE IT ################
931
930
932 if self.notifier:
931 if self.notifier:
933 self.session.send(self.notifier, "unregistration_notification", content=content)
932 self.session.send(self.notifier, "unregistration_notification", content=content)
934
933
935 def _handle_stranded_msgs(self, eid, uuid):
934 def _handle_stranded_msgs(self, eid, uuid):
936 """Handle messages known to be on an engine when the engine unregisters.
935 """Handle messages known to be on an engine when the engine unregisters.
937
936
938 It is possible that this will fire prematurely - that is, an engine will
937 It is possible that this will fire prematurely - that is, an engine will
939 go down after completing a result, and the client will be notified
938 go down after completing a result, and the client will be notified
940 that the result failed and later receive the actual result.
939 that the result failed and later receive the actual result.
941 """
940 """
942
941
943 outstanding = self.queues[eid]
942 outstanding = self.queues[eid]
944
943
945 for msg_id in outstanding:
944 for msg_id in outstanding:
946 self.pending.remove(msg_id)
945 self.pending.remove(msg_id)
947 self.all_completed.add(msg_id)
946 self.all_completed.add(msg_id)
948 try:
947 try:
949 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
948 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
950 except:
949 except:
951 content = error.wrap_exception()
950 content = error.wrap_exception()
952 # build a fake header:
951 # build a fake header:
953 header = {}
952 header = {}
954 header['engine'] = uuid
953 header['engine'] = uuid
955 header['date'] = datetime.now()
954 header['date'] = datetime.now()
956 rec = dict(result_content=content, result_header=header, result_buffers=[])
955 rec = dict(result_content=content, result_header=header, result_buffers=[])
957 rec['completed'] = header['date']
956 rec['completed'] = header['date']
958 rec['engine_uuid'] = uuid
957 rec['engine_uuid'] = uuid
959 try:
958 try:
960 self.db.update_record(msg_id, rec)
959 self.db.update_record(msg_id, rec)
961 except Exception:
960 except Exception:
962 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
961 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
963
962
964
963
965 def finish_registration(self, heart):
964 def finish_registration(self, heart):
966 """Second half of engine registration, called after our HeartMonitor
965 """Second half of engine registration, called after our HeartMonitor
967 has received a beat from the Engine's Heart."""
966 has received a beat from the Engine's Heart."""
968 try:
967 try:
969 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
968 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
970 except KeyError:
969 except KeyError:
971 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
970 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
972 return
971 return
973 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
972 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
974 if purge is not None:
973 if purge is not None:
975 purge.stop()
974 purge.stop()
976 control = queue
975 control = queue
977 self.ids.add(eid)
976 self.ids.add(eid)
978 self.keytable[eid] = queue
977 self.keytable[eid] = queue
979 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
978 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
980 control=control, heartbeat=heart)
979 control=control, heartbeat=heart)
981 self.by_ident[queue] = eid
980 self.by_ident[queue] = eid
982 self.queues[eid] = list()
981 self.queues[eid] = list()
983 self.tasks[eid] = list()
982 self.tasks[eid] = list()
984 self.completed[eid] = list()
983 self.completed[eid] = list()
985 self.hearts[heart] = eid
984 self.hearts[heart] = eid
986 content = dict(id=eid, queue=self.engines[eid].queue.decode('ascii'))
985 content = dict(id=eid, queue=self.engines[eid].queue.decode('ascii'))
987 if self.notifier:
986 if self.notifier:
988 self.session.send(self.notifier, "registration_notification", content=content)
987 self.session.send(self.notifier, "registration_notification", content=content)
989 self.log.info("engine::Engine Connected: %i"%eid)
988 self.log.info("engine::Engine Connected: %i"%eid)
990
989
991 def _purge_stalled_registration(self, heart):
990 def _purge_stalled_registration(self, heart):
992 if heart in self.incoming_registrations:
991 if heart in self.incoming_registrations:
993 eid = self.incoming_registrations.pop(heart)[0]
992 eid = self.incoming_registrations.pop(heart)[0]
994 self.log.info("registration::purging stalled registration: %i"%eid)
993 self.log.info("registration::purging stalled registration: %i"%eid)
995 else:
994 else:
996 pass
995 pass
997
996
998 #-------------------------------------------------------------------------
997 #-------------------------------------------------------------------------
999 # Client Requests
998 # Client Requests
1000 #-------------------------------------------------------------------------
999 #-------------------------------------------------------------------------
1001
1000
1002 def shutdown_request(self, client_id, msg):
1001 def shutdown_request(self, client_id, msg):
1003 """handle shutdown request."""
1002 """handle shutdown request."""
1004 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1003 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1005 # also notify other clients of shutdown
1004 # also notify other clients of shutdown
1006 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1005 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1007 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1006 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1008 dc.start()
1007 dc.start()
1009
1008
1010 def _shutdown(self):
1009 def _shutdown(self):
1011 self.log.info("hub::hub shutting down.")
1010 self.log.info("hub::hub shutting down.")
1012 time.sleep(0.1)
1011 time.sleep(0.1)
1013 sys.exit(0)
1012 sys.exit(0)
1014
1013
1015
1014
1016 def check_load(self, client_id, msg):
1015 def check_load(self, client_id, msg):
1017 content = msg['content']
1016 content = msg['content']
1018 try:
1017 try:
1019 targets = content['targets']
1018 targets = content['targets']
1020 targets = self._validate_targets(targets)
1019 targets = self._validate_targets(targets)
1021 except:
1020 except:
1022 content = error.wrap_exception()
1021 content = error.wrap_exception()
1023 self.session.send(self.query, "hub_error",
1022 self.session.send(self.query, "hub_error",
1024 content=content, ident=client_id)
1023 content=content, ident=client_id)
1025 return
1024 return
1026
1025
1027 content = dict(status='ok')
1026 content = dict(status='ok')
1028 # loads = {}
1027 # loads = {}
1029 for t in targets:
1028 for t in targets:
1030 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1029 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1031 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1030 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1032
1031
1033
1032
1034 def queue_status(self, client_id, msg):
1033 def queue_status(self, client_id, msg):
1035 """Return the Queue status of one or more targets.
1034 """Return the Queue status of one or more targets.
1036 if verbose: return the msg_ids
1035 if verbose: return the msg_ids
1037 else: return len of each type.
1036 else: return len of each type.
1038 keys: queue (pending MUX jobs)
1037 keys: queue (pending MUX jobs)
1039 tasks (pending Task jobs)
1038 tasks (pending Task jobs)
1040 completed (finished jobs from both queues)"""
1039 completed (finished jobs from both queues)"""
1041 content = msg['content']
1040 content = msg['content']
1042 targets = content['targets']
1041 targets = content['targets']
1043 try:
1042 try:
1044 targets = self._validate_targets(targets)
1043 targets = self._validate_targets(targets)
1045 except:
1044 except:
1046 content = error.wrap_exception()
1045 content = error.wrap_exception()
1047 self.session.send(self.query, "hub_error",
1046 self.session.send(self.query, "hub_error",
1048 content=content, ident=client_id)
1047 content=content, ident=client_id)
1049 return
1048 return
1050 verbose = content.get('verbose', False)
1049 verbose = content.get('verbose', False)
1051 content = dict(status='ok')
1050 content = dict(status='ok')
1052 for t in targets:
1051 for t in targets:
1053 queue = self.queues[t]
1052 queue = self.queues[t]
1054 completed = self.completed[t]
1053 completed = self.completed[t]
1055 tasks = self.tasks[t]
1054 tasks = self.tasks[t]
1056 if not verbose:
1055 if not verbose:
1057 queue = len(queue)
1056 queue = len(queue)
1058 completed = len(completed)
1057 completed = len(completed)
1059 tasks = len(tasks)
1058 tasks = len(tasks)
1060 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1059 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1061 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1060 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1062 # print (content)
1061 # print (content)
1063 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1062 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1064
1063
1065 def purge_results(self, client_id, msg):
1064 def purge_results(self, client_id, msg):
1066 """Purge results from memory. This method is more valuable before we move
1065 """Purge results from memory. This method is more valuable before we move
1067 to a DB based message storage mechanism."""
1066 to a DB based message storage mechanism."""
1068 content = msg['content']
1067 content = msg['content']
1069 self.log.info("Dropping records with %s", content)
1068 self.log.info("Dropping records with %s", content)
1070 msg_ids = content.get('msg_ids', [])
1069 msg_ids = content.get('msg_ids', [])
1071 reply = dict(status='ok')
1070 reply = dict(status='ok')
1072 if msg_ids == 'all':
1071 if msg_ids == 'all':
1073 try:
1072 try:
1074 self.db.drop_matching_records(dict(completed={'$ne':None}))
1073 self.db.drop_matching_records(dict(completed={'$ne':None}))
1075 except Exception:
1074 except Exception:
1076 reply = error.wrap_exception()
1075 reply = error.wrap_exception()
1077 else:
1076 else:
1078 pending = filter(lambda m: m in self.pending, msg_ids)
1077 pending = filter(lambda m: m in self.pending, msg_ids)
1079 if pending:
1078 if pending:
1080 try:
1079 try:
1081 raise IndexError("msg pending: %r"%pending[0])
1080 raise IndexError("msg pending: %r"%pending[0])
1082 except:
1081 except:
1083 reply = error.wrap_exception()
1082 reply = error.wrap_exception()
1084 else:
1083 else:
1085 try:
1084 try:
1086 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1085 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1087 except Exception:
1086 except Exception:
1088 reply = error.wrap_exception()
1087 reply = error.wrap_exception()
1089
1088
1090 if reply['status'] == 'ok':
1089 if reply['status'] == 'ok':
1091 eids = content.get('engine_ids', [])
1090 eids = content.get('engine_ids', [])
1092 for eid in eids:
1091 for eid in eids:
1093 if eid not in self.engines:
1092 if eid not in self.engines:
1094 try:
1093 try:
1095 raise IndexError("No such engine: %i"%eid)
1094 raise IndexError("No such engine: %i"%eid)
1096 except:
1095 except:
1097 reply = error.wrap_exception()
1096 reply = error.wrap_exception()
1098 break
1097 break
1099 uid = self.engines[eid].queue
1098 uid = self.engines[eid].queue
1100 try:
1099 try:
1101 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1100 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1102 except Exception:
1101 except Exception:
1103 reply = error.wrap_exception()
1102 reply = error.wrap_exception()
1104 break
1103 break
1105
1104
1106 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1105 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1107
1106
1108 def resubmit_task(self, client_id, msg):
1107 def resubmit_task(self, client_id, msg):
1109 """Resubmit one or more tasks."""
1108 """Resubmit one or more tasks."""
1110 def finish(reply):
1109 def finish(reply):
1111 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1110 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1112
1111
1113 content = msg['content']
1112 content = msg['content']
1114 msg_ids = content['msg_ids']
1113 msg_ids = content['msg_ids']
1115 reply = dict(status='ok')
1114 reply = dict(status='ok')
1116 try:
1115 try:
1117 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1116 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1118 'header', 'content', 'buffers'])
1117 'header', 'content', 'buffers'])
1119 except Exception:
1118 except Exception:
1120 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1119 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1121 return finish(error.wrap_exception())
1120 return finish(error.wrap_exception())
1122
1121
1123 # validate msg_ids
1122 # validate msg_ids
1124 found_ids = [ rec['msg_id'] for rec in records ]
1123 found_ids = [ rec['msg_id'] for rec in records ]
1125 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1124 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1126 if len(records) > len(msg_ids):
1125 if len(records) > len(msg_ids):
1127 try:
1126 try:
1128 raise RuntimeError("DB appears to be in an inconsistent state."
1127 raise RuntimeError("DB appears to be in an inconsistent state."
1129 "More matching records were found than should exist")
1128 "More matching records were found than should exist")
1130 except Exception:
1129 except Exception:
1131 return finish(error.wrap_exception())
1130 return finish(error.wrap_exception())
1132 elif len(records) < len(msg_ids):
1131 elif len(records) < len(msg_ids):
1133 missing = [ m for m in msg_ids if m not in found_ids ]
1132 missing = [ m for m in msg_ids if m not in found_ids ]
1134 try:
1133 try:
1135 raise KeyError("No such msg(s): %r"%missing)
1134 raise KeyError("No such msg(s): %r"%missing)
1136 except KeyError:
1135 except KeyError:
1137 return finish(error.wrap_exception())
1136 return finish(error.wrap_exception())
1138 elif invalid_ids:
1137 elif invalid_ids:
1139 msg_id = invalid_ids[0]
1138 msg_id = invalid_ids[0]
1140 try:
1139 try:
1141 raise ValueError("Task %r appears to be inflight"%(msg_id))
1140 raise ValueError("Task %r appears to be inflight"%(msg_id))
1142 except Exception:
1141 except Exception:
1143 return finish(error.wrap_exception())
1142 return finish(error.wrap_exception())
1144
1143
1145 # clear the existing records
1144 # clear the existing records
1146 now = datetime.now()
1145 now = datetime.now()
1147 rec = empty_record()
1146 rec = empty_record()
1148 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1147 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1149 rec['resubmitted'] = now
1148 rec['resubmitted'] = now
1150 rec['queue'] = 'task'
1149 rec['queue'] = 'task'
1151 rec['client_uuid'] = client_id[0]
1150 rec['client_uuid'] = client_id[0]
1152 try:
1151 try:
1153 for msg_id in msg_ids:
1152 for msg_id in msg_ids:
1154 self.all_completed.discard(msg_id)
1153 self.all_completed.discard(msg_id)
1155 self.db.update_record(msg_id, rec)
1154 self.db.update_record(msg_id, rec)
1156 except Exception:
1155 except Exception:
1157 self.log.error('db::db error upating record', exc_info=True)
1156 self.log.error('db::db error upating record', exc_info=True)
1158 reply = error.wrap_exception()
1157 reply = error.wrap_exception()
1159 else:
1158 else:
1160 # send the messages
1159 # send the messages
1161 for rec in records:
1160 for rec in records:
1162 header = rec['header']
1161 header = rec['header']
1163 # include resubmitted in header to prevent digest collision
1162 # include resubmitted in header to prevent digest collision
1164 header['resubmitted'] = now
1163 header['resubmitted'] = now
1165 msg = self.session.msg(header['msg_type'])
1164 msg = self.session.msg(header['msg_type'])
1166 msg['content'] = rec['content']
1165 msg['content'] = rec['content']
1167 msg['header'] = header
1166 msg['header'] = header
1168 msg['msg_id'] = rec['msg_id']
1167 msg['msg_id'] = rec['msg_id']
1169 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1168 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1170
1169
1171 finish(dict(status='ok'))
1170 finish(dict(status='ok'))
1172
1171
1173
1172
1174 def _extract_record(self, rec):
1173 def _extract_record(self, rec):
1175 """decompose a TaskRecord dict into subsection of reply for get_result"""
1174 """decompose a TaskRecord dict into subsection of reply for get_result"""
1176 io_dict = {}
1175 io_dict = {}
1177 for key in 'pyin pyout pyerr stdout stderr'.split():
1176 for key in 'pyin pyout pyerr stdout stderr'.split():
1178 io_dict[key] = rec[key]
1177 io_dict[key] = rec[key]
1179 content = { 'result_content': rec['result_content'],
1178 content = { 'result_content': rec['result_content'],
1180 'header': rec['header'],
1179 'header': rec['header'],
1181 'result_header' : rec['result_header'],
1180 'result_header' : rec['result_header'],
1182 'io' : io_dict,
1181 'io' : io_dict,
1183 }
1182 }
1184 if rec['result_buffers']:
1183 if rec['result_buffers']:
1185 buffers = map(bytes, rec['result_buffers'])
1184 buffers = map(bytes, rec['result_buffers'])
1186 else:
1185 else:
1187 buffers = []
1186 buffers = []
1188
1187
1189 return content, buffers
1188 return content, buffers
1190
1189
1191 def get_results(self, client_id, msg):
1190 def get_results(self, client_id, msg):
1192 """Get the result of 1 or more messages."""
1191 """Get the result of 1 or more messages."""
1193 content = msg['content']
1192 content = msg['content']
1194 msg_ids = sorted(set(content['msg_ids']))
1193 msg_ids = sorted(set(content['msg_ids']))
1195 statusonly = content.get('status_only', False)
1194 statusonly = content.get('status_only', False)
1196 pending = []
1195 pending = []
1197 completed = []
1196 completed = []
1198 content = dict(status='ok')
1197 content = dict(status='ok')
1199 content['pending'] = pending
1198 content['pending'] = pending
1200 content['completed'] = completed
1199 content['completed'] = completed
1201 buffers = []
1200 buffers = []
1202 if not statusonly:
1201 if not statusonly:
1203 try:
1202 try:
1204 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1203 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1205 # turn match list into dict, for faster lookup
1204 # turn match list into dict, for faster lookup
1206 records = {}
1205 records = {}
1207 for rec in matches:
1206 for rec in matches:
1208 records[rec['msg_id']] = rec
1207 records[rec['msg_id']] = rec
1209 except Exception:
1208 except Exception:
1210 content = error.wrap_exception()
1209 content = error.wrap_exception()
1211 self.session.send(self.query, "result_reply", content=content,
1210 self.session.send(self.query, "result_reply", content=content,
1212 parent=msg, ident=client_id)
1211 parent=msg, ident=client_id)
1213 return
1212 return
1214 else:
1213 else:
1215 records = {}
1214 records = {}
1216 for msg_id in msg_ids:
1215 for msg_id in msg_ids:
1217 if msg_id in self.pending:
1216 if msg_id in self.pending:
1218 pending.append(msg_id)
1217 pending.append(msg_id)
1219 elif msg_id in self.all_completed:
1218 elif msg_id in self.all_completed:
1220 completed.append(msg_id)
1219 completed.append(msg_id)
1221 if not statusonly:
1220 if not statusonly:
1222 c,bufs = self._extract_record(records[msg_id])
1221 c,bufs = self._extract_record(records[msg_id])
1223 content[msg_id] = c
1222 content[msg_id] = c
1224 buffers.extend(bufs)
1223 buffers.extend(bufs)
1225 elif msg_id in records:
1224 elif msg_id in records:
1226 if rec['completed']:
1225 if rec['completed']:
1227 completed.append(msg_id)
1226 completed.append(msg_id)
1228 c,bufs = self._extract_record(records[msg_id])
1227 c,bufs = self._extract_record(records[msg_id])
1229 content[msg_id] = c
1228 content[msg_id] = c
1230 buffers.extend(bufs)
1229 buffers.extend(bufs)
1231 else:
1230 else:
1232 pending.append(msg_id)
1231 pending.append(msg_id)
1233 else:
1232 else:
1234 try:
1233 try:
1235 raise KeyError('No such message: '+msg_id)
1234 raise KeyError('No such message: '+msg_id)
1236 except:
1235 except:
1237 content = error.wrap_exception()
1236 content = error.wrap_exception()
1238 break
1237 break
1239 self.session.send(self.query, "result_reply", content=content,
1238 self.session.send(self.query, "result_reply", content=content,
1240 parent=msg, ident=client_id,
1239 parent=msg, ident=client_id,
1241 buffers=buffers)
1240 buffers=buffers)
1242
1241
1243 def get_history(self, client_id, msg):
1242 def get_history(self, client_id, msg):
1244 """Get a list of all msg_ids in our DB records"""
1243 """Get a list of all msg_ids in our DB records"""
1245 try:
1244 try:
1246 msg_ids = self.db.get_history()
1245 msg_ids = self.db.get_history()
1247 except Exception as e:
1246 except Exception as e:
1248 content = error.wrap_exception()
1247 content = error.wrap_exception()
1249 else:
1248 else:
1250 content = dict(status='ok', history=msg_ids)
1249 content = dict(status='ok', history=msg_ids)
1251
1250
1252 self.session.send(self.query, "history_reply", content=content,
1251 self.session.send(self.query, "history_reply", content=content,
1253 parent=msg, ident=client_id)
1252 parent=msg, ident=client_id)
1254
1253
1255 def db_query(self, client_id, msg):
1254 def db_query(self, client_id, msg):
1256 """Perform a raw query on the task record database."""
1255 """Perform a raw query on the task record database."""
1257 content = msg['content']
1256 content = msg['content']
1258 query = content.get('query', {})
1257 query = content.get('query', {})
1259 keys = content.get('keys', None)
1258 keys = content.get('keys', None)
1260 buffers = []
1259 buffers = []
1261 empty = list()
1260 empty = list()
1262 try:
1261 try:
1263 records = self.db.find_records(query, keys)
1262 records = self.db.find_records(query, keys)
1264 except Exception as e:
1263 except Exception as e:
1265 content = error.wrap_exception()
1264 content = error.wrap_exception()
1266 else:
1265 else:
1267 # extract buffers from reply content:
1266 # extract buffers from reply content:
1268 if keys is not None:
1267 if keys is not None:
1269 buffer_lens = [] if 'buffers' in keys else None
1268 buffer_lens = [] if 'buffers' in keys else None
1270 result_buffer_lens = [] if 'result_buffers' in keys else None
1269 result_buffer_lens = [] if 'result_buffers' in keys else None
1271 else:
1270 else:
1272 buffer_lens = []
1271 buffer_lens = []
1273 result_buffer_lens = []
1272 result_buffer_lens = []
1274
1273
1275 for rec in records:
1274 for rec in records:
1276 # buffers may be None, so double check
1275 # buffers may be None, so double check
1277 if buffer_lens is not None:
1276 if buffer_lens is not None:
1278 b = rec.pop('buffers', empty) or empty
1277 b = rec.pop('buffers', empty) or empty
1279 buffer_lens.append(len(b))
1278 buffer_lens.append(len(b))
1280 buffers.extend(b)
1279 buffers.extend(b)
1281 if result_buffer_lens is not None:
1280 if result_buffer_lens is not None:
1282 rb = rec.pop('result_buffers', empty) or empty
1281 rb = rec.pop('result_buffers', empty) or empty
1283 result_buffer_lens.append(len(rb))
1282 result_buffer_lens.append(len(rb))
1284 buffers.extend(rb)
1283 buffers.extend(rb)
1285 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1284 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1286 result_buffer_lens=result_buffer_lens)
1285 result_buffer_lens=result_buffer_lens)
1287 # self.log.debug (content)
1286 # self.log.debug (content)
1288 self.session.send(self.query, "db_reply", content=content,
1287 self.session.send(self.query, "db_reply", content=content,
1289 parent=msg, ident=client_id,
1288 parent=msg, ident=client_id,
1290 buffers=buffers)
1289 buffers=buffers)
1291
1290
@@ -1,174 +1,173 b''
1 #!/usr/bin/env python
2 """A simple engine that talks to a controller over 0MQ.
1 """A simple engine that talks to a controller over 0MQ.
3 it handles registration, etc. and launches a kernel
2 it handles registration, etc. and launches a kernel
4 connected to the Controller's Schedulers.
3 connected to the Controller's Schedulers.
5
4
6 Authors:
5 Authors:
7
6
8 * Min RK
7 * Min RK
9 """
8 """
10 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
11 # Copyright (C) 2010-2011 The IPython Development Team
10 # Copyright (C) 2010-2011 The IPython Development Team
12 #
11 #
13 # Distributed under the terms of the BSD License. The full license is in
12 # Distributed under the terms of the BSD License. The full license is in
14 # the file COPYING, distributed as part of this software.
13 # the file COPYING, distributed as part of this software.
15 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
16
15
17 from __future__ import print_function
16 from __future__ import print_function
18
17
19 import sys
18 import sys
20 import time
19 import time
21
20
22 import zmq
21 import zmq
23 from zmq.eventloop import ioloop, zmqstream
22 from zmq.eventloop import ioloop, zmqstream
24
23
25 # internal
24 # internal
26 from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode, CBytes
25 from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode, CBytes
27 # from IPython.utils.localinterfaces import LOCALHOST
26 # from IPython.utils.localinterfaces import LOCALHOST
28
27
29 from IPython.parallel.controller.heartmonitor import Heart
28 from IPython.parallel.controller.heartmonitor import Heart
30 from IPython.parallel.factory import RegistrationFactory
29 from IPython.parallel.factory import RegistrationFactory
31 from IPython.parallel.util import disambiguate_url, asbytes
30 from IPython.parallel.util import disambiguate_url, asbytes
32
31
33 from IPython.zmq.session import Message
32 from IPython.zmq.session import Message
34
33
35 from .streamkernel import Kernel
34 from .streamkernel import Kernel
36
35
37 class EngineFactory(RegistrationFactory):
36 class EngineFactory(RegistrationFactory):
38 """IPython engine"""
37 """IPython engine"""
39
38
40 # configurables:
39 # configurables:
41 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True,
40 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True,
42 help="""The OutStream for handling stdout/err.
41 help="""The OutStream for handling stdout/err.
43 Typically 'IPython.zmq.iostream.OutStream'""")
42 Typically 'IPython.zmq.iostream.OutStream'""")
44 display_hook_factory=Type('IPython.zmq.displayhook.ZMQDisplayHook', config=True,
43 display_hook_factory=Type('IPython.zmq.displayhook.ZMQDisplayHook', config=True,
45 help="""The class for handling displayhook.
44 help="""The class for handling displayhook.
46 Typically 'IPython.zmq.displayhook.ZMQDisplayHook'""")
45 Typically 'IPython.zmq.displayhook.ZMQDisplayHook'""")
47 location=Unicode(config=True,
46 location=Unicode(config=True,
48 help="""The location (an IP address) of the controller. This is
47 help="""The location (an IP address) of the controller. This is
49 used for disambiguating URLs, to determine whether
48 used for disambiguating URLs, to determine whether
50 loopback should be used to connect or the public address.""")
49 loopback should be used to connect or the public address.""")
51 timeout=CFloat(2,config=True,
50 timeout=CFloat(2,config=True,
52 help="""The time (in seconds) to wait for the Controller to respond
51 help="""The time (in seconds) to wait for the Controller to respond
53 to registration requests before giving up.""")
52 to registration requests before giving up.""")
54
53
55 # not configurable:
54 # not configurable:
56 user_ns=Dict()
55 user_ns=Dict()
57 id=Int(allow_none=True)
56 id=Int(allow_none=True)
58 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
57 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
59 kernel=Instance(Kernel)
58 kernel=Instance(Kernel)
60
59
61 bident = CBytes()
60 bident = CBytes()
62 ident = Unicode()
61 ident = Unicode()
63 def _ident_changed(self, name, old, new):
62 def _ident_changed(self, name, old, new):
64 self.bident = asbytes(new)
63 self.bident = asbytes(new)
65
64
66
65
67 def __init__(self, **kwargs):
66 def __init__(self, **kwargs):
68 super(EngineFactory, self).__init__(**kwargs)
67 super(EngineFactory, self).__init__(**kwargs)
69 self.ident = self.session.session
68 self.ident = self.session.session
70 ctx = self.context
69 ctx = self.context
71
70
72 reg = ctx.socket(zmq.XREQ)
71 reg = ctx.socket(zmq.XREQ)
73 reg.setsockopt(zmq.IDENTITY, self.bident)
72 reg.setsockopt(zmq.IDENTITY, self.bident)
74 reg.connect(self.url)
73 reg.connect(self.url)
75 self.registrar = zmqstream.ZMQStream(reg, self.loop)
74 self.registrar = zmqstream.ZMQStream(reg, self.loop)
76
75
77 def register(self):
76 def register(self):
78 """send the registration_request"""
77 """send the registration_request"""
79
78
80 self.log.info("Registering with controller at %s"%self.url)
79 self.log.info("Registering with controller at %s"%self.url)
81 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
80 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
82 self.registrar.on_recv(self.complete_registration)
81 self.registrar.on_recv(self.complete_registration)
83 # print (self.session.key)
82 # print (self.session.key)
84 self.session.send(self.registrar, "registration_request",content=content)
83 self.session.send(self.registrar, "registration_request",content=content)
85
84
86 def complete_registration(self, msg):
85 def complete_registration(self, msg):
87 # print msg
86 # print msg
88 self._abort_dc.stop()
87 self._abort_dc.stop()
89 ctx = self.context
88 ctx = self.context
90 loop = self.loop
89 loop = self.loop
91 identity = self.bident
90 identity = self.bident
92 idents,msg = self.session.feed_identities(msg)
91 idents,msg = self.session.feed_identities(msg)
93 msg = Message(self.session.unpack_message(msg))
92 msg = Message(self.session.unpack_message(msg))
94
93
95 if msg.content.status == 'ok':
94 if msg.content.status == 'ok':
96 self.id = int(msg.content.id)
95 self.id = int(msg.content.id)
97
96
98 # create Shell Streams (MUX, Task, etc.):
97 # create Shell Streams (MUX, Task, etc.):
99 queue_addr = msg.content.mux
98 queue_addr = msg.content.mux
100 shell_addrs = [ str(queue_addr) ]
99 shell_addrs = [ str(queue_addr) ]
101 task_addr = msg.content.task
100 task_addr = msg.content.task
102 if task_addr:
101 if task_addr:
103 shell_addrs.append(str(task_addr))
102 shell_addrs.append(str(task_addr))
104
103
105 # Uncomment this to go back to two-socket model
104 # Uncomment this to go back to two-socket model
106 # shell_streams = []
105 # shell_streams = []
107 # for addr in shell_addrs:
106 # for addr in shell_addrs:
108 # stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
107 # stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
109 # stream.setsockopt(zmq.IDENTITY, identity)
108 # stream.setsockopt(zmq.IDENTITY, identity)
110 # stream.connect(disambiguate_url(addr, self.location))
109 # stream.connect(disambiguate_url(addr, self.location))
111 # shell_streams.append(stream)
110 # shell_streams.append(stream)
112
111
113 # Now use only one shell stream for mux and tasks
112 # Now use only one shell stream for mux and tasks
114 stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
113 stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
115 stream.setsockopt(zmq.IDENTITY, identity)
114 stream.setsockopt(zmq.IDENTITY, identity)
116 shell_streams = [stream]
115 shell_streams = [stream]
117 for addr in shell_addrs:
116 for addr in shell_addrs:
118 stream.connect(disambiguate_url(addr, self.location))
117 stream.connect(disambiguate_url(addr, self.location))
119 # end single stream-socket
118 # end single stream-socket
120
119
121 # control stream:
120 # control stream:
122 control_addr = str(msg.content.control)
121 control_addr = str(msg.content.control)
123 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
122 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
124 control_stream.setsockopt(zmq.IDENTITY, identity)
123 control_stream.setsockopt(zmq.IDENTITY, identity)
125 control_stream.connect(disambiguate_url(control_addr, self.location))
124 control_stream.connect(disambiguate_url(control_addr, self.location))
126
125
127 # create iopub stream:
126 # create iopub stream:
128 iopub_addr = msg.content.iopub
127 iopub_addr = msg.content.iopub
129 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
128 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
130 iopub_stream.setsockopt(zmq.IDENTITY, identity)
129 iopub_stream.setsockopt(zmq.IDENTITY, identity)
131 iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
130 iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
132
131
133 # launch heartbeat
132 # launch heartbeat
134 hb_addrs = msg.content.heartbeat
133 hb_addrs = msg.content.heartbeat
135 # print (hb_addrs)
134 # print (hb_addrs)
136
135
137 # # Redirect input streams and set a display hook.
136 # # Redirect input streams and set a display hook.
138 if self.out_stream_factory:
137 if self.out_stream_factory:
139 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
138 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
140 sys.stdout.topic = 'engine.%i.stdout'%self.id
139 sys.stdout.topic = 'engine.%i.stdout'%self.id
141 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
140 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
142 sys.stderr.topic = 'engine.%i.stderr'%self.id
141 sys.stderr.topic = 'engine.%i.stderr'%self.id
143 if self.display_hook_factory:
142 if self.display_hook_factory:
144 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
143 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
145 sys.displayhook.topic = 'engine.%i.pyout'%self.id
144 sys.displayhook.topic = 'engine.%i.pyout'%self.id
146
145
147 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
146 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
148 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
147 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
149 loop=loop, user_ns = self.user_ns, log=self.log)
148 loop=loop, user_ns = self.user_ns, log=self.log)
150 self.kernel.start()
149 self.kernel.start()
151 hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
150 hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
152 heart = Heart(*map(str, hb_addrs), heart_id=identity)
151 heart = Heart(*map(str, hb_addrs), heart_id=identity)
153 heart.start()
152 heart.start()
154
153
155
154
156 else:
155 else:
157 self.log.fatal("Registration Failed: %s"%msg)
156 self.log.fatal("Registration Failed: %s"%msg)
158 raise Exception("Registration Failed: %s"%msg)
157 raise Exception("Registration Failed: %s"%msg)
159
158
160 self.log.info("Completed registration with id %i"%self.id)
159 self.log.info("Completed registration with id %i"%self.id)
161
160
162
161
163 def abort(self):
162 def abort(self):
164 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
163 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
165 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
164 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
166 time.sleep(1)
165 time.sleep(1)
167 sys.exit(255)
166 sys.exit(255)
168
167
169 def start(self):
168 def start(self):
170 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
169 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
171 dc.start()
170 dc.start()
172 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
171 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
173 self._abort_dc.start()
172 self._abort_dc.start()
174
173
@@ -1,438 +1,437 b''
1 #!/usr/bin/env python
2 """
1 """
3 Kernel adapted from kernel.py to use ZMQ Streams
2 Kernel adapted from kernel.py to use ZMQ Streams
4
3
5 Authors:
4 Authors:
6
5
7 * Min RK
6 * Min RK
8 * Brian Granger
7 * Brian Granger
9 * Fernando Perez
8 * Fernando Perez
10 * Evan Patterson
9 * Evan Patterson
11 """
10 """
12 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
13 # Copyright (C) 2010-2011 The IPython Development Team
12 # Copyright (C) 2010-2011 The IPython Development Team
14 #
13 #
15 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
16 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
17 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
18
17
19 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
20 # Imports
19 # Imports
21 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
22
21
23 # Standard library imports.
22 # Standard library imports.
24 from __future__ import print_function
23 from __future__ import print_function
25
24
26 import sys
25 import sys
27 import time
26 import time
28
27
29 from code import CommandCompiler
28 from code import CommandCompiler
30 from datetime import datetime
29 from datetime import datetime
31 from pprint import pprint
30 from pprint import pprint
32
31
33 # System library imports.
32 # System library imports.
34 import zmq
33 import zmq
35 from zmq.eventloop import ioloop, zmqstream
34 from zmq.eventloop import ioloop, zmqstream
36
35
37 # Local imports.
36 # Local imports.
38 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Unicode, CBytes
37 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Unicode, CBytes
39 from IPython.zmq.completer import KernelCompleter
38 from IPython.zmq.completer import KernelCompleter
40
39
41 from IPython.parallel.error import wrap_exception
40 from IPython.parallel.error import wrap_exception
42 from IPython.parallel.factory import SessionFactory
41 from IPython.parallel.factory import SessionFactory
43 from IPython.parallel.util import serialize_object, unpack_apply_message, asbytes
42 from IPython.parallel.util import serialize_object, unpack_apply_message, asbytes
44
43
45 def printer(*args):
44 def printer(*args):
46 pprint(args, stream=sys.__stdout__)
45 pprint(args, stream=sys.__stdout__)
47
46
48
47
49 class _Passer(zmqstream.ZMQStream):
48 class _Passer(zmqstream.ZMQStream):
50 """Empty class that implements `send()` that does nothing.
49 """Empty class that implements `send()` that does nothing.
51
50
52 Subclass ZMQStream for Session typechecking
51 Subclass ZMQStream for Session typechecking
53
52
54 """
53 """
55 def __init__(self, *args, **kwargs):
54 def __init__(self, *args, **kwargs):
56 pass
55 pass
57
56
58 def send(self, *args, **kwargs):
57 def send(self, *args, **kwargs):
59 pass
58 pass
60 send_multipart = send
59 send_multipart = send
61
60
62
61
63 #-----------------------------------------------------------------------------
62 #-----------------------------------------------------------------------------
64 # Main kernel class
63 # Main kernel class
65 #-----------------------------------------------------------------------------
64 #-----------------------------------------------------------------------------
66
65
67 class Kernel(SessionFactory):
66 class Kernel(SessionFactory):
68
67
69 #---------------------------------------------------------------------------
68 #---------------------------------------------------------------------------
70 # Kernel interface
69 # Kernel interface
71 #---------------------------------------------------------------------------
70 #---------------------------------------------------------------------------
72
71
73 # kwargs:
72 # kwargs:
74 exec_lines = List(Unicode, config=True,
73 exec_lines = List(Unicode, config=True,
75 help="List of lines to execute")
74 help="List of lines to execute")
76
75
77 # identities:
76 # identities:
78 int_id = Int(-1)
77 int_id = Int(-1)
79 bident = CBytes()
78 bident = CBytes()
80 ident = Unicode()
79 ident = Unicode()
81 def _ident_changed(self, name, old, new):
80 def _ident_changed(self, name, old, new):
82 self.bident = asbytes(new)
81 self.bident = asbytes(new)
83
82
84 user_ns = Dict(config=True, help="""Set the user's namespace of the Kernel""")
83 user_ns = Dict(config=True, help="""Set the user's namespace of the Kernel""")
85
84
86 control_stream = Instance(zmqstream.ZMQStream)
85 control_stream = Instance(zmqstream.ZMQStream)
87 task_stream = Instance(zmqstream.ZMQStream)
86 task_stream = Instance(zmqstream.ZMQStream)
88 iopub_stream = Instance(zmqstream.ZMQStream)
87 iopub_stream = Instance(zmqstream.ZMQStream)
89 client = Instance('IPython.parallel.Client')
88 client = Instance('IPython.parallel.Client')
90
89
91 # internals
90 # internals
92 shell_streams = List()
91 shell_streams = List()
93 compiler = Instance(CommandCompiler, (), {})
92 compiler = Instance(CommandCompiler, (), {})
94 completer = Instance(KernelCompleter)
93 completer = Instance(KernelCompleter)
95
94
96 aborted = Set()
95 aborted = Set()
97 shell_handlers = Dict()
96 shell_handlers = Dict()
98 control_handlers = Dict()
97 control_handlers = Dict()
99
98
100 def _set_prefix(self):
99 def _set_prefix(self):
101 self.prefix = "engine.%s"%self.int_id
100 self.prefix = "engine.%s"%self.int_id
102
101
103 def _connect_completer(self):
102 def _connect_completer(self):
104 self.completer = KernelCompleter(self.user_ns)
103 self.completer = KernelCompleter(self.user_ns)
105
104
106 def __init__(self, **kwargs):
105 def __init__(self, **kwargs):
107 super(Kernel, self).__init__(**kwargs)
106 super(Kernel, self).__init__(**kwargs)
108 self._set_prefix()
107 self._set_prefix()
109 self._connect_completer()
108 self._connect_completer()
110
109
111 self.on_trait_change(self._set_prefix, 'id')
110 self.on_trait_change(self._set_prefix, 'id')
112 self.on_trait_change(self._connect_completer, 'user_ns')
111 self.on_trait_change(self._connect_completer, 'user_ns')
113
112
114 # Build dict of handlers for message types
113 # Build dict of handlers for message types
115 for msg_type in ['execute_request', 'complete_request', 'apply_request',
114 for msg_type in ['execute_request', 'complete_request', 'apply_request',
116 'clear_request']:
115 'clear_request']:
117 self.shell_handlers[msg_type] = getattr(self, msg_type)
116 self.shell_handlers[msg_type] = getattr(self, msg_type)
118
117
119 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
118 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
120 self.control_handlers[msg_type] = getattr(self, msg_type)
119 self.control_handlers[msg_type] = getattr(self, msg_type)
121
120
122 self._initial_exec_lines()
121 self._initial_exec_lines()
123
122
124 def _wrap_exception(self, method=None):
123 def _wrap_exception(self, method=None):
125 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
124 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
126 content=wrap_exception(e_info)
125 content=wrap_exception(e_info)
127 return content
126 return content
128
127
129 def _initial_exec_lines(self):
128 def _initial_exec_lines(self):
130 s = _Passer()
129 s = _Passer()
131 content = dict(silent=True, user_variable=[],user_expressions=[])
130 content = dict(silent=True, user_variable=[],user_expressions=[])
132 for line in self.exec_lines:
131 for line in self.exec_lines:
133 self.log.debug("executing initialization: %s"%line)
132 self.log.debug("executing initialization: %s"%line)
134 content.update({'code':line})
133 content.update({'code':line})
135 msg = self.session.msg('execute_request', content)
134 msg = self.session.msg('execute_request', content)
136 self.execute_request(s, [], msg)
135 self.execute_request(s, [], msg)
137
136
138
137
139 #-------------------- control handlers -----------------------------
138 #-------------------- control handlers -----------------------------
140 def abort_queues(self):
139 def abort_queues(self):
141 for stream in self.shell_streams:
140 for stream in self.shell_streams:
142 if stream:
141 if stream:
143 self.abort_queue(stream)
142 self.abort_queue(stream)
144
143
145 def abort_queue(self, stream):
144 def abort_queue(self, stream):
146 while True:
145 while True:
147 idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
146 idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
148 if msg is None:
147 if msg is None:
149 return
148 return
150
149
151 self.log.info("Aborting:")
150 self.log.info("Aborting:")
152 self.log.info(str(msg))
151 self.log.info(str(msg))
153 msg_type = msg['msg_type']
152 msg_type = msg['msg_type']
154 reply_type = msg_type.split('_')[0] + '_reply'
153 reply_type = msg_type.split('_')[0] + '_reply'
155 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
154 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
156 # self.reply_socket.send(ident,zmq.SNDMORE)
155 # self.reply_socket.send(ident,zmq.SNDMORE)
157 # self.reply_socket.send_json(reply_msg)
156 # self.reply_socket.send_json(reply_msg)
158 reply_msg = self.session.send(stream, reply_type,
157 reply_msg = self.session.send(stream, reply_type,
159 content={'status' : 'aborted'}, parent=msg, ident=idents)
158 content={'status' : 'aborted'}, parent=msg, ident=idents)
160 self.log.debug(str(reply_msg))
159 self.log.debug(str(reply_msg))
161 # We need to wait a bit for requests to come in. This can probably
160 # We need to wait a bit for requests to come in. This can probably
162 # be set shorter for true asynchronous clients.
161 # be set shorter for true asynchronous clients.
163 time.sleep(0.05)
162 time.sleep(0.05)
164
163
165 def abort_request(self, stream, ident, parent):
164 def abort_request(self, stream, ident, parent):
166 """abort a specifig msg by id"""
165 """abort a specifig msg by id"""
167 msg_ids = parent['content'].get('msg_ids', None)
166 msg_ids = parent['content'].get('msg_ids', None)
168 if isinstance(msg_ids, basestring):
167 if isinstance(msg_ids, basestring):
169 msg_ids = [msg_ids]
168 msg_ids = [msg_ids]
170 if not msg_ids:
169 if not msg_ids:
171 self.abort_queues()
170 self.abort_queues()
172 for mid in msg_ids:
171 for mid in msg_ids:
173 self.aborted.add(str(mid))
172 self.aborted.add(str(mid))
174
173
175 content = dict(status='ok')
174 content = dict(status='ok')
176 reply_msg = self.session.send(stream, 'abort_reply', content=content,
175 reply_msg = self.session.send(stream, 'abort_reply', content=content,
177 parent=parent, ident=ident)
176 parent=parent, ident=ident)
178 self.log.debug(str(reply_msg))
177 self.log.debug(str(reply_msg))
179
178
180 def shutdown_request(self, stream, ident, parent):
179 def shutdown_request(self, stream, ident, parent):
181 """kill ourself. This should really be handled in an external process"""
180 """kill ourself. This should really be handled in an external process"""
182 try:
181 try:
183 self.abort_queues()
182 self.abort_queues()
184 except:
183 except:
185 content = self._wrap_exception('shutdown')
184 content = self._wrap_exception('shutdown')
186 else:
185 else:
187 content = dict(parent['content'])
186 content = dict(parent['content'])
188 content['status'] = 'ok'
187 content['status'] = 'ok'
189 msg = self.session.send(stream, 'shutdown_reply',
188 msg = self.session.send(stream, 'shutdown_reply',
190 content=content, parent=parent, ident=ident)
189 content=content, parent=parent, ident=ident)
191 self.log.debug(str(msg))
190 self.log.debug(str(msg))
192 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
191 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
193 dc.start()
192 dc.start()
194
193
195 def dispatch_control(self, msg):
194 def dispatch_control(self, msg):
196 idents,msg = self.session.feed_identities(msg, copy=False)
195 idents,msg = self.session.feed_identities(msg, copy=False)
197 try:
196 try:
198 msg = self.session.unpack_message(msg, content=True, copy=False)
197 msg = self.session.unpack_message(msg, content=True, copy=False)
199 except:
198 except:
200 self.log.error("Invalid Message", exc_info=True)
199 self.log.error("Invalid Message", exc_info=True)
201 return
200 return
202 else:
201 else:
203 self.log.debug("Control received, %s", msg)
202 self.log.debug("Control received, %s", msg)
204
203
205 header = msg['header']
204 header = msg['header']
206 msg_id = header['msg_id']
205 msg_id = header['msg_id']
207
206
208 handler = self.control_handlers.get(msg['msg_type'], None)
207 handler = self.control_handlers.get(msg['msg_type'], None)
209 if handler is None:
208 if handler is None:
210 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
209 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
211 else:
210 else:
212 handler(self.control_stream, idents, msg)
211 handler(self.control_stream, idents, msg)
213
212
214
213
215 #-------------------- queue helpers ------------------------------
214 #-------------------- queue helpers ------------------------------
216
215
217 def check_dependencies(self, dependencies):
216 def check_dependencies(self, dependencies):
218 if not dependencies:
217 if not dependencies:
219 return True
218 return True
220 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
219 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
221 anyorall = dependencies[0]
220 anyorall = dependencies[0]
222 dependencies = dependencies[1]
221 dependencies = dependencies[1]
223 else:
222 else:
224 anyorall = 'all'
223 anyorall = 'all'
225 results = self.client.get_results(dependencies,status_only=True)
224 results = self.client.get_results(dependencies,status_only=True)
226 if results['status'] != 'ok':
225 if results['status'] != 'ok':
227 return False
226 return False
228
227
229 if anyorall == 'any':
228 if anyorall == 'any':
230 if not results['completed']:
229 if not results['completed']:
231 return False
230 return False
232 else:
231 else:
233 if results['pending']:
232 if results['pending']:
234 return False
233 return False
235
234
236 return True
235 return True
237
236
238 def check_aborted(self, msg_id):
237 def check_aborted(self, msg_id):
239 return msg_id in self.aborted
238 return msg_id in self.aborted
240
239
241 #-------------------- queue handlers -----------------------------
240 #-------------------- queue handlers -----------------------------
242
241
243 def clear_request(self, stream, idents, parent):
242 def clear_request(self, stream, idents, parent):
244 """Clear our namespace."""
243 """Clear our namespace."""
245 self.user_ns = {}
244 self.user_ns = {}
246 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
245 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
247 content = dict(status='ok'))
246 content = dict(status='ok'))
248 self._initial_exec_lines()
247 self._initial_exec_lines()
249
248
250 def execute_request(self, stream, ident, parent):
249 def execute_request(self, stream, ident, parent):
251 self.log.debug('execute request %s'%parent)
250 self.log.debug('execute request %s'%parent)
252 try:
251 try:
253 code = parent[u'content'][u'code']
252 code = parent[u'content'][u'code']
254 except:
253 except:
255 self.log.error("Got bad msg: %s"%parent, exc_info=True)
254 self.log.error("Got bad msg: %s"%parent, exc_info=True)
256 return
255 return
257 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
256 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
258 ident=asbytes('%s.pyin'%self.prefix))
257 ident=asbytes('%s.pyin'%self.prefix))
259 started = datetime.now()
258 started = datetime.now()
260 try:
259 try:
261 comp_code = self.compiler(code, '<zmq-kernel>')
260 comp_code = self.compiler(code, '<zmq-kernel>')
262 # allow for not overriding displayhook
261 # allow for not overriding displayhook
263 if hasattr(sys.displayhook, 'set_parent'):
262 if hasattr(sys.displayhook, 'set_parent'):
264 sys.displayhook.set_parent(parent)
263 sys.displayhook.set_parent(parent)
265 sys.stdout.set_parent(parent)
264 sys.stdout.set_parent(parent)
266 sys.stderr.set_parent(parent)
265 sys.stderr.set_parent(parent)
267 exec comp_code in self.user_ns, self.user_ns
266 exec comp_code in self.user_ns, self.user_ns
268 except:
267 except:
269 exc_content = self._wrap_exception('execute')
268 exc_content = self._wrap_exception('execute')
270 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
269 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
271 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
270 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
272 ident=asbytes('%s.pyerr'%self.prefix))
271 ident=asbytes('%s.pyerr'%self.prefix))
273 reply_content = exc_content
272 reply_content = exc_content
274 else:
273 else:
275 reply_content = {'status' : 'ok'}
274 reply_content = {'status' : 'ok'}
276
275
277 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
276 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
278 ident=ident, subheader = dict(started=started))
277 ident=ident, subheader = dict(started=started))
279 self.log.debug(str(reply_msg))
278 self.log.debug(str(reply_msg))
280 if reply_msg['content']['status'] == u'error':
279 if reply_msg['content']['status'] == u'error':
281 self.abort_queues()
280 self.abort_queues()
282
281
283 def complete_request(self, stream, ident, parent):
282 def complete_request(self, stream, ident, parent):
284 matches = {'matches' : self.complete(parent),
283 matches = {'matches' : self.complete(parent),
285 'status' : 'ok'}
284 'status' : 'ok'}
286 completion_msg = self.session.send(stream, 'complete_reply',
285 completion_msg = self.session.send(stream, 'complete_reply',
287 matches, parent, ident)
286 matches, parent, ident)
288 # print >> sys.__stdout__, completion_msg
287 # print >> sys.__stdout__, completion_msg
289
288
290 def complete(self, msg):
289 def complete(self, msg):
291 return self.completer.complete(msg.content.line, msg.content.text)
290 return self.completer.complete(msg.content.line, msg.content.text)
292
291
293 def apply_request(self, stream, ident, parent):
292 def apply_request(self, stream, ident, parent):
294 # flush previous reply, so this request won't block it
293 # flush previous reply, so this request won't block it
295 stream.flush(zmq.POLLOUT)
294 stream.flush(zmq.POLLOUT)
296 try:
295 try:
297 content = parent[u'content']
296 content = parent[u'content']
298 bufs = parent[u'buffers']
297 bufs = parent[u'buffers']
299 msg_id = parent['header']['msg_id']
298 msg_id = parent['header']['msg_id']
300 # bound = parent['header'].get('bound', False)
299 # bound = parent['header'].get('bound', False)
301 except:
300 except:
302 self.log.error("Got bad msg: %s"%parent, exc_info=True)
301 self.log.error("Got bad msg: %s"%parent, exc_info=True)
303 return
302 return
304 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
303 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
305 # self.iopub_stream.send(pyin_msg)
304 # self.iopub_stream.send(pyin_msg)
306 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
305 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
307 sub = {'dependencies_met' : True, 'engine' : self.ident,
306 sub = {'dependencies_met' : True, 'engine' : self.ident,
308 'started': datetime.now()}
307 'started': datetime.now()}
309 try:
308 try:
310 # allow for not overriding displayhook
309 # allow for not overriding displayhook
311 if hasattr(sys.displayhook, 'set_parent'):
310 if hasattr(sys.displayhook, 'set_parent'):
312 sys.displayhook.set_parent(parent)
311 sys.displayhook.set_parent(parent)
313 sys.stdout.set_parent(parent)
312 sys.stdout.set_parent(parent)
314 sys.stderr.set_parent(parent)
313 sys.stderr.set_parent(parent)
315 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
314 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
316 working = self.user_ns
315 working = self.user_ns
317 # suffix =
316 # suffix =
318 prefix = "_"+str(msg_id).replace("-","")+"_"
317 prefix = "_"+str(msg_id).replace("-","")+"_"
319
318
320 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
319 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
321 # if bound:
320 # if bound:
322 # bound_ns = Namespace(working)
321 # bound_ns = Namespace(working)
323 # args = [bound_ns]+list(args)
322 # args = [bound_ns]+list(args)
324
323
325 fname = getattr(f, '__name__', 'f')
324 fname = getattr(f, '__name__', 'f')
326
325
327 fname = prefix+"f"
326 fname = prefix+"f"
328 argname = prefix+"args"
327 argname = prefix+"args"
329 kwargname = prefix+"kwargs"
328 kwargname = prefix+"kwargs"
330 resultname = prefix+"result"
329 resultname = prefix+"result"
331
330
332 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
331 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
333 # print ns
332 # print ns
334 working.update(ns)
333 working.update(ns)
335 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
334 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
336 try:
335 try:
337 exec code in working,working
336 exec code in working,working
338 result = working.get(resultname)
337 result = working.get(resultname)
339 finally:
338 finally:
340 for key in ns.iterkeys():
339 for key in ns.iterkeys():
341 working.pop(key)
340 working.pop(key)
342 # if bound:
341 # if bound:
343 # working.update(bound_ns)
342 # working.update(bound_ns)
344
343
345 packed_result,buf = serialize_object(result)
344 packed_result,buf = serialize_object(result)
346 result_buf = [packed_result]+buf
345 result_buf = [packed_result]+buf
347 except:
346 except:
348 exc_content = self._wrap_exception('apply')
347 exc_content = self._wrap_exception('apply')
349 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
348 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
350 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
349 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
351 ident=asbytes('%s.pyerr'%self.prefix))
350 ident=asbytes('%s.pyerr'%self.prefix))
352 reply_content = exc_content
351 reply_content = exc_content
353 result_buf = []
352 result_buf = []
354
353
355 if exc_content['ename'] == 'UnmetDependency':
354 if exc_content['ename'] == 'UnmetDependency':
356 sub['dependencies_met'] = False
355 sub['dependencies_met'] = False
357 else:
356 else:
358 reply_content = {'status' : 'ok'}
357 reply_content = {'status' : 'ok'}
359
358
360 # put 'ok'/'error' status in header, for scheduler introspection:
359 # put 'ok'/'error' status in header, for scheduler introspection:
361 sub['status'] = reply_content['status']
360 sub['status'] = reply_content['status']
362
361
363 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
362 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
364 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
363 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
365
364
366 # flush i/o
365 # flush i/o
367 # should this be before reply_msg is sent, like in the single-kernel code,
366 # should this be before reply_msg is sent, like in the single-kernel code,
368 # or should nothing get in the way of real results?
367 # or should nothing get in the way of real results?
369 sys.stdout.flush()
368 sys.stdout.flush()
370 sys.stderr.flush()
369 sys.stderr.flush()
371
370
372 def dispatch_queue(self, stream, msg):
371 def dispatch_queue(self, stream, msg):
373 self.control_stream.flush()
372 self.control_stream.flush()
374 idents,msg = self.session.feed_identities(msg, copy=False)
373 idents,msg = self.session.feed_identities(msg, copy=False)
375 try:
374 try:
376 msg = self.session.unpack_message(msg, content=True, copy=False)
375 msg = self.session.unpack_message(msg, content=True, copy=False)
377 except:
376 except:
378 self.log.error("Invalid Message", exc_info=True)
377 self.log.error("Invalid Message", exc_info=True)
379 return
378 return
380 else:
379 else:
381 self.log.debug("Message received, %s", msg)
380 self.log.debug("Message received, %s", msg)
382
381
383
382
384 header = msg['header']
383 header = msg['header']
385 msg_id = header['msg_id']
384 msg_id = header['msg_id']
386 if self.check_aborted(msg_id):
385 if self.check_aborted(msg_id):
387 self.aborted.remove(msg_id)
386 self.aborted.remove(msg_id)
388 # is it safe to assume a msg_id will not be resubmitted?
387 # is it safe to assume a msg_id will not be resubmitted?
389 reply_type = msg['msg_type'].split('_')[0] + '_reply'
388 reply_type = msg['msg_type'].split('_')[0] + '_reply'
390 status = {'status' : 'aborted'}
389 status = {'status' : 'aborted'}
391 reply_msg = self.session.send(stream, reply_type, subheader=status,
390 reply_msg = self.session.send(stream, reply_type, subheader=status,
392 content=status, parent=msg, ident=idents)
391 content=status, parent=msg, ident=idents)
393 return
392 return
394 handler = self.shell_handlers.get(msg['msg_type'], None)
393 handler = self.shell_handlers.get(msg['msg_type'], None)
395 if handler is None:
394 if handler is None:
396 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
395 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
397 else:
396 else:
398 handler(stream, idents, msg)
397 handler(stream, idents, msg)
399
398
400 def start(self):
399 def start(self):
401 #### stream mode:
400 #### stream mode:
402 if self.control_stream:
401 if self.control_stream:
403 self.control_stream.on_recv(self.dispatch_control, copy=False)
402 self.control_stream.on_recv(self.dispatch_control, copy=False)
404 self.control_stream.on_err(printer)
403 self.control_stream.on_err(printer)
405
404
406 def make_dispatcher(stream):
405 def make_dispatcher(stream):
407 def dispatcher(msg):
406 def dispatcher(msg):
408 return self.dispatch_queue(stream, msg)
407 return self.dispatch_queue(stream, msg)
409 return dispatcher
408 return dispatcher
410
409
411 for s in self.shell_streams:
410 for s in self.shell_streams:
412 s.on_recv(make_dispatcher(s), copy=False)
411 s.on_recv(make_dispatcher(s), copy=False)
413 s.on_err(printer)
412 s.on_err(printer)
414
413
415 if self.iopub_stream:
414 if self.iopub_stream:
416 self.iopub_stream.on_err(printer)
415 self.iopub_stream.on_err(printer)
417
416
418 #### while True mode:
417 #### while True mode:
419 # while True:
418 # while True:
420 # idle = True
419 # idle = True
421 # try:
420 # try:
422 # msg = self.shell_stream.socket.recv_multipart(
421 # msg = self.shell_stream.socket.recv_multipart(
423 # zmq.NOBLOCK, copy=False)
422 # zmq.NOBLOCK, copy=False)
424 # except zmq.ZMQError, e:
423 # except zmq.ZMQError, e:
425 # if e.errno != zmq.EAGAIN:
424 # if e.errno != zmq.EAGAIN:
426 # raise e
425 # raise e
427 # else:
426 # else:
428 # idle=False
427 # idle=False
429 # self.dispatch_queue(self.shell_stream, msg)
428 # self.dispatch_queue(self.shell_stream, msg)
430 #
429 #
431 # if not self.task_stream.empty():
430 # if not self.task_stream.empty():
432 # idle=False
431 # idle=False
433 # msg = self.task_stream.recv_multipart()
432 # msg = self.task_stream.recv_multipart()
434 # self.dispatch_queue(self.task_stream, msg)
433 # self.dispatch_queue(self.task_stream, msg)
435 # if idle:
434 # if idle:
436 # # don't busywait
435 # # don't busywait
437 # time.sleep(1e-3)
436 # time.sleep(1e-3)
438
437
@@ -1,38 +1,35 b''
1 #!/usr/bin/env python
2
3
4 """
1 """
5 Add %global magic for GNU Global usage.
2 Add %global magic for GNU Global usage.
6
3
7 http://www.gnu.org/software/global/
4 http://www.gnu.org/software/global/
8
5
9 """
6 """
10
7
11 from IPython.core import ipapi
8 from IPython.core import ipapi
12 ip = ipapi.get()
9 ip = ipapi.get()
13 import os
10 import os
14
11
15 # alter to your liking
12 # alter to your liking
16 global_bin = 'd:/opt/global/bin/global'
13 global_bin = 'd:/opt/global/bin/global'
17
14
18 def global_f(self,cmdline):
15 def global_f(self,cmdline):
19 simple = 0
16 simple = 0
20 if '-' not in cmdline:
17 if '-' not in cmdline:
21 cmdline = '-rx ' + cmdline
18 cmdline = '-rx ' + cmdline
22 simple = 1
19 simple = 1
23
20
24 lines = [l.rstrip() for l in os.popen( global_bin + ' ' + cmdline ).readlines()]
21 lines = [l.rstrip() for l in os.popen( global_bin + ' ' + cmdline ).readlines()]
25
22
26 if simple:
23 if simple:
27 parts = [l.split(None,3) for l in lines]
24 parts = [l.split(None,3) for l in lines]
28 lines = ['%s [%s]\n%s' % (p[2].rjust(70),p[1],p[3].rstrip()) for p in parts]
25 lines = ['%s [%s]\n%s' % (p[2].rjust(70),p[1],p[3].rstrip()) for p in parts]
29 print "\n".join(lines)
26 print "\n".join(lines)
30
27
31 ip.define_magic('global', global_f)
28 ip.define_magic('global', global_f)
32
29
33 def global_completer(self,event):
30 def global_completer(self,event):
34 compl = [l.rstrip() for l in os.popen(global_bin + ' -c ' + event.symbol).readlines()]
31 compl = [l.rstrip() for l in os.popen(global_bin + ' -c ' + event.symbol).readlines()]
35 return compl
32 return compl
36
33
37 ip.set_hook('complete_command', global_completer, str_key = '%global')
34 ip.set_hook('complete_command', global_completer, str_key = '%global')
38
35
@@ -1,68 +1,65 b''
1 #!/usr/bin/env python
2
3 """ IPython extension: Render templates from variables and paste to clipbard """
1 """ IPython extension: Render templates from variables and paste to clipbard """
4
2
5 from IPython.core import ipapi
3 from IPython.core import ipapi
6
4
7 ip = ipapi.get()
5 ip = ipapi.get()
8
6
9 from string import Template
7 from string import Template
10 import sys,os
8 import sys,os
11
9
12 from IPython.external.Itpl import itplns
10 from IPython.external.Itpl import itplns
13
11
14 def toclip_w32(s):
12 def toclip_w32(s):
15 """ Places contents of s to clipboard
13 """ Places contents of s to clipboard
16
14
17 Needs pyvin32 to work:
15 Needs pyvin32 to work:
18 http://sourceforge.net/projects/pywin32/
16 http://sourceforge.net/projects/pywin32/
19 """
17 """
20 import win32clipboard as cl
18 import win32clipboard as cl
21 import win32con
19 import win32con
22 cl.OpenClipboard()
20 cl.OpenClipboard()
23 cl.EmptyClipboard()
21 cl.EmptyClipboard()
24 cl.SetClipboardText( s.replace('\n','\r\n' ))
22 cl.SetClipboardText( s.replace('\n','\r\n' ))
25 cl.CloseClipboard()
23 cl.CloseClipboard()
26
24
27 try:
25 try:
28 import win32clipboard
26 import win32clipboard
29 toclip = toclip_w32
27 toclip = toclip_w32
30 except ImportError:
28 except ImportError:
31 def toclip(s): pass
29 def toclip(s): pass
32
30
33
31
34 def render(tmpl):
32 def render(tmpl):
35 """ Render a template (Itpl format) from ipython variables
33 """ Render a template (Itpl format) from ipython variables
36
34
37 Example:
35 Example:
38
36
39 $ import ipy_render
37 $ import ipy_render
40 $ my_name = 'Bob' # %store this for convenience
38 $ my_name = 'Bob' # %store this for convenience
41 $ t_submission_form = "Submission report, author: $my_name" # %store also
39 $ t_submission_form = "Submission report, author: $my_name" # %store also
42 $ render t_submission_form
40 $ render t_submission_form
43
41
44 => returns "Submission report, author: Bob" and copies to clipboard on win32
42 => returns "Submission report, author: Bob" and copies to clipboard on win32
45
43
46 # if template exist as a file, read it. Note: ;f hei vaan => f("hei vaan")
44 # if template exist as a file, read it. Note: ;f hei vaan => f("hei vaan")
47 $ ;render c:/templates/greeting.txt
45 $ ;render c:/templates/greeting.txt
48
46
49 Template examples (Ka-Ping Yee's Itpl library):
47 Template examples (Ka-Ping Yee's Itpl library):
50
48
51 Here is a $string.
49 Here is a $string.
52 Here is a $module.member.
50 Here is a $module.member.
53 Here is an $object.member.
51 Here is an $object.member.
54 Here is a $functioncall(with, arguments).
52 Here is a $functioncall(with, arguments).
55 Here is an ${arbitrary + expression}.
53 Here is an ${arbitrary + expression}.
56 Here is an $array[3] member.
54 Here is an $array[3] member.
57 Here is a $dictionary['member'].
55 Here is a $dictionary['member'].
58 """
56 """
59
57
60 if os.path.isfile(tmpl):
58 if os.path.isfile(tmpl):
61 tmpl = open(tmpl).read()
59 tmpl = open(tmpl).read()
62
60
63 res = itplns(tmpl, ip.user_ns)
61 res = itplns(tmpl, ip.user_ns)
64 toclip(res)
62 toclip(res)
65 return res
63 return res
66
64
67 ip.push('render')
65 ip.push('render')
68 No newline at end of file
@@ -1,43 +1,41 b''
1 #!/usr/bin/env python
2
3 from IPython.core import ipapi
1 from IPython.core import ipapi
4 ip = ipapi.get()
2 ip = ipapi.get()
5
3
6 import os, subprocess
4 import os, subprocess
7
5
8 workdir = None
6 workdir = None
9 def workdir_f(ip,line):
7 def workdir_f(ip,line):
10 """ Exceute commands residing in cwd elsewhere
8 """ Exceute commands residing in cwd elsewhere
11
9
12 Example::
10 Example::
13
11
14 workdir /myfiles
12 workdir /myfiles
15 cd bin
13 cd bin
16 workdir myscript.py
14 workdir myscript.py
17
15
18 executes myscript.py (stored in bin, but not in path) in /myfiles
16 executes myscript.py (stored in bin, but not in path) in /myfiles
19 """
17 """
20 global workdir
18 global workdir
21 dummy,cmd = line.split(None,1)
19 dummy,cmd = line.split(None,1)
22 if os.path.isdir(cmd):
20 if os.path.isdir(cmd):
23 workdir = os.path.abspath(cmd)
21 workdir = os.path.abspath(cmd)
24 print "Set workdir",workdir
22 print "Set workdir",workdir
25 elif workdir is None:
23 elif workdir is None:
26 print "Please set workdir first by doing e.g. 'workdir q:/'"
24 print "Please set workdir first by doing e.g. 'workdir q:/'"
27 else:
25 else:
28 sp = cmd.split(None,1)
26 sp = cmd.split(None,1)
29 if len(sp) == 1:
27 if len(sp) == 1:
30 head, tail = cmd, ''
28 head, tail = cmd, ''
31 else:
29 else:
32 head, tail = sp
30 head, tail = sp
33 if os.path.isfile(head):
31 if os.path.isfile(head):
34 cmd = os.path.abspath(head) + ' ' + tail
32 cmd = os.path.abspath(head) + ' ' + tail
35 print "Execute command '" + cmd+ "' in",workdir
33 print "Execute command '" + cmd+ "' in",workdir
36 olddir = os.getcwdu()
34 olddir = os.getcwdu()
37 os.chdir(workdir)
35 os.chdir(workdir)
38 try:
36 try:
39 os.system(cmd)
37 os.system(cmd)
40 finally:
38 finally:
41 os.chdir(olddir)
39 os.chdir(olddir)
42
40
43 ip.define_alias("workdir",workdir_f)
41 ip.define_alias("workdir",workdir_f)
1 NO CONTENT: modified file chmod 100644 => 100755
NO CONTENT: modified file chmod 100644 => 100755
@@ -1,74 +1,73 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 Tests for testing.tools
3 Tests for testing.tools
5 """
4 """
6
5
7 #-----------------------------------------------------------------------------
6 #-----------------------------------------------------------------------------
8 # Copyright (C) 2008-2009 The IPython Development Team
7 # Copyright (C) 2008-2009 The IPython Development Team
9 #
8 #
10 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
13
12
14 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
15 # Imports
14 # Imports
16 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
17 from __future__ import with_statement
16 from __future__ import with_statement
18
17
19 import os
18 import os
20 import sys
19 import sys
21
20
22 import nose.tools as nt
21 import nose.tools as nt
23
22
24 from IPython.testing import decorators as dec
23 from IPython.testing import decorators as dec
25 from IPython.testing import tools as tt
24 from IPython.testing import tools as tt
26
25
27 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
28 # Tests
27 # Tests
29 #-----------------------------------------------------------------------------
28 #-----------------------------------------------------------------------------
30
29
31 @dec.skip_win32
30 @dec.skip_win32
32 def test_full_path_posix():
31 def test_full_path_posix():
33 spath = '/foo/bar.py'
32 spath = '/foo/bar.py'
34 result = tt.full_path(spath,['a.txt','b.txt'])
33 result = tt.full_path(spath,['a.txt','b.txt'])
35 nt.assert_equal(result, ['/foo/a.txt', '/foo/b.txt'])
34 nt.assert_equal(result, ['/foo/a.txt', '/foo/b.txt'])
36 spath = '/foo'
35 spath = '/foo'
37 result = tt.full_path(spath,['a.txt','b.txt'])
36 result = tt.full_path(spath,['a.txt','b.txt'])
38 nt.assert_equal(result, ['/a.txt', '/b.txt'])
37 nt.assert_equal(result, ['/a.txt', '/b.txt'])
39 result = tt.full_path(spath,'a.txt')
38 result = tt.full_path(spath,'a.txt')
40 nt.assert_equal(result, ['/a.txt'])
39 nt.assert_equal(result, ['/a.txt'])
41
40
42
41
43 @dec.skip_if_not_win32
42 @dec.skip_if_not_win32
44 def test_full_path_win32():
43 def test_full_path_win32():
45 spath = 'c:\\foo\\bar.py'
44 spath = 'c:\\foo\\bar.py'
46 result = tt.full_path(spath,['a.txt','b.txt'])
45 result = tt.full_path(spath,['a.txt','b.txt'])
47 nt.assert_equal(result, ['c:\\foo\\a.txt', 'c:\\foo\\b.txt'])
46 nt.assert_equal(result, ['c:\\foo\\a.txt', 'c:\\foo\\b.txt'])
48 spath = 'c:\\foo'
47 spath = 'c:\\foo'
49 result = tt.full_path(spath,['a.txt','b.txt'])
48 result = tt.full_path(spath,['a.txt','b.txt'])
50 nt.assert_equal(result, ['c:\\a.txt', 'c:\\b.txt'])
49 nt.assert_equal(result, ['c:\\a.txt', 'c:\\b.txt'])
51 result = tt.full_path(spath,'a.txt')
50 result = tt.full_path(spath,'a.txt')
52 nt.assert_equal(result, ['c:\\a.txt'])
51 nt.assert_equal(result, ['c:\\a.txt'])
53
52
54
53
55 @dec.parametric
54 @dec.parametric
56 def test_parser():
55 def test_parser():
57 err = ("FAILED (errors=1)", 1, 0)
56 err = ("FAILED (errors=1)", 1, 0)
58 fail = ("FAILED (failures=1)", 0, 1)
57 fail = ("FAILED (failures=1)", 0, 1)
59 both = ("FAILED (errors=1, failures=1)", 1, 1)
58 both = ("FAILED (errors=1, failures=1)", 1, 1)
60 for txt, nerr, nfail in [err, fail, both]:
59 for txt, nerr, nfail in [err, fail, both]:
61 nerr1, nfail1 = tt.parse_test_output(txt)
60 nerr1, nfail1 = tt.parse_test_output(txt)
62 yield nt.assert_equal(nerr, nerr1)
61 yield nt.assert_equal(nerr, nerr1)
63 yield nt.assert_equal(nfail, nfail1)
62 yield nt.assert_equal(nfail, nfail1)
64
63
65
64
66 @dec.parametric
65 @dec.parametric
67 def test_temp_pyfile():
66 def test_temp_pyfile():
68 src = 'pass\n'
67 src = 'pass\n'
69 fname, fh = tt.temp_pyfile(src)
68 fname, fh = tt.temp_pyfile(src)
70 yield nt.assert_true(os.path.isfile(fname))
69 yield nt.assert_true(os.path.isfile(fname))
71 fh.close()
70 fh.close()
72 with open(fname) as fh2:
71 with open(fname) as fh2:
73 src2 = fh2.read()
72 src2 = fh2.read()
74 yield nt.assert_equal(src2, src)
73 yield nt.assert_equal(src2, src)
@@ -1,396 +1,395 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """A dict subclass that supports attribute style access.
2 """A dict subclass that supports attribute style access.
4
3
5 Authors:
4 Authors:
6
5
7 * Fernando Perez (original)
6 * Fernando Perez (original)
8 * Brian Granger (refactoring to a dict subclass)
7 * Brian Granger (refactoring to a dict subclass)
9 """
8 """
10
9
11 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
12 # Copyright (C) 2008-2009 The IPython Development Team
11 # Copyright (C) 2008-2009 The IPython Development Team
13 #
12 #
14 # Distributed under the terms of the BSD License. The full license is in
13 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
14 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
17
16
18 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
19 # Imports
18 # Imports
20 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
21
20
22 from IPython.utils.data import list2dict2
21 from IPython.utils.data import list2dict2
23
22
24 __all__ = ['Struct']
23 __all__ = ['Struct']
25
24
26 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
27 # Code
26 # Code
28 #-----------------------------------------------------------------------------
27 #-----------------------------------------------------------------------------
29
28
30
29
31 class Struct(dict):
30 class Struct(dict):
32 """A dict subclass with attribute style access.
31 """A dict subclass with attribute style access.
33
32
34 This dict subclass has a a few extra features:
33 This dict subclass has a a few extra features:
35
34
36 * Attribute style access.
35 * Attribute style access.
37 * Protection of class members (like keys, items) when using attribute
36 * Protection of class members (like keys, items) when using attribute
38 style access.
37 style access.
39 * The ability to restrict assignment to only existing keys.
38 * The ability to restrict assignment to only existing keys.
40 * Intelligent merging.
39 * Intelligent merging.
41 * Overloaded operators.
40 * Overloaded operators.
42 """
41 """
43 _allownew = True
42 _allownew = True
44 def __init__(self, *args, **kw):
43 def __init__(self, *args, **kw):
45 """Initialize with a dictionary, another Struct, or data.
44 """Initialize with a dictionary, another Struct, or data.
46
45
47 Parameters
46 Parameters
48 ----------
47 ----------
49 args : dict, Struct
48 args : dict, Struct
50 Initialize with one dict or Struct
49 Initialize with one dict or Struct
51 kw : dict
50 kw : dict
52 Initialize with key, value pairs.
51 Initialize with key, value pairs.
53
52
54 Examples
53 Examples
55 --------
54 --------
56
55
57 >>> s = Struct(a=10,b=30)
56 >>> s = Struct(a=10,b=30)
58 >>> s.a
57 >>> s.a
59 10
58 10
60 >>> s.b
59 >>> s.b
61 30
60 30
62 >>> s2 = Struct(s,c=30)
61 >>> s2 = Struct(s,c=30)
63 >>> s2.keys()
62 >>> s2.keys()
64 ['a', 'c', 'b']
63 ['a', 'c', 'b']
65 """
64 """
66 object.__setattr__(self, '_allownew', True)
65 object.__setattr__(self, '_allownew', True)
67 dict.__init__(self, *args, **kw)
66 dict.__init__(self, *args, **kw)
68
67
69 def __setitem__(self, key, value):
68 def __setitem__(self, key, value):
70 """Set an item with check for allownew.
69 """Set an item with check for allownew.
71
70
72 Examples
71 Examples
73 --------
72 --------
74
73
75 >>> s = Struct()
74 >>> s = Struct()
76 >>> s['a'] = 10
75 >>> s['a'] = 10
77 >>> s.allow_new_attr(False)
76 >>> s.allow_new_attr(False)
78 >>> s['a'] = 10
77 >>> s['a'] = 10
79 >>> s['a']
78 >>> s['a']
80 10
79 10
81 >>> try:
80 >>> try:
82 ... s['b'] = 20
81 ... s['b'] = 20
83 ... except KeyError:
82 ... except KeyError:
84 ... print 'this is not allowed'
83 ... print 'this is not allowed'
85 ...
84 ...
86 this is not allowed
85 this is not allowed
87 """
86 """
88 if not self._allownew and not self.has_key(key):
87 if not self._allownew and not self.has_key(key):
89 raise KeyError(
88 raise KeyError(
90 "can't create new attribute %s when allow_new_attr(False)" % key)
89 "can't create new attribute %s when allow_new_attr(False)" % key)
91 dict.__setitem__(self, key, value)
90 dict.__setitem__(self, key, value)
92
91
93 def __setattr__(self, key, value):
92 def __setattr__(self, key, value):
94 """Set an attr with protection of class members.
93 """Set an attr with protection of class members.
95
94
96 This calls :meth:`self.__setitem__` but convert :exc:`KeyError` to
95 This calls :meth:`self.__setitem__` but convert :exc:`KeyError` to
97 :exc:`AttributeError`.
96 :exc:`AttributeError`.
98
97
99 Examples
98 Examples
100 --------
99 --------
101
100
102 >>> s = Struct()
101 >>> s = Struct()
103 >>> s.a = 10
102 >>> s.a = 10
104 >>> s.a
103 >>> s.a
105 10
104 10
106 >>> try:
105 >>> try:
107 ... s.get = 10
106 ... s.get = 10
108 ... except AttributeError:
107 ... except AttributeError:
109 ... print "you can't set a class member"
108 ... print "you can't set a class member"
110 ...
109 ...
111 you can't set a class member
110 you can't set a class member
112 """
111 """
113 # If key is an str it might be a class member or instance var
112 # If key is an str it might be a class member or instance var
114 if isinstance(key, str):
113 if isinstance(key, str):
115 # I can't simply call hasattr here because it calls getattr, which
114 # I can't simply call hasattr here because it calls getattr, which
116 # calls self.__getattr__, which returns True for keys in
115 # calls self.__getattr__, which returns True for keys in
117 # self._data. But I only want keys in the class and in
116 # self._data. But I only want keys in the class and in
118 # self.__dict__
117 # self.__dict__
119 if key in self.__dict__ or hasattr(Struct, key):
118 if key in self.__dict__ or hasattr(Struct, key):
120 raise AttributeError(
119 raise AttributeError(
121 'attr %s is a protected member of class Struct.' % key
120 'attr %s is a protected member of class Struct.' % key
122 )
121 )
123 try:
122 try:
124 self.__setitem__(key, value)
123 self.__setitem__(key, value)
125 except KeyError, e:
124 except KeyError, e:
126 raise AttributeError(e)
125 raise AttributeError(e)
127
126
128 def __getattr__(self, key):
127 def __getattr__(self, key):
129 """Get an attr by calling :meth:`dict.__getitem__`.
128 """Get an attr by calling :meth:`dict.__getitem__`.
130
129
131 Like :meth:`__setattr__`, this method converts :exc:`KeyError` to
130 Like :meth:`__setattr__`, this method converts :exc:`KeyError` to
132 :exc:`AttributeError`.
131 :exc:`AttributeError`.
133
132
134 Examples
133 Examples
135 --------
134 --------
136
135
137 >>> s = Struct(a=10)
136 >>> s = Struct(a=10)
138 >>> s.a
137 >>> s.a
139 10
138 10
140 >>> type(s.get)
139 >>> type(s.get)
141 <type 'builtin_function_or_method'>
140 <type 'builtin_function_or_method'>
142 >>> try:
141 >>> try:
143 ... s.b
142 ... s.b
144 ... except AttributeError:
143 ... except AttributeError:
145 ... print "I don't have that key"
144 ... print "I don't have that key"
146 ...
145 ...
147 I don't have that key
146 I don't have that key
148 """
147 """
149 try:
148 try:
150 result = self[key]
149 result = self[key]
151 except KeyError:
150 except KeyError:
152 raise AttributeError(key)
151 raise AttributeError(key)
153 else:
152 else:
154 return result
153 return result
155
154
156 def __iadd__(self, other):
155 def __iadd__(self, other):
157 """s += s2 is a shorthand for s.merge(s2).
156 """s += s2 is a shorthand for s.merge(s2).
158
157
159 Examples
158 Examples
160 --------
159 --------
161
160
162 >>> s = Struct(a=10,b=30)
161 >>> s = Struct(a=10,b=30)
163 >>> s2 = Struct(a=20,c=40)
162 >>> s2 = Struct(a=20,c=40)
164 >>> s += s2
163 >>> s += s2
165 >>> s
164 >>> s
166 {'a': 10, 'c': 40, 'b': 30}
165 {'a': 10, 'c': 40, 'b': 30}
167 """
166 """
168 self.merge(other)
167 self.merge(other)
169 return self
168 return self
170
169
171 def __add__(self,other):
170 def __add__(self,other):
172 """s + s2 -> New Struct made from s.merge(s2).
171 """s + s2 -> New Struct made from s.merge(s2).
173
172
174 Examples
173 Examples
175 --------
174 --------
176
175
177 >>> s1 = Struct(a=10,b=30)
176 >>> s1 = Struct(a=10,b=30)
178 >>> s2 = Struct(a=20,c=40)
177 >>> s2 = Struct(a=20,c=40)
179 >>> s = s1 + s2
178 >>> s = s1 + s2
180 >>> s
179 >>> s
181 {'a': 10, 'c': 40, 'b': 30}
180 {'a': 10, 'c': 40, 'b': 30}
182 """
181 """
183 sout = self.copy()
182 sout = self.copy()
184 sout.merge(other)
183 sout.merge(other)
185 return sout
184 return sout
186
185
187 def __sub__(self,other):
186 def __sub__(self,other):
188 """s1 - s2 -> remove keys in s2 from s1.
187 """s1 - s2 -> remove keys in s2 from s1.
189
188
190 Examples
189 Examples
191 --------
190 --------
192
191
193 >>> s1 = Struct(a=10,b=30)
192 >>> s1 = Struct(a=10,b=30)
194 >>> s2 = Struct(a=40)
193 >>> s2 = Struct(a=40)
195 >>> s = s1 - s2
194 >>> s = s1 - s2
196 >>> s
195 >>> s
197 {'b': 30}
196 {'b': 30}
198 """
197 """
199 sout = self.copy()
198 sout = self.copy()
200 sout -= other
199 sout -= other
201 return sout
200 return sout
202
201
203 def __isub__(self,other):
202 def __isub__(self,other):
204 """Inplace remove keys from self that are in other.
203 """Inplace remove keys from self that are in other.
205
204
206 Examples
205 Examples
207 --------
206 --------
208
207
209 >>> s1 = Struct(a=10,b=30)
208 >>> s1 = Struct(a=10,b=30)
210 >>> s2 = Struct(a=40)
209 >>> s2 = Struct(a=40)
211 >>> s1 -= s2
210 >>> s1 -= s2
212 >>> s1
211 >>> s1
213 {'b': 30}
212 {'b': 30}
214 """
213 """
215 for k in other.keys():
214 for k in other.keys():
216 if self.has_key(k):
215 if self.has_key(k):
217 del self[k]
216 del self[k]
218 return self
217 return self
219
218
220 def __dict_invert(self, data):
219 def __dict_invert(self, data):
221 """Helper function for merge.
220 """Helper function for merge.
222
221
223 Takes a dictionary whose values are lists and returns a dict with
222 Takes a dictionary whose values are lists and returns a dict with
224 the elements of each list as keys and the original keys as values.
223 the elements of each list as keys and the original keys as values.
225 """
224 """
226 outdict = {}
225 outdict = {}
227 for k,lst in data.items():
226 for k,lst in data.items():
228 if isinstance(lst, str):
227 if isinstance(lst, str):
229 lst = lst.split()
228 lst = lst.split()
230 for entry in lst:
229 for entry in lst:
231 outdict[entry] = k
230 outdict[entry] = k
232 return outdict
231 return outdict
233
232
234 def dict(self):
233 def dict(self):
235 return self
234 return self
236
235
237 def copy(self):
236 def copy(self):
238 """Return a copy as a Struct.
237 """Return a copy as a Struct.
239
238
240 Examples
239 Examples
241 --------
240 --------
242
241
243 >>> s = Struct(a=10,b=30)
242 >>> s = Struct(a=10,b=30)
244 >>> s2 = s.copy()
243 >>> s2 = s.copy()
245 >>> s2
244 >>> s2
246 {'a': 10, 'b': 30}
245 {'a': 10, 'b': 30}
247 >>> type(s2).__name__
246 >>> type(s2).__name__
248 'Struct'
247 'Struct'
249 """
248 """
250 return Struct(dict.copy(self))
249 return Struct(dict.copy(self))
251
250
252 def hasattr(self, key):
251 def hasattr(self, key):
253 """hasattr function available as a method.
252 """hasattr function available as a method.
254
253
255 Implemented like has_key.
254 Implemented like has_key.
256
255
257 Examples
256 Examples
258 --------
257 --------
259
258
260 >>> s = Struct(a=10)
259 >>> s = Struct(a=10)
261 >>> s.hasattr('a')
260 >>> s.hasattr('a')
262 True
261 True
263 >>> s.hasattr('b')
262 >>> s.hasattr('b')
264 False
263 False
265 >>> s.hasattr('get')
264 >>> s.hasattr('get')
266 False
265 False
267 """
266 """
268 return self.has_key(key)
267 return self.has_key(key)
269
268
270 def allow_new_attr(self, allow = True):
269 def allow_new_attr(self, allow = True):
271 """Set whether new attributes can be created in this Struct.
270 """Set whether new attributes can be created in this Struct.
272
271
273 This can be used to catch typos by verifying that the attribute user
272 This can be used to catch typos by verifying that the attribute user
274 tries to change already exists in this Struct.
273 tries to change already exists in this Struct.
275 """
274 """
276 object.__setattr__(self, '_allownew', allow)
275 object.__setattr__(self, '_allownew', allow)
277
276
278 def merge(self, __loc_data__=None, __conflict_solve=None, **kw):
277 def merge(self, __loc_data__=None, __conflict_solve=None, **kw):
279 """Merge two Structs with customizable conflict resolution.
278 """Merge two Structs with customizable conflict resolution.
280
279
281 This is similar to :meth:`update`, but much more flexible. First, a
280 This is similar to :meth:`update`, but much more flexible. First, a
282 dict is made from data+key=value pairs. When merging this dict with
281 dict is made from data+key=value pairs. When merging this dict with
283 the Struct S, the optional dictionary 'conflict' is used to decide
282 the Struct S, the optional dictionary 'conflict' is used to decide
284 what to do.
283 what to do.
285
284
286 If conflict is not given, the default behavior is to preserve any keys
285 If conflict is not given, the default behavior is to preserve any keys
287 with their current value (the opposite of the :meth:`update` method's
286 with their current value (the opposite of the :meth:`update` method's
288 behavior).
287 behavior).
289
288
290 Parameters
289 Parameters
291 ----------
290 ----------
292 __loc_data : dict, Struct
291 __loc_data : dict, Struct
293 The data to merge into self
292 The data to merge into self
294 __conflict_solve : dict
293 __conflict_solve : dict
295 The conflict policy dict. The keys are binary functions used to
294 The conflict policy dict. The keys are binary functions used to
296 resolve the conflict and the values are lists of strings naming
295 resolve the conflict and the values are lists of strings naming
297 the keys the conflict resolution function applies to. Instead of
296 the keys the conflict resolution function applies to. Instead of
298 a list of strings a space separated string can be used, like
297 a list of strings a space separated string can be used, like
299 'a b c'.
298 'a b c'.
300 kw : dict
299 kw : dict
301 Additional key, value pairs to merge in
300 Additional key, value pairs to merge in
302
301
303 Notes
302 Notes
304 -----
303 -----
305
304
306 The `__conflict_solve` dict is a dictionary of binary functions which will be used to
305 The `__conflict_solve` dict is a dictionary of binary functions which will be used to
307 solve key conflicts. Here is an example::
306 solve key conflicts. Here is an example::
308
307
309 __conflict_solve = dict(
308 __conflict_solve = dict(
310 func1=['a','b','c'],
309 func1=['a','b','c'],
311 func2=['d','e']
310 func2=['d','e']
312 )
311 )
313
312
314 In this case, the function :func:`func1` will be used to resolve
313 In this case, the function :func:`func1` will be used to resolve
315 keys 'a', 'b' and 'c' and the function :func:`func2` will be used for
314 keys 'a', 'b' and 'c' and the function :func:`func2` will be used for
316 keys 'd' and 'e'. This could also be written as::
315 keys 'd' and 'e'. This could also be written as::
317
316
318 __conflict_solve = dict(func1='a b c',func2='d e')
317 __conflict_solve = dict(func1='a b c',func2='d e')
319
318
320 These functions will be called for each key they apply to with the
319 These functions will be called for each key they apply to with the
321 form::
320 form::
322
321
323 func1(self['a'], other['a'])
322 func1(self['a'], other['a'])
324
323
325 The return value is used as the final merged value.
324 The return value is used as the final merged value.
326
325
327 As a convenience, merge() provides five (the most commonly needed)
326 As a convenience, merge() provides five (the most commonly needed)
328 pre-defined policies: preserve, update, add, add_flip and add_s. The
327 pre-defined policies: preserve, update, add, add_flip and add_s. The
329 easiest explanation is their implementation::
328 easiest explanation is their implementation::
330
329
331 preserve = lambda old,new: old
330 preserve = lambda old,new: old
332 update = lambda old,new: new
331 update = lambda old,new: new
333 add = lambda old,new: old + new
332 add = lambda old,new: old + new
334 add_flip = lambda old,new: new + old # note change of order!
333 add_flip = lambda old,new: new + old # note change of order!
335 add_s = lambda old,new: old + ' ' + new # only for str!
334 add_s = lambda old,new: old + ' ' + new # only for str!
336
335
337 You can use those four words (as strings) as keys instead
336 You can use those four words (as strings) as keys instead
338 of defining them as functions, and the merge method will substitute
337 of defining them as functions, and the merge method will substitute
339 the appropriate functions for you.
338 the appropriate functions for you.
340
339
341 For more complicated conflict resolution policies, you still need to
340 For more complicated conflict resolution policies, you still need to
342 construct your own functions.
341 construct your own functions.
343
342
344 Examples
343 Examples
345 --------
344 --------
346
345
347 This show the default policy:
346 This show the default policy:
348
347
349 >>> s = Struct(a=10,b=30)
348 >>> s = Struct(a=10,b=30)
350 >>> s2 = Struct(a=20,c=40)
349 >>> s2 = Struct(a=20,c=40)
351 >>> s.merge(s2)
350 >>> s.merge(s2)
352 >>> s
351 >>> s
353 {'a': 10, 'c': 40, 'b': 30}
352 {'a': 10, 'c': 40, 'b': 30}
354
353
355 Now, show how to specify a conflict dict:
354 Now, show how to specify a conflict dict:
356
355
357 >>> s = Struct(a=10,b=30)
356 >>> s = Struct(a=10,b=30)
358 >>> s2 = Struct(a=20,b=40)
357 >>> s2 = Struct(a=20,b=40)
359 >>> conflict = {'update':'a','add':'b'}
358 >>> conflict = {'update':'a','add':'b'}
360 >>> s.merge(s2,conflict)
359 >>> s.merge(s2,conflict)
361 >>> s
360 >>> s
362 {'a': 20, 'b': 70}
361 {'a': 20, 'b': 70}
363 """
362 """
364
363
365 data_dict = dict(__loc_data__,**kw)
364 data_dict = dict(__loc_data__,**kw)
366
365
367 # policies for conflict resolution: two argument functions which return
366 # policies for conflict resolution: two argument functions which return
368 # the value that will go in the new struct
367 # the value that will go in the new struct
369 preserve = lambda old,new: old
368 preserve = lambda old,new: old
370 update = lambda old,new: new
369 update = lambda old,new: new
371 add = lambda old,new: old + new
370 add = lambda old,new: old + new
372 add_flip = lambda old,new: new + old # note change of order!
371 add_flip = lambda old,new: new + old # note change of order!
373 add_s = lambda old,new: old + ' ' + new
372 add_s = lambda old,new: old + ' ' + new
374
373
375 # default policy is to keep current keys when there's a conflict
374 # default policy is to keep current keys when there's a conflict
376 conflict_solve = list2dict2(self.keys(), default = preserve)
375 conflict_solve = list2dict2(self.keys(), default = preserve)
377
376
378 # the conflict_solve dictionary is given by the user 'inverted': we
377 # the conflict_solve dictionary is given by the user 'inverted': we
379 # need a name-function mapping, it comes as a function -> names
378 # need a name-function mapping, it comes as a function -> names
380 # dict. Make a local copy (b/c we'll make changes), replace user
379 # dict. Make a local copy (b/c we'll make changes), replace user
381 # strings for the three builtin policies and invert it.
380 # strings for the three builtin policies and invert it.
382 if __conflict_solve:
381 if __conflict_solve:
383 inv_conflict_solve_user = __conflict_solve.copy()
382 inv_conflict_solve_user = __conflict_solve.copy()
384 for name, func in [('preserve',preserve), ('update',update),
383 for name, func in [('preserve',preserve), ('update',update),
385 ('add',add), ('add_flip',add_flip),
384 ('add',add), ('add_flip',add_flip),
386 ('add_s',add_s)]:
385 ('add_s',add_s)]:
387 if name in inv_conflict_solve_user.keys():
386 if name in inv_conflict_solve_user.keys():
388 inv_conflict_solve_user[func] = inv_conflict_solve_user[name]
387 inv_conflict_solve_user[func] = inv_conflict_solve_user[name]
389 del inv_conflict_solve_user[name]
388 del inv_conflict_solve_user[name]
390 conflict_solve.update(self.__dict_invert(inv_conflict_solve_user))
389 conflict_solve.update(self.__dict_invert(inv_conflict_solve_user))
391 for key in data_dict:
390 for key in data_dict:
392 if key not in self:
391 if key not in self:
393 self[key] = data_dict[key]
392 self[key] = data_dict[key]
394 else:
393 else:
395 self[key] = conflict_solve[key](self[key],data_dict[key])
394 self[key] = conflict_solve[key](self[key],data_dict[key])
396
395
@@ -1,143 +1,142 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 The IPython Core Notification Center.
3 The IPython Core Notification Center.
5
4
6 See docs/source/development/notification_blueprint.txt for an overview of the
5 See docs/source/development/notification_blueprint.txt for an overview of the
7 notification module.
6 notification module.
8
7
9 Authors:
8 Authors:
10
9
11 * Barry Wark
10 * Barry Wark
12 * Brian Granger
11 * Brian Granger
13 """
12 """
14
13
15 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
16 # Copyright (C) 2008-2009 The IPython Development Team
15 # Copyright (C) 2008-2009 The IPython Development Team
17 #
16 #
18 # Distributed under the terms of the BSD License. The full license is in
17 # Distributed under the terms of the BSD License. The full license is in
19 # the file COPYING, distributed as part of this software.
18 # the file COPYING, distributed as part of this software.
20 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
21
20
22 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
23 # Code
22 # Code
24 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
25
24
26
25
27 class NotificationError(Exception):
26 class NotificationError(Exception):
28 pass
27 pass
29
28
30
29
31 class NotificationCenter(object):
30 class NotificationCenter(object):
32 """Synchronous notification center.
31 """Synchronous notification center.
33
32
34 Examples
33 Examples
35 --------
34 --------
36 Here is a simple example of how to use this::
35 Here is a simple example of how to use this::
37
36
38 import IPython.util.notification as notification
37 import IPython.util.notification as notification
39 def callback(ntype, theSender, args={}):
38 def callback(ntype, theSender, args={}):
40 print ntype,theSender,args
39 print ntype,theSender,args
41
40
42 notification.sharedCenter.add_observer(callback, 'NOTIFICATION_TYPE', None)
41 notification.sharedCenter.add_observer(callback, 'NOTIFICATION_TYPE', None)
43 notification.sharedCenter.post_notification('NOTIFICATION_TYPE', object()) # doctest:+ELLIPSIS
42 notification.sharedCenter.post_notification('NOTIFICATION_TYPE', object()) # doctest:+ELLIPSIS
44 NOTIFICATION_TYPE ...
43 NOTIFICATION_TYPE ...
45 """
44 """
46 def __init__(self):
45 def __init__(self):
47 super(NotificationCenter, self).__init__()
46 super(NotificationCenter, self).__init__()
48 self._init_observers()
47 self._init_observers()
49
48
50 def _init_observers(self):
49 def _init_observers(self):
51 """Initialize observer storage"""
50 """Initialize observer storage"""
52
51
53 self.registered_types = set() #set of types that are observed
52 self.registered_types = set() #set of types that are observed
54 self.registered_senders = set() #set of senders that are observed
53 self.registered_senders = set() #set of senders that are observed
55 self.observers = {} #map (type,sender) => callback (callable)
54 self.observers = {} #map (type,sender) => callback (callable)
56
55
57 def post_notification(self, ntype, sender, *args, **kwargs):
56 def post_notification(self, ntype, sender, *args, **kwargs):
58 """Post notification to all registered observers.
57 """Post notification to all registered observers.
59
58
60 The registered callback will be called as::
59 The registered callback will be called as::
61
60
62 callback(ntype, sender, *args, **kwargs)
61 callback(ntype, sender, *args, **kwargs)
63
62
64 Parameters
63 Parameters
65 ----------
64 ----------
66 ntype : hashable
65 ntype : hashable
67 The notification type.
66 The notification type.
68 sender : hashable
67 sender : hashable
69 The object sending the notification.
68 The object sending the notification.
70 *args : tuple
69 *args : tuple
71 The positional arguments to be passed to the callback.
70 The positional arguments to be passed to the callback.
72 **kwargs : dict
71 **kwargs : dict
73 The keyword argument to be passed to the callback.
72 The keyword argument to be passed to the callback.
74
73
75 Notes
74 Notes
76 -----
75 -----
77 * If no registered observers, performance is O(1).
76 * If no registered observers, performance is O(1).
78 * Notificaiton order is undefined.
77 * Notificaiton order is undefined.
79 * Notifications are posted synchronously.
78 * Notifications are posted synchronously.
80 """
79 """
81
80
82 if(ntype==None or sender==None):
81 if(ntype==None or sender==None):
83 raise NotificationError(
82 raise NotificationError(
84 "Notification type and sender are required.")
83 "Notification type and sender are required.")
85
84
86 # If there are no registered observers for the type/sender pair
85 # If there are no registered observers for the type/sender pair
87 if((ntype not in self.registered_types and
86 if((ntype not in self.registered_types and
88 None not in self.registered_types) or
87 None not in self.registered_types) or
89 (sender not in self.registered_senders and
88 (sender not in self.registered_senders and
90 None not in self.registered_senders)):
89 None not in self.registered_senders)):
91 return
90 return
92
91
93 for o in self._observers_for_notification(ntype, sender):
92 for o in self._observers_for_notification(ntype, sender):
94 o(ntype, sender, *args, **kwargs)
93 o(ntype, sender, *args, **kwargs)
95
94
96 def _observers_for_notification(self, ntype, sender):
95 def _observers_for_notification(self, ntype, sender):
97 """Find all registered observers that should recieve notification"""
96 """Find all registered observers that should recieve notification"""
98
97
99 keys = (
98 keys = (
100 (ntype,sender),
99 (ntype,sender),
101 (ntype, None),
100 (ntype, None),
102 (None, sender),
101 (None, sender),
103 (None,None)
102 (None,None)
104 )
103 )
105
104
106 obs = set()
105 obs = set()
107 for k in keys:
106 for k in keys:
108 obs.update(self.observers.get(k, set()))
107 obs.update(self.observers.get(k, set()))
109
108
110 return obs
109 return obs
111
110
112 def add_observer(self, callback, ntype, sender):
111 def add_observer(self, callback, ntype, sender):
113 """Add an observer callback to this notification center.
112 """Add an observer callback to this notification center.
114
113
115 The given callback will be called upon posting of notifications of
114 The given callback will be called upon posting of notifications of
116 the given type/sender and will receive any additional arguments passed
115 the given type/sender and will receive any additional arguments passed
117 to post_notification.
116 to post_notification.
118
117
119 Parameters
118 Parameters
120 ----------
119 ----------
121 callback : callable
120 callback : callable
122 The callable that will be called by :meth:`post_notification`
121 The callable that will be called by :meth:`post_notification`
123 as ``callback(ntype, sender, *args, **kwargs)
122 as ``callback(ntype, sender, *args, **kwargs)
124 ntype : hashable
123 ntype : hashable
125 The notification type. If None, all notifications from sender
124 The notification type. If None, all notifications from sender
126 will be posted.
125 will be posted.
127 sender : hashable
126 sender : hashable
128 The notification sender. If None, all notifications of ntype
127 The notification sender. If None, all notifications of ntype
129 will be posted.
128 will be posted.
130 """
129 """
131 assert(callback != None)
130 assert(callback != None)
132 self.registered_types.add(ntype)
131 self.registered_types.add(ntype)
133 self.registered_senders.add(sender)
132 self.registered_senders.add(sender)
134 self.observers.setdefault((ntype,sender), set()).add(callback)
133 self.observers.setdefault((ntype,sender), set()).add(callback)
135
134
136 def remove_all_observers(self):
135 def remove_all_observers(self):
137 """Removes all observers from this notification center"""
136 """Removes all observers from this notification center"""
138
137
139 self._init_observers()
138 self._init_observers()
140
139
141
140
142
141
143 shared_center = NotificationCenter()
142 shared_center = NotificationCenter()
@@ -1,70 +1,69 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 Context managers for adding things to sys.path temporarily.
3 Context managers for adding things to sys.path temporarily.
5
4
6 Authors:
5 Authors:
7
6
8 * Brian Granger
7 * Brian Granger
9 """
8 """
10
9
11 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
12 # Copyright (C) 2008-2009 The IPython Development Team
11 # Copyright (C) 2008-2009 The IPython Development Team
13 #
12 #
14 # Distributed under the terms of the BSD License. The full license is in
13 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
14 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
17
16
18 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
19 # Imports
18 # Imports
20 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
21
20
22 import sys
21 import sys
23
22
24 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
25 # Code
24 # Code
26 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
27
26
28 class appended_to_syspath(object):
27 class appended_to_syspath(object):
29 """A context for appending a directory to sys.path for a second."""
28 """A context for appending a directory to sys.path for a second."""
30
29
31 def __init__(self, dir):
30 def __init__(self, dir):
32 self.dir = dir
31 self.dir = dir
33
32
34 def __enter__(self):
33 def __enter__(self):
35 if self.dir not in sys.path:
34 if self.dir not in sys.path:
36 sys.path.append(self.dir)
35 sys.path.append(self.dir)
37 self.added = True
36 self.added = True
38 else:
37 else:
39 self.added = False
38 self.added = False
40
39
41 def __exit__(self, type, value, traceback):
40 def __exit__(self, type, value, traceback):
42 if self.added:
41 if self.added:
43 try:
42 try:
44 sys.path.remove(self.dir)
43 sys.path.remove(self.dir)
45 except ValueError:
44 except ValueError:
46 pass
45 pass
47 # Returning False causes any exceptions to be re-raised.
46 # Returning False causes any exceptions to be re-raised.
48 return False
47 return False
49
48
50 class prepended_to_syspath(object):
49 class prepended_to_syspath(object):
51 """A context for prepending a directory to sys.path for a second."""
50 """A context for prepending a directory to sys.path for a second."""
52
51
53 def __init__(self, dir):
52 def __init__(self, dir):
54 self.dir = dir
53 self.dir = dir
55
54
56 def __enter__(self):
55 def __enter__(self):
57 if self.dir not in sys.path:
56 if self.dir not in sys.path:
58 sys.path.insert(0,self.dir)
57 sys.path.insert(0,self.dir)
59 self.added = True
58 self.added = True
60 else:
59 else:
61 self.added = False
60 self.added = False
62
61
63 def __exit__(self, type, value, traceback):
62 def __exit__(self, type, value, traceback):
64 if self.added:
63 if self.added:
65 try:
64 try:
66 sys.path.remove(self.dir)
65 sys.path.remove(self.dir)
67 except ValueError:
66 except ValueError:
68 pass
67 pass
69 # Returning False causes any exceptions to be re-raised.
68 # Returning False causes any exceptions to be re-raised.
70 return False
69 return False
@@ -1,847 +1,846 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 Tests for IPython.utils.traitlets.
3 Tests for IPython.utils.traitlets.
5
4
6 Authors:
5 Authors:
7
6
8 * Brian Granger
7 * Brian Granger
9 * Enthought, Inc. Some of the code in this file comes from enthought.traits
8 * Enthought, Inc. Some of the code in this file comes from enthought.traits
10 and is licensed under the BSD license. Also, many of the ideas also come
9 and is licensed under the BSD license. Also, many of the ideas also come
11 from enthought.traits even though our implementation is very different.
10 from enthought.traits even though our implementation is very different.
12 """
11 """
13
12
14 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
15 # Copyright (C) 2008-2009 The IPython Development Team
14 # Copyright (C) 2008-2009 The IPython Development Team
16 #
15 #
17 # Distributed under the terms of the BSD License. The full license is in
16 # Distributed under the terms of the BSD License. The full license is in
18 # the file COPYING, distributed as part of this software.
17 # the file COPYING, distributed as part of this software.
19 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
20
19
21 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
22 # Imports
21 # Imports
23 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
24
23
25 import sys
24 import sys
26 from unittest import TestCase
25 from unittest import TestCase
27
26
28 from IPython.utils.traitlets import (
27 from IPython.utils.traitlets import (
29 HasTraits, MetaHasTraits, TraitType, Any, CBytes,
28 HasTraits, MetaHasTraits, TraitType, Any, CBytes,
30 Int, Long, Float, Complex, Bytes, Unicode, TraitError,
29 Int, Long, Float, Complex, Bytes, Unicode, TraitError,
31 Undefined, Type, This, Instance, TCPAddress, List, Tuple,
30 Undefined, Type, This, Instance, TCPAddress, List, Tuple,
32 ObjectName, DottedObjectName
31 ObjectName, DottedObjectName
33 )
32 )
34
33
35
34
36 #-----------------------------------------------------------------------------
35 #-----------------------------------------------------------------------------
37 # Helper classes for testing
36 # Helper classes for testing
38 #-----------------------------------------------------------------------------
37 #-----------------------------------------------------------------------------
39
38
40
39
41 class HasTraitsStub(HasTraits):
40 class HasTraitsStub(HasTraits):
42
41
43 def _notify_trait(self, name, old, new):
42 def _notify_trait(self, name, old, new):
44 self._notify_name = name
43 self._notify_name = name
45 self._notify_old = old
44 self._notify_old = old
46 self._notify_new = new
45 self._notify_new = new
47
46
48
47
49 #-----------------------------------------------------------------------------
48 #-----------------------------------------------------------------------------
50 # Test classes
49 # Test classes
51 #-----------------------------------------------------------------------------
50 #-----------------------------------------------------------------------------
52
51
53
52
54 class TestTraitType(TestCase):
53 class TestTraitType(TestCase):
55
54
56 def test_get_undefined(self):
55 def test_get_undefined(self):
57 class A(HasTraits):
56 class A(HasTraits):
58 a = TraitType
57 a = TraitType
59 a = A()
58 a = A()
60 self.assertEquals(a.a, Undefined)
59 self.assertEquals(a.a, Undefined)
61
60
62 def test_set(self):
61 def test_set(self):
63 class A(HasTraitsStub):
62 class A(HasTraitsStub):
64 a = TraitType
63 a = TraitType
65
64
66 a = A()
65 a = A()
67 a.a = 10
66 a.a = 10
68 self.assertEquals(a.a, 10)
67 self.assertEquals(a.a, 10)
69 self.assertEquals(a._notify_name, 'a')
68 self.assertEquals(a._notify_name, 'a')
70 self.assertEquals(a._notify_old, Undefined)
69 self.assertEquals(a._notify_old, Undefined)
71 self.assertEquals(a._notify_new, 10)
70 self.assertEquals(a._notify_new, 10)
72
71
73 def test_validate(self):
72 def test_validate(self):
74 class MyTT(TraitType):
73 class MyTT(TraitType):
75 def validate(self, inst, value):
74 def validate(self, inst, value):
76 return -1
75 return -1
77 class A(HasTraitsStub):
76 class A(HasTraitsStub):
78 tt = MyTT
77 tt = MyTT
79
78
80 a = A()
79 a = A()
81 a.tt = 10
80 a.tt = 10
82 self.assertEquals(a.tt, -1)
81 self.assertEquals(a.tt, -1)
83
82
84 def test_default_validate(self):
83 def test_default_validate(self):
85 class MyIntTT(TraitType):
84 class MyIntTT(TraitType):
86 def validate(self, obj, value):
85 def validate(self, obj, value):
87 if isinstance(value, int):
86 if isinstance(value, int):
88 return value
87 return value
89 self.error(obj, value)
88 self.error(obj, value)
90 class A(HasTraits):
89 class A(HasTraits):
91 tt = MyIntTT(10)
90 tt = MyIntTT(10)
92 a = A()
91 a = A()
93 self.assertEquals(a.tt, 10)
92 self.assertEquals(a.tt, 10)
94
93
95 # Defaults are validated when the HasTraits is instantiated
94 # Defaults are validated when the HasTraits is instantiated
96 class B(HasTraits):
95 class B(HasTraits):
97 tt = MyIntTT('bad default')
96 tt = MyIntTT('bad default')
98 self.assertRaises(TraitError, B)
97 self.assertRaises(TraitError, B)
99
98
100 def test_is_valid_for(self):
99 def test_is_valid_for(self):
101 class MyTT(TraitType):
100 class MyTT(TraitType):
102 def is_valid_for(self, value):
101 def is_valid_for(self, value):
103 return True
102 return True
104 class A(HasTraits):
103 class A(HasTraits):
105 tt = MyTT
104 tt = MyTT
106
105
107 a = A()
106 a = A()
108 a.tt = 10
107 a.tt = 10
109 self.assertEquals(a.tt, 10)
108 self.assertEquals(a.tt, 10)
110
109
111 def test_value_for(self):
110 def test_value_for(self):
112 class MyTT(TraitType):
111 class MyTT(TraitType):
113 def value_for(self, value):
112 def value_for(self, value):
114 return 20
113 return 20
115 class A(HasTraits):
114 class A(HasTraits):
116 tt = MyTT
115 tt = MyTT
117
116
118 a = A()
117 a = A()
119 a.tt = 10
118 a.tt = 10
120 self.assertEquals(a.tt, 20)
119 self.assertEquals(a.tt, 20)
121
120
122 def test_info(self):
121 def test_info(self):
123 class A(HasTraits):
122 class A(HasTraits):
124 tt = TraitType
123 tt = TraitType
125 a = A()
124 a = A()
126 self.assertEquals(A.tt.info(), 'any value')
125 self.assertEquals(A.tt.info(), 'any value')
127
126
128 def test_error(self):
127 def test_error(self):
129 class A(HasTraits):
128 class A(HasTraits):
130 tt = TraitType
129 tt = TraitType
131 a = A()
130 a = A()
132 self.assertRaises(TraitError, A.tt.error, a, 10)
131 self.assertRaises(TraitError, A.tt.error, a, 10)
133
132
134 def test_dynamic_initializer(self):
133 def test_dynamic_initializer(self):
135 class A(HasTraits):
134 class A(HasTraits):
136 x = Int(10)
135 x = Int(10)
137 def _x_default(self):
136 def _x_default(self):
138 return 11
137 return 11
139 class B(A):
138 class B(A):
140 x = Int(20)
139 x = Int(20)
141 class C(A):
140 class C(A):
142 def _x_default(self):
141 def _x_default(self):
143 return 21
142 return 21
144
143
145 a = A()
144 a = A()
146 self.assertEquals(a._trait_values, {})
145 self.assertEquals(a._trait_values, {})
147 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
146 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
148 self.assertEquals(a.x, 11)
147 self.assertEquals(a.x, 11)
149 self.assertEquals(a._trait_values, {'x': 11})
148 self.assertEquals(a._trait_values, {'x': 11})
150 b = B()
149 b = B()
151 self.assertEquals(b._trait_values, {'x': 20})
150 self.assertEquals(b._trait_values, {'x': 20})
152 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
151 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
153 self.assertEquals(b.x, 20)
152 self.assertEquals(b.x, 20)
154 c = C()
153 c = C()
155 self.assertEquals(c._trait_values, {})
154 self.assertEquals(c._trait_values, {})
156 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
155 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
157 self.assertEquals(c.x, 21)
156 self.assertEquals(c.x, 21)
158 self.assertEquals(c._trait_values, {'x': 21})
157 self.assertEquals(c._trait_values, {'x': 21})
159 # Ensure that the base class remains unmolested when the _default
158 # Ensure that the base class remains unmolested when the _default
160 # initializer gets overridden in a subclass.
159 # initializer gets overridden in a subclass.
161 a = A()
160 a = A()
162 c = C()
161 c = C()
163 self.assertEquals(a._trait_values, {})
162 self.assertEquals(a._trait_values, {})
164 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
163 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
165 self.assertEquals(a.x, 11)
164 self.assertEquals(a.x, 11)
166 self.assertEquals(a._trait_values, {'x': 11})
165 self.assertEquals(a._trait_values, {'x': 11})
167
166
168
167
169
168
170 class TestHasTraitsMeta(TestCase):
169 class TestHasTraitsMeta(TestCase):
171
170
172 def test_metaclass(self):
171 def test_metaclass(self):
173 self.assertEquals(type(HasTraits), MetaHasTraits)
172 self.assertEquals(type(HasTraits), MetaHasTraits)
174
173
175 class A(HasTraits):
174 class A(HasTraits):
176 a = Int
175 a = Int
177
176
178 a = A()
177 a = A()
179 self.assertEquals(type(a.__class__), MetaHasTraits)
178 self.assertEquals(type(a.__class__), MetaHasTraits)
180 self.assertEquals(a.a,0)
179 self.assertEquals(a.a,0)
181 a.a = 10
180 a.a = 10
182 self.assertEquals(a.a,10)
181 self.assertEquals(a.a,10)
183
182
184 class B(HasTraits):
183 class B(HasTraits):
185 b = Int()
184 b = Int()
186
185
187 b = B()
186 b = B()
188 self.assertEquals(b.b,0)
187 self.assertEquals(b.b,0)
189 b.b = 10
188 b.b = 10
190 self.assertEquals(b.b,10)
189 self.assertEquals(b.b,10)
191
190
192 class C(HasTraits):
191 class C(HasTraits):
193 c = Int(30)
192 c = Int(30)
194
193
195 c = C()
194 c = C()
196 self.assertEquals(c.c,30)
195 self.assertEquals(c.c,30)
197 c.c = 10
196 c.c = 10
198 self.assertEquals(c.c,10)
197 self.assertEquals(c.c,10)
199
198
200 def test_this_class(self):
199 def test_this_class(self):
201 class A(HasTraits):
200 class A(HasTraits):
202 t = This()
201 t = This()
203 tt = This()
202 tt = This()
204 class B(A):
203 class B(A):
205 tt = This()
204 tt = This()
206 ttt = This()
205 ttt = This()
207 self.assertEquals(A.t.this_class, A)
206 self.assertEquals(A.t.this_class, A)
208 self.assertEquals(B.t.this_class, A)
207 self.assertEquals(B.t.this_class, A)
209 self.assertEquals(B.tt.this_class, B)
208 self.assertEquals(B.tt.this_class, B)
210 self.assertEquals(B.ttt.this_class, B)
209 self.assertEquals(B.ttt.this_class, B)
211
210
212 class TestHasTraitsNotify(TestCase):
211 class TestHasTraitsNotify(TestCase):
213
212
214 def setUp(self):
213 def setUp(self):
215 self._notify1 = []
214 self._notify1 = []
216 self._notify2 = []
215 self._notify2 = []
217
216
218 def notify1(self, name, old, new):
217 def notify1(self, name, old, new):
219 self._notify1.append((name, old, new))
218 self._notify1.append((name, old, new))
220
219
221 def notify2(self, name, old, new):
220 def notify2(self, name, old, new):
222 self._notify2.append((name, old, new))
221 self._notify2.append((name, old, new))
223
222
224 def test_notify_all(self):
223 def test_notify_all(self):
225
224
226 class A(HasTraits):
225 class A(HasTraits):
227 a = Int
226 a = Int
228 b = Float
227 b = Float
229
228
230 a = A()
229 a = A()
231 a.on_trait_change(self.notify1)
230 a.on_trait_change(self.notify1)
232 a.a = 0
231 a.a = 0
233 self.assertEquals(len(self._notify1),0)
232 self.assertEquals(len(self._notify1),0)
234 a.b = 0.0
233 a.b = 0.0
235 self.assertEquals(len(self._notify1),0)
234 self.assertEquals(len(self._notify1),0)
236 a.a = 10
235 a.a = 10
237 self.assert_(('a',0,10) in self._notify1)
236 self.assert_(('a',0,10) in self._notify1)
238 a.b = 10.0
237 a.b = 10.0
239 self.assert_(('b',0.0,10.0) in self._notify1)
238 self.assert_(('b',0.0,10.0) in self._notify1)
240 self.assertRaises(TraitError,setattr,a,'a','bad string')
239 self.assertRaises(TraitError,setattr,a,'a','bad string')
241 self.assertRaises(TraitError,setattr,a,'b','bad string')
240 self.assertRaises(TraitError,setattr,a,'b','bad string')
242 self._notify1 = []
241 self._notify1 = []
243 a.on_trait_change(self.notify1,remove=True)
242 a.on_trait_change(self.notify1,remove=True)
244 a.a = 20
243 a.a = 20
245 a.b = 20.0
244 a.b = 20.0
246 self.assertEquals(len(self._notify1),0)
245 self.assertEquals(len(self._notify1),0)
247
246
248 def test_notify_one(self):
247 def test_notify_one(self):
249
248
250 class A(HasTraits):
249 class A(HasTraits):
251 a = Int
250 a = Int
252 b = Float
251 b = Float
253
252
254 a = A()
253 a = A()
255 a.on_trait_change(self.notify1, 'a')
254 a.on_trait_change(self.notify1, 'a')
256 a.a = 0
255 a.a = 0
257 self.assertEquals(len(self._notify1),0)
256 self.assertEquals(len(self._notify1),0)
258 a.a = 10
257 a.a = 10
259 self.assert_(('a',0,10) in self._notify1)
258 self.assert_(('a',0,10) in self._notify1)
260 self.assertRaises(TraitError,setattr,a,'a','bad string')
259 self.assertRaises(TraitError,setattr,a,'a','bad string')
261
260
262 def test_subclass(self):
261 def test_subclass(self):
263
262
264 class A(HasTraits):
263 class A(HasTraits):
265 a = Int
264 a = Int
266
265
267 class B(A):
266 class B(A):
268 b = Float
267 b = Float
269
268
270 b = B()
269 b = B()
271 self.assertEquals(b.a,0)
270 self.assertEquals(b.a,0)
272 self.assertEquals(b.b,0.0)
271 self.assertEquals(b.b,0.0)
273 b.a = 100
272 b.a = 100
274 b.b = 100.0
273 b.b = 100.0
275 self.assertEquals(b.a,100)
274 self.assertEquals(b.a,100)
276 self.assertEquals(b.b,100.0)
275 self.assertEquals(b.b,100.0)
277
276
278 def test_notify_subclass(self):
277 def test_notify_subclass(self):
279
278
280 class A(HasTraits):
279 class A(HasTraits):
281 a = Int
280 a = Int
282
281
283 class B(A):
282 class B(A):
284 b = Float
283 b = Float
285
284
286 b = B()
285 b = B()
287 b.on_trait_change(self.notify1, 'a')
286 b.on_trait_change(self.notify1, 'a')
288 b.on_trait_change(self.notify2, 'b')
287 b.on_trait_change(self.notify2, 'b')
289 b.a = 0
288 b.a = 0
290 b.b = 0.0
289 b.b = 0.0
291 self.assertEquals(len(self._notify1),0)
290 self.assertEquals(len(self._notify1),0)
292 self.assertEquals(len(self._notify2),0)
291 self.assertEquals(len(self._notify2),0)
293 b.a = 10
292 b.a = 10
294 b.b = 10.0
293 b.b = 10.0
295 self.assert_(('a',0,10) in self._notify1)
294 self.assert_(('a',0,10) in self._notify1)
296 self.assert_(('b',0.0,10.0) in self._notify2)
295 self.assert_(('b',0.0,10.0) in self._notify2)
297
296
298 def test_static_notify(self):
297 def test_static_notify(self):
299
298
300 class A(HasTraits):
299 class A(HasTraits):
301 a = Int
300 a = Int
302 _notify1 = []
301 _notify1 = []
303 def _a_changed(self, name, old, new):
302 def _a_changed(self, name, old, new):
304 self._notify1.append((name, old, new))
303 self._notify1.append((name, old, new))
305
304
306 a = A()
305 a = A()
307 a.a = 0
306 a.a = 0
308 # This is broken!!!
307 # This is broken!!!
309 self.assertEquals(len(a._notify1),0)
308 self.assertEquals(len(a._notify1),0)
310 a.a = 10
309 a.a = 10
311 self.assert_(('a',0,10) in a._notify1)
310 self.assert_(('a',0,10) in a._notify1)
312
311
313 class B(A):
312 class B(A):
314 b = Float
313 b = Float
315 _notify2 = []
314 _notify2 = []
316 def _b_changed(self, name, old, new):
315 def _b_changed(self, name, old, new):
317 self._notify2.append((name, old, new))
316 self._notify2.append((name, old, new))
318
317
319 b = B()
318 b = B()
320 b.a = 10
319 b.a = 10
321 b.b = 10.0
320 b.b = 10.0
322 self.assert_(('a',0,10) in b._notify1)
321 self.assert_(('a',0,10) in b._notify1)
323 self.assert_(('b',0.0,10.0) in b._notify2)
322 self.assert_(('b',0.0,10.0) in b._notify2)
324
323
325 def test_notify_args(self):
324 def test_notify_args(self):
326
325
327 def callback0():
326 def callback0():
328 self.cb = ()
327 self.cb = ()
329 def callback1(name):
328 def callback1(name):
330 self.cb = (name,)
329 self.cb = (name,)
331 def callback2(name, new):
330 def callback2(name, new):
332 self.cb = (name, new)
331 self.cb = (name, new)
333 def callback3(name, old, new):
332 def callback3(name, old, new):
334 self.cb = (name, old, new)
333 self.cb = (name, old, new)
335
334
336 class A(HasTraits):
335 class A(HasTraits):
337 a = Int
336 a = Int
338
337
339 a = A()
338 a = A()
340 a.on_trait_change(callback0, 'a')
339 a.on_trait_change(callback0, 'a')
341 a.a = 10
340 a.a = 10
342 self.assertEquals(self.cb,())
341 self.assertEquals(self.cb,())
343 a.on_trait_change(callback0, 'a', remove=True)
342 a.on_trait_change(callback0, 'a', remove=True)
344
343
345 a.on_trait_change(callback1, 'a')
344 a.on_trait_change(callback1, 'a')
346 a.a = 100
345 a.a = 100
347 self.assertEquals(self.cb,('a',))
346 self.assertEquals(self.cb,('a',))
348 a.on_trait_change(callback1, 'a', remove=True)
347 a.on_trait_change(callback1, 'a', remove=True)
349
348
350 a.on_trait_change(callback2, 'a')
349 a.on_trait_change(callback2, 'a')
351 a.a = 1000
350 a.a = 1000
352 self.assertEquals(self.cb,('a',1000))
351 self.assertEquals(self.cb,('a',1000))
353 a.on_trait_change(callback2, 'a', remove=True)
352 a.on_trait_change(callback2, 'a', remove=True)
354
353
355 a.on_trait_change(callback3, 'a')
354 a.on_trait_change(callback3, 'a')
356 a.a = 10000
355 a.a = 10000
357 self.assertEquals(self.cb,('a',1000,10000))
356 self.assertEquals(self.cb,('a',1000,10000))
358 a.on_trait_change(callback3, 'a', remove=True)
357 a.on_trait_change(callback3, 'a', remove=True)
359
358
360 self.assertEquals(len(a._trait_notifiers['a']),0)
359 self.assertEquals(len(a._trait_notifiers['a']),0)
361
360
362
361
363 class TestHasTraits(TestCase):
362 class TestHasTraits(TestCase):
364
363
365 def test_trait_names(self):
364 def test_trait_names(self):
366 class A(HasTraits):
365 class A(HasTraits):
367 i = Int
366 i = Int
368 f = Float
367 f = Float
369 a = A()
368 a = A()
370 self.assertEquals(a.trait_names(),['i','f'])
369 self.assertEquals(a.trait_names(),['i','f'])
371 self.assertEquals(A.class_trait_names(),['i','f'])
370 self.assertEquals(A.class_trait_names(),['i','f'])
372
371
373 def test_trait_metadata(self):
372 def test_trait_metadata(self):
374 class A(HasTraits):
373 class A(HasTraits):
375 i = Int(config_key='MY_VALUE')
374 i = Int(config_key='MY_VALUE')
376 a = A()
375 a = A()
377 self.assertEquals(a.trait_metadata('i','config_key'), 'MY_VALUE')
376 self.assertEquals(a.trait_metadata('i','config_key'), 'MY_VALUE')
378
377
379 def test_traits(self):
378 def test_traits(self):
380 class A(HasTraits):
379 class A(HasTraits):
381 i = Int
380 i = Int
382 f = Float
381 f = Float
383 a = A()
382 a = A()
384 self.assertEquals(a.traits(), dict(i=A.i, f=A.f))
383 self.assertEquals(a.traits(), dict(i=A.i, f=A.f))
385 self.assertEquals(A.class_traits(), dict(i=A.i, f=A.f))
384 self.assertEquals(A.class_traits(), dict(i=A.i, f=A.f))
386
385
387 def test_traits_metadata(self):
386 def test_traits_metadata(self):
388 class A(HasTraits):
387 class A(HasTraits):
389 i = Int(config_key='VALUE1', other_thing='VALUE2')
388 i = Int(config_key='VALUE1', other_thing='VALUE2')
390 f = Float(config_key='VALUE3', other_thing='VALUE2')
389 f = Float(config_key='VALUE3', other_thing='VALUE2')
391 j = Int(0)
390 j = Int(0)
392 a = A()
391 a = A()
393 self.assertEquals(a.traits(), dict(i=A.i, f=A.f, j=A.j))
392 self.assertEquals(a.traits(), dict(i=A.i, f=A.f, j=A.j))
394 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
393 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
395 self.assertEquals(traits, dict(i=A.i))
394 self.assertEquals(traits, dict(i=A.i))
396
395
397 # This passes, but it shouldn't because I am replicating a bug in
396 # This passes, but it shouldn't because I am replicating a bug in
398 # traits.
397 # traits.
399 traits = a.traits(config_key=lambda v: True)
398 traits = a.traits(config_key=lambda v: True)
400 self.assertEquals(traits, dict(i=A.i, f=A.f, j=A.j))
399 self.assertEquals(traits, dict(i=A.i, f=A.f, j=A.j))
401
400
402 def test_init(self):
401 def test_init(self):
403 class A(HasTraits):
402 class A(HasTraits):
404 i = Int()
403 i = Int()
405 x = Float()
404 x = Float()
406 a = A(i=1, x=10.0)
405 a = A(i=1, x=10.0)
407 self.assertEquals(a.i, 1)
406 self.assertEquals(a.i, 1)
408 self.assertEquals(a.x, 10.0)
407 self.assertEquals(a.x, 10.0)
409
408
410 #-----------------------------------------------------------------------------
409 #-----------------------------------------------------------------------------
411 # Tests for specific trait types
410 # Tests for specific trait types
412 #-----------------------------------------------------------------------------
411 #-----------------------------------------------------------------------------
413
412
414
413
415 class TestType(TestCase):
414 class TestType(TestCase):
416
415
417 def test_default(self):
416 def test_default(self):
418
417
419 class B(object): pass
418 class B(object): pass
420 class A(HasTraits):
419 class A(HasTraits):
421 klass = Type
420 klass = Type
422
421
423 a = A()
422 a = A()
424 self.assertEquals(a.klass, None)
423 self.assertEquals(a.klass, None)
425
424
426 a.klass = B
425 a.klass = B
427 self.assertEquals(a.klass, B)
426 self.assertEquals(a.klass, B)
428 self.assertRaises(TraitError, setattr, a, 'klass', 10)
427 self.assertRaises(TraitError, setattr, a, 'klass', 10)
429
428
430 def test_value(self):
429 def test_value(self):
431
430
432 class B(object): pass
431 class B(object): pass
433 class C(object): pass
432 class C(object): pass
434 class A(HasTraits):
433 class A(HasTraits):
435 klass = Type(B)
434 klass = Type(B)
436
435
437 a = A()
436 a = A()
438 self.assertEquals(a.klass, B)
437 self.assertEquals(a.klass, B)
439 self.assertRaises(TraitError, setattr, a, 'klass', C)
438 self.assertRaises(TraitError, setattr, a, 'klass', C)
440 self.assertRaises(TraitError, setattr, a, 'klass', object)
439 self.assertRaises(TraitError, setattr, a, 'klass', object)
441 a.klass = B
440 a.klass = B
442
441
443 def test_allow_none(self):
442 def test_allow_none(self):
444
443
445 class B(object): pass
444 class B(object): pass
446 class C(B): pass
445 class C(B): pass
447 class A(HasTraits):
446 class A(HasTraits):
448 klass = Type(B, allow_none=False)
447 klass = Type(B, allow_none=False)
449
448
450 a = A()
449 a = A()
451 self.assertEquals(a.klass, B)
450 self.assertEquals(a.klass, B)
452 self.assertRaises(TraitError, setattr, a, 'klass', None)
451 self.assertRaises(TraitError, setattr, a, 'klass', None)
453 a.klass = C
452 a.klass = C
454 self.assertEquals(a.klass, C)
453 self.assertEquals(a.klass, C)
455
454
456 def test_validate_klass(self):
455 def test_validate_klass(self):
457
456
458 class A(HasTraits):
457 class A(HasTraits):
459 klass = Type('no strings allowed')
458 klass = Type('no strings allowed')
460
459
461 self.assertRaises(ImportError, A)
460 self.assertRaises(ImportError, A)
462
461
463 class A(HasTraits):
462 class A(HasTraits):
464 klass = Type('rub.adub.Duck')
463 klass = Type('rub.adub.Duck')
465
464
466 self.assertRaises(ImportError, A)
465 self.assertRaises(ImportError, A)
467
466
468 def test_validate_default(self):
467 def test_validate_default(self):
469
468
470 class B(object): pass
469 class B(object): pass
471 class A(HasTraits):
470 class A(HasTraits):
472 klass = Type('bad default', B)
471 klass = Type('bad default', B)
473
472
474 self.assertRaises(ImportError, A)
473 self.assertRaises(ImportError, A)
475
474
476 class C(HasTraits):
475 class C(HasTraits):
477 klass = Type(None, B, allow_none=False)
476 klass = Type(None, B, allow_none=False)
478
477
479 self.assertRaises(TraitError, C)
478 self.assertRaises(TraitError, C)
480
479
481 def test_str_klass(self):
480 def test_str_klass(self):
482
481
483 class A(HasTraits):
482 class A(HasTraits):
484 klass = Type('IPython.utils.ipstruct.Struct')
483 klass = Type('IPython.utils.ipstruct.Struct')
485
484
486 from IPython.utils.ipstruct import Struct
485 from IPython.utils.ipstruct import Struct
487 a = A()
486 a = A()
488 a.klass = Struct
487 a.klass = Struct
489 self.assertEquals(a.klass, Struct)
488 self.assertEquals(a.klass, Struct)
490
489
491 self.assertRaises(TraitError, setattr, a, 'klass', 10)
490 self.assertRaises(TraitError, setattr, a, 'klass', 10)
492
491
493 class TestInstance(TestCase):
492 class TestInstance(TestCase):
494
493
495 def test_basic(self):
494 def test_basic(self):
496 class Foo(object): pass
495 class Foo(object): pass
497 class Bar(Foo): pass
496 class Bar(Foo): pass
498 class Bah(object): pass
497 class Bah(object): pass
499
498
500 class A(HasTraits):
499 class A(HasTraits):
501 inst = Instance(Foo)
500 inst = Instance(Foo)
502
501
503 a = A()
502 a = A()
504 self.assert_(a.inst is None)
503 self.assert_(a.inst is None)
505 a.inst = Foo()
504 a.inst = Foo()
506 self.assert_(isinstance(a.inst, Foo))
505 self.assert_(isinstance(a.inst, Foo))
507 a.inst = Bar()
506 a.inst = Bar()
508 self.assert_(isinstance(a.inst, Foo))
507 self.assert_(isinstance(a.inst, Foo))
509 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
508 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
510 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
509 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
511 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
510 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
512
511
513 def test_unique_default_value(self):
512 def test_unique_default_value(self):
514 class Foo(object): pass
513 class Foo(object): pass
515 class A(HasTraits):
514 class A(HasTraits):
516 inst = Instance(Foo,(),{})
515 inst = Instance(Foo,(),{})
517
516
518 a = A()
517 a = A()
519 b = A()
518 b = A()
520 self.assert_(a.inst is not b.inst)
519 self.assert_(a.inst is not b.inst)
521
520
522 def test_args_kw(self):
521 def test_args_kw(self):
523 class Foo(object):
522 class Foo(object):
524 def __init__(self, c): self.c = c
523 def __init__(self, c): self.c = c
525 class Bar(object): pass
524 class Bar(object): pass
526 class Bah(object):
525 class Bah(object):
527 def __init__(self, c, d):
526 def __init__(self, c, d):
528 self.c = c; self.d = d
527 self.c = c; self.d = d
529
528
530 class A(HasTraits):
529 class A(HasTraits):
531 inst = Instance(Foo, (10,))
530 inst = Instance(Foo, (10,))
532 a = A()
531 a = A()
533 self.assertEquals(a.inst.c, 10)
532 self.assertEquals(a.inst.c, 10)
534
533
535 class B(HasTraits):
534 class B(HasTraits):
536 inst = Instance(Bah, args=(10,), kw=dict(d=20))
535 inst = Instance(Bah, args=(10,), kw=dict(d=20))
537 b = B()
536 b = B()
538 self.assertEquals(b.inst.c, 10)
537 self.assertEquals(b.inst.c, 10)
539 self.assertEquals(b.inst.d, 20)
538 self.assertEquals(b.inst.d, 20)
540
539
541 class C(HasTraits):
540 class C(HasTraits):
542 inst = Instance(Foo)
541 inst = Instance(Foo)
543 c = C()
542 c = C()
544 self.assert_(c.inst is None)
543 self.assert_(c.inst is None)
545
544
546 def test_bad_default(self):
545 def test_bad_default(self):
547 class Foo(object): pass
546 class Foo(object): pass
548
547
549 class A(HasTraits):
548 class A(HasTraits):
550 inst = Instance(Foo, allow_none=False)
549 inst = Instance(Foo, allow_none=False)
551
550
552 self.assertRaises(TraitError, A)
551 self.assertRaises(TraitError, A)
553
552
554 def test_instance(self):
553 def test_instance(self):
555 class Foo(object): pass
554 class Foo(object): pass
556
555
557 def inner():
556 def inner():
558 class A(HasTraits):
557 class A(HasTraits):
559 inst = Instance(Foo())
558 inst = Instance(Foo())
560
559
561 self.assertRaises(TraitError, inner)
560 self.assertRaises(TraitError, inner)
562
561
563
562
564 class TestThis(TestCase):
563 class TestThis(TestCase):
565
564
566 def test_this_class(self):
565 def test_this_class(self):
567 class Foo(HasTraits):
566 class Foo(HasTraits):
568 this = This
567 this = This
569
568
570 f = Foo()
569 f = Foo()
571 self.assertEquals(f.this, None)
570 self.assertEquals(f.this, None)
572 g = Foo()
571 g = Foo()
573 f.this = g
572 f.this = g
574 self.assertEquals(f.this, g)
573 self.assertEquals(f.this, g)
575 self.assertRaises(TraitError, setattr, f, 'this', 10)
574 self.assertRaises(TraitError, setattr, f, 'this', 10)
576
575
577 def test_this_inst(self):
576 def test_this_inst(self):
578 class Foo(HasTraits):
577 class Foo(HasTraits):
579 this = This()
578 this = This()
580
579
581 f = Foo()
580 f = Foo()
582 f.this = Foo()
581 f.this = Foo()
583 self.assert_(isinstance(f.this, Foo))
582 self.assert_(isinstance(f.this, Foo))
584
583
585 def test_subclass(self):
584 def test_subclass(self):
586 class Foo(HasTraits):
585 class Foo(HasTraits):
587 t = This()
586 t = This()
588 class Bar(Foo):
587 class Bar(Foo):
589 pass
588 pass
590 f = Foo()
589 f = Foo()
591 b = Bar()
590 b = Bar()
592 f.t = b
591 f.t = b
593 b.t = f
592 b.t = f
594 self.assertEquals(f.t, b)
593 self.assertEquals(f.t, b)
595 self.assertEquals(b.t, f)
594 self.assertEquals(b.t, f)
596
595
597 def test_subclass_override(self):
596 def test_subclass_override(self):
598 class Foo(HasTraits):
597 class Foo(HasTraits):
599 t = This()
598 t = This()
600 class Bar(Foo):
599 class Bar(Foo):
601 t = This()
600 t = This()
602 f = Foo()
601 f = Foo()
603 b = Bar()
602 b = Bar()
604 f.t = b
603 f.t = b
605 self.assertEquals(f.t, b)
604 self.assertEquals(f.t, b)
606 self.assertRaises(TraitError, setattr, b, 't', f)
605 self.assertRaises(TraitError, setattr, b, 't', f)
607
606
608 class TraitTestBase(TestCase):
607 class TraitTestBase(TestCase):
609 """A best testing class for basic trait types."""
608 """A best testing class for basic trait types."""
610
609
611 def assign(self, value):
610 def assign(self, value):
612 self.obj.value = value
611 self.obj.value = value
613
612
614 def coerce(self, value):
613 def coerce(self, value):
615 return value
614 return value
616
615
617 def test_good_values(self):
616 def test_good_values(self):
618 if hasattr(self, '_good_values'):
617 if hasattr(self, '_good_values'):
619 for value in self._good_values:
618 for value in self._good_values:
620 self.assign(value)
619 self.assign(value)
621 self.assertEquals(self.obj.value, self.coerce(value))
620 self.assertEquals(self.obj.value, self.coerce(value))
622
621
623 def test_bad_values(self):
622 def test_bad_values(self):
624 if hasattr(self, '_bad_values'):
623 if hasattr(self, '_bad_values'):
625 for value in self._bad_values:
624 for value in self._bad_values:
626 self.assertRaises(TraitError, self.assign, value)
625 self.assertRaises(TraitError, self.assign, value)
627
626
628 def test_default_value(self):
627 def test_default_value(self):
629 if hasattr(self, '_default_value'):
628 if hasattr(self, '_default_value'):
630 self.assertEquals(self._default_value, self.obj.value)
629 self.assertEquals(self._default_value, self.obj.value)
631
630
632
631
633 class AnyTrait(HasTraits):
632 class AnyTrait(HasTraits):
634
633
635 value = Any
634 value = Any
636
635
637 class AnyTraitTest(TraitTestBase):
636 class AnyTraitTest(TraitTestBase):
638
637
639 obj = AnyTrait()
638 obj = AnyTrait()
640
639
641 _default_value = None
640 _default_value = None
642 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
641 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
643 _bad_values = []
642 _bad_values = []
644
643
645
644
646 class IntTrait(HasTraits):
645 class IntTrait(HasTraits):
647
646
648 value = Int(99)
647 value = Int(99)
649
648
650 class TestInt(TraitTestBase):
649 class TestInt(TraitTestBase):
651
650
652 obj = IntTrait()
651 obj = IntTrait()
653 _default_value = 99
652 _default_value = 99
654 _good_values = [10, -10]
653 _good_values = [10, -10]
655 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j, 10L,
654 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j, 10L,
656 -10L, 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
655 -10L, 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
657 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
656 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
658
657
659
658
660 class LongTrait(HasTraits):
659 class LongTrait(HasTraits):
661
660
662 value = Long(99L)
661 value = Long(99L)
663
662
664 class TestLong(TraitTestBase):
663 class TestLong(TraitTestBase):
665
664
666 obj = LongTrait()
665 obj = LongTrait()
667
666
668 _default_value = 99L
667 _default_value = 99L
669 _good_values = [10, -10, 10L, -10L]
668 _good_values = [10, -10, 10L, -10L]
670 _bad_values = ['ten', u'ten', [10], [10l], {'ten': 10},(10,),(10L,),
669 _bad_values = ['ten', u'ten', [10], [10l], {'ten': 10},(10,),(10L,),
671 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
670 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
672 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
671 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
673 u'-10.1']
672 u'-10.1']
674
673
675
674
676 class FloatTrait(HasTraits):
675 class FloatTrait(HasTraits):
677
676
678 value = Float(99.0)
677 value = Float(99.0)
679
678
680 class TestFloat(TraitTestBase):
679 class TestFloat(TraitTestBase):
681
680
682 obj = FloatTrait()
681 obj = FloatTrait()
683
682
684 _default_value = 99.0
683 _default_value = 99.0
685 _good_values = [10, -10, 10.1, -10.1]
684 _good_values = [10, -10, 10.1, -10.1]
686 _bad_values = [10L, -10L, 'ten', u'ten', [10], {'ten': 10},(10,), None,
685 _bad_values = [10L, -10L, 'ten', u'ten', [10], {'ten': 10},(10,), None,
687 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
686 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
688 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
687 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
689
688
690
689
691 class ComplexTrait(HasTraits):
690 class ComplexTrait(HasTraits):
692
691
693 value = Complex(99.0-99.0j)
692 value = Complex(99.0-99.0j)
694
693
695 class TestComplex(TraitTestBase):
694 class TestComplex(TraitTestBase):
696
695
697 obj = ComplexTrait()
696 obj = ComplexTrait()
698
697
699 _default_value = 99.0-99.0j
698 _default_value = 99.0-99.0j
700 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
699 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
701 10.1j, 10.1+10.1j, 10.1-10.1j]
700 10.1j, 10.1+10.1j, 10.1-10.1j]
702 _bad_values = [10L, -10L, u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
701 _bad_values = [10L, -10L, u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
703
702
704
703
705 class BytesTrait(HasTraits):
704 class BytesTrait(HasTraits):
706
705
707 value = Bytes(b'string')
706 value = Bytes(b'string')
708
707
709 class TestBytes(TraitTestBase):
708 class TestBytes(TraitTestBase):
710
709
711 obj = BytesTrait()
710 obj = BytesTrait()
712
711
713 _default_value = b'string'
712 _default_value = b'string'
714 _good_values = [b'10', b'-10', b'10L',
713 _good_values = [b'10', b'-10', b'10L',
715 b'-10L', b'10.1', b'-10.1', b'string']
714 b'-10L', b'10.1', b'-10.1', b'string']
716 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j, [10],
715 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j, [10],
717 ['ten'],{'ten': 10},(10,), None, u'string']
716 ['ten'],{'ten': 10},(10,), None, u'string']
718
717
719
718
720 class UnicodeTrait(HasTraits):
719 class UnicodeTrait(HasTraits):
721
720
722 value = Unicode(u'unicode')
721 value = Unicode(u'unicode')
723
722
724 class TestUnicode(TraitTestBase):
723 class TestUnicode(TraitTestBase):
725
724
726 obj = UnicodeTrait()
725 obj = UnicodeTrait()
727
726
728 _default_value = u'unicode'
727 _default_value = u'unicode'
729 _good_values = ['10', '-10', '10L', '-10L', '10.1',
728 _good_values = ['10', '-10', '10L', '-10L', '10.1',
730 '-10.1', '', u'', 'string', u'string', u"€"]
729 '-10.1', '', u'', 'string', u'string', u"€"]
731 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j,
730 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j,
732 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
731 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
733
732
734
733
735 class ObjectNameTrait(HasTraits):
734 class ObjectNameTrait(HasTraits):
736 value = ObjectName("abc")
735 value = ObjectName("abc")
737
736
738 class TestObjectName(TraitTestBase):
737 class TestObjectName(TraitTestBase):
739 obj = ObjectNameTrait()
738 obj = ObjectNameTrait()
740
739
741 _default_value = "abc"
740 _default_value = "abc"
742 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
741 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
743 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
742 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
744 object(), object]
743 object(), object]
745 if sys.version_info[0] < 3:
744 if sys.version_info[0] < 3:
746 _bad_values.append(u"þ")
745 _bad_values.append(u"þ")
747 else:
746 else:
748 _good_values.append(u"þ") # þ=1 is valid in Python 3 (PEP 3131).
747 _good_values.append(u"þ") # þ=1 is valid in Python 3 (PEP 3131).
749
748
750
749
751 class DottedObjectNameTrait(HasTraits):
750 class DottedObjectNameTrait(HasTraits):
752 value = DottedObjectName("a.b")
751 value = DottedObjectName("a.b")
753
752
754 class TestDottedObjectName(TraitTestBase):
753 class TestDottedObjectName(TraitTestBase):
755 obj = DottedObjectNameTrait()
754 obj = DottedObjectNameTrait()
756
755
757 _default_value = "a.b"
756 _default_value = "a.b"
758 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
757 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
759 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc."]
758 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc."]
760 if sys.version_info[0] < 3:
759 if sys.version_info[0] < 3:
761 _bad_values.append(u"t.þ")
760 _bad_values.append(u"t.þ")
762 else:
761 else:
763 _good_values.append(u"t.þ")
762 _good_values.append(u"t.þ")
764
763
765
764
766 class TCPAddressTrait(HasTraits):
765 class TCPAddressTrait(HasTraits):
767
766
768 value = TCPAddress()
767 value = TCPAddress()
769
768
770 class TestTCPAddress(TraitTestBase):
769 class TestTCPAddress(TraitTestBase):
771
770
772 obj = TCPAddressTrait()
771 obj = TCPAddressTrait()
773
772
774 _default_value = ('127.0.0.1',0)
773 _default_value = ('127.0.0.1',0)
775 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
774 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
776 _bad_values = [(0,0),('localhost',10.0),('localhost',-1)]
775 _bad_values = [(0,0),('localhost',10.0),('localhost',-1)]
777
776
778 class ListTrait(HasTraits):
777 class ListTrait(HasTraits):
779
778
780 value = List(Int)
779 value = List(Int)
781
780
782 class TestList(TraitTestBase):
781 class TestList(TraitTestBase):
783
782
784 obj = ListTrait()
783 obj = ListTrait()
785
784
786 _default_value = []
785 _default_value = []
787 _good_values = [[], [1], range(10)]
786 _good_values = [[], [1], range(10)]
788 _bad_values = [10, [1,'a'], 'a', (1,2)]
787 _bad_values = [10, [1,'a'], 'a', (1,2)]
789
788
790 class LenListTrait(HasTraits):
789 class LenListTrait(HasTraits):
791
790
792 value = List(Int, [0], minlen=1, maxlen=2)
791 value = List(Int, [0], minlen=1, maxlen=2)
793
792
794 class TestLenList(TraitTestBase):
793 class TestLenList(TraitTestBase):
795
794
796 obj = LenListTrait()
795 obj = LenListTrait()
797
796
798 _default_value = [0]
797 _default_value = [0]
799 _good_values = [[1], range(2)]
798 _good_values = [[1], range(2)]
800 _bad_values = [10, [1,'a'], 'a', (1,2), [], range(3)]
799 _bad_values = [10, [1,'a'], 'a', (1,2), [], range(3)]
801
800
802 class TupleTrait(HasTraits):
801 class TupleTrait(HasTraits):
803
802
804 value = Tuple(Int)
803 value = Tuple(Int)
805
804
806 class TestTupleTrait(TraitTestBase):
805 class TestTupleTrait(TraitTestBase):
807
806
808 obj = TupleTrait()
807 obj = TupleTrait()
809
808
810 _default_value = None
809 _default_value = None
811 _good_values = [(1,), None,(0,)]
810 _good_values = [(1,), None,(0,)]
812 _bad_values = [10, (1,2), [1],('a'), ()]
811 _bad_values = [10, (1,2), [1],('a'), ()]
813
812
814 def test_invalid_args(self):
813 def test_invalid_args(self):
815 self.assertRaises(TypeError, Tuple, 5)
814 self.assertRaises(TypeError, Tuple, 5)
816 self.assertRaises(TypeError, Tuple, default_value='hello')
815 self.assertRaises(TypeError, Tuple, default_value='hello')
817 t = Tuple(Int, CBytes, default_value=(1,5))
816 t = Tuple(Int, CBytes, default_value=(1,5))
818
817
819 class LooseTupleTrait(HasTraits):
818 class LooseTupleTrait(HasTraits):
820
819
821 value = Tuple((1,2,3))
820 value = Tuple((1,2,3))
822
821
823 class TestLooseTupleTrait(TraitTestBase):
822 class TestLooseTupleTrait(TraitTestBase):
824
823
825 obj = LooseTupleTrait()
824 obj = LooseTupleTrait()
826
825
827 _default_value = (1,2,3)
826 _default_value = (1,2,3)
828 _good_values = [(1,), None, (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
827 _good_values = [(1,), None, (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
829 _bad_values = [10, 'hello', [1], []]
828 _bad_values = [10, 'hello', [1], []]
830
829
831 def test_invalid_args(self):
830 def test_invalid_args(self):
832 self.assertRaises(TypeError, Tuple, 5)
831 self.assertRaises(TypeError, Tuple, 5)
833 self.assertRaises(TypeError, Tuple, default_value='hello')
832 self.assertRaises(TypeError, Tuple, default_value='hello')
834 t = Tuple(Int, CBytes, default_value=(1,5))
833 t = Tuple(Int, CBytes, default_value=(1,5))
835
834
836
835
837 class MultiTupleTrait(HasTraits):
836 class MultiTupleTrait(HasTraits):
838
837
839 value = Tuple(Int, Bytes, default_value=[99,'bottles'])
838 value = Tuple(Int, Bytes, default_value=[99,'bottles'])
840
839
841 class TestMultiTuple(TraitTestBase):
840 class TestMultiTuple(TraitTestBase):
842
841
843 obj = MultiTupleTrait()
842 obj = MultiTupleTrait()
844
843
845 _default_value = (99,'bottles')
844 _default_value = (99,'bottles')
846 _good_values = [(1,'a'), (2,'b')]
845 _good_values = [(1,'a'), (2,'b')]
847 _bad_values = ((),10, 'a', (1,'a',3), ('a',1))
846 _bad_values = ((),10, 'a', (1,'a',3), ('a',1))
@@ -1,1397 +1,1396 b''
1 #!/usr/bin/env python
2 # encoding: utf-8
1 # encoding: utf-8
3 """
2 """
4 A lightweight Traits like module.
3 A lightweight Traits like module.
5
4
6 This is designed to provide a lightweight, simple, pure Python version of
5 This is designed to provide a lightweight, simple, pure Python version of
7 many of the capabilities of enthought.traits. This includes:
6 many of the capabilities of enthought.traits. This includes:
8
7
9 * Validation
8 * Validation
10 * Type specification with defaults
9 * Type specification with defaults
11 * Static and dynamic notification
10 * Static and dynamic notification
12 * Basic predefined types
11 * Basic predefined types
13 * An API that is similar to enthought.traits
12 * An API that is similar to enthought.traits
14
13
15 We don't support:
14 We don't support:
16
15
17 * Delegation
16 * Delegation
18 * Automatic GUI generation
17 * Automatic GUI generation
19 * A full set of trait types. Most importantly, we don't provide container
18 * A full set of trait types. Most importantly, we don't provide container
20 traits (list, dict, tuple) that can trigger notifications if their
19 traits (list, dict, tuple) that can trigger notifications if their
21 contents change.
20 contents change.
22 * API compatibility with enthought.traits
21 * API compatibility with enthought.traits
23
22
24 There are also some important difference in our design:
23 There are also some important difference in our design:
25
24
26 * enthought.traits does not validate default values. We do.
25 * enthought.traits does not validate default values. We do.
27
26
28 We choose to create this module because we need these capabilities, but
27 We choose to create this module because we need these capabilities, but
29 we need them to be pure Python so they work in all Python implementations,
28 we need them to be pure Python so they work in all Python implementations,
30 including Jython and IronPython.
29 including Jython and IronPython.
31
30
32 Authors:
31 Authors:
33
32
34 * Brian Granger
33 * Brian Granger
35 * Enthought, Inc. Some of the code in this file comes from enthought.traits
34 * Enthought, Inc. Some of the code in this file comes from enthought.traits
36 and is licensed under the BSD license. Also, many of the ideas also come
35 and is licensed under the BSD license. Also, many of the ideas also come
37 from enthought.traits even though our implementation is very different.
36 from enthought.traits even though our implementation is very different.
38 """
37 """
39
38
40 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
41 # Copyright (C) 2008-2009 The IPython Development Team
40 # Copyright (C) 2008-2009 The IPython Development Team
42 #
41 #
43 # Distributed under the terms of the BSD License. The full license is in
42 # Distributed under the terms of the BSD License. The full license is in
44 # the file COPYING, distributed as part of this software.
43 # the file COPYING, distributed as part of this software.
45 #-----------------------------------------------------------------------------
44 #-----------------------------------------------------------------------------
46
45
47 #-----------------------------------------------------------------------------
46 #-----------------------------------------------------------------------------
48 # Imports
47 # Imports
49 #-----------------------------------------------------------------------------
48 #-----------------------------------------------------------------------------
50
49
51
50
52 import inspect
51 import inspect
53 import re
52 import re
54 import sys
53 import sys
55 import types
54 import types
56 from types import (
55 from types import (
57 InstanceType, ClassType, FunctionType,
56 InstanceType, ClassType, FunctionType,
58 ListType, TupleType
57 ListType, TupleType
59 )
58 )
60 from .importstring import import_item
59 from .importstring import import_item
61
60
62 ClassTypes = (ClassType, type)
61 ClassTypes = (ClassType, type)
63
62
64 SequenceTypes = (ListType, TupleType, set, frozenset)
63 SequenceTypes = (ListType, TupleType, set, frozenset)
65
64
66 #-----------------------------------------------------------------------------
65 #-----------------------------------------------------------------------------
67 # Basic classes
66 # Basic classes
68 #-----------------------------------------------------------------------------
67 #-----------------------------------------------------------------------------
69
68
70
69
71 class NoDefaultSpecified ( object ): pass
70 class NoDefaultSpecified ( object ): pass
72 NoDefaultSpecified = NoDefaultSpecified()
71 NoDefaultSpecified = NoDefaultSpecified()
73
72
74
73
75 class Undefined ( object ): pass
74 class Undefined ( object ): pass
76 Undefined = Undefined()
75 Undefined = Undefined()
77
76
78 class TraitError(Exception):
77 class TraitError(Exception):
79 pass
78 pass
80
79
81 #-----------------------------------------------------------------------------
80 #-----------------------------------------------------------------------------
82 # Utilities
81 # Utilities
83 #-----------------------------------------------------------------------------
82 #-----------------------------------------------------------------------------
84
83
85
84
86 def class_of ( object ):
85 def class_of ( object ):
87 """ Returns a string containing the class name of an object with the
86 """ Returns a string containing the class name of an object with the
88 correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image',
87 correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image',
89 'a PlotValue').
88 'a PlotValue').
90 """
89 """
91 if isinstance( object, basestring ):
90 if isinstance( object, basestring ):
92 return add_article( object )
91 return add_article( object )
93
92
94 return add_article( object.__class__.__name__ )
93 return add_article( object.__class__.__name__ )
95
94
96
95
97 def add_article ( name ):
96 def add_article ( name ):
98 """ Returns a string containing the correct indefinite article ('a' or 'an')
97 """ Returns a string containing the correct indefinite article ('a' or 'an')
99 prefixed to the specified string.
98 prefixed to the specified string.
100 """
99 """
101 if name[:1].lower() in 'aeiou':
100 if name[:1].lower() in 'aeiou':
102 return 'an ' + name
101 return 'an ' + name
103
102
104 return 'a ' + name
103 return 'a ' + name
105
104
106
105
107 def repr_type(obj):
106 def repr_type(obj):
108 """ Return a string representation of a value and its type for readable
107 """ Return a string representation of a value and its type for readable
109 error messages.
108 error messages.
110 """
109 """
111 the_type = type(obj)
110 the_type = type(obj)
112 if the_type is InstanceType:
111 if the_type is InstanceType:
113 # Old-style class.
112 # Old-style class.
114 the_type = obj.__class__
113 the_type = obj.__class__
115 msg = '%r %r' % (obj, the_type)
114 msg = '%r %r' % (obj, the_type)
116 return msg
115 return msg
117
116
118
117
119 def parse_notifier_name(name):
118 def parse_notifier_name(name):
120 """Convert the name argument to a list of names.
119 """Convert the name argument to a list of names.
121
120
122 Examples
121 Examples
123 --------
122 --------
124
123
125 >>> parse_notifier_name('a')
124 >>> parse_notifier_name('a')
126 ['a']
125 ['a']
127 >>> parse_notifier_name(['a','b'])
126 >>> parse_notifier_name(['a','b'])
128 ['a', 'b']
127 ['a', 'b']
129 >>> parse_notifier_name(None)
128 >>> parse_notifier_name(None)
130 ['anytrait']
129 ['anytrait']
131 """
130 """
132 if isinstance(name, str):
131 if isinstance(name, str):
133 return [name]
132 return [name]
134 elif name is None:
133 elif name is None:
135 return ['anytrait']
134 return ['anytrait']
136 elif isinstance(name, (list, tuple)):
135 elif isinstance(name, (list, tuple)):
137 for n in name:
136 for n in name:
138 assert isinstance(n, str), "names must be strings"
137 assert isinstance(n, str), "names must be strings"
139 return name
138 return name
140
139
141
140
142 class _SimpleTest:
141 class _SimpleTest:
143 def __init__ ( self, value ): self.value = value
142 def __init__ ( self, value ): self.value = value
144 def __call__ ( self, test ):
143 def __call__ ( self, test ):
145 return test == self.value
144 return test == self.value
146 def __repr__(self):
145 def __repr__(self):
147 return "<SimpleTest(%r)" % self.value
146 return "<SimpleTest(%r)" % self.value
148 def __str__(self):
147 def __str__(self):
149 return self.__repr__()
148 return self.__repr__()
150
149
151
150
152 def getmembers(object, predicate=None):
151 def getmembers(object, predicate=None):
153 """A safe version of inspect.getmembers that handles missing attributes.
152 """A safe version of inspect.getmembers that handles missing attributes.
154
153
155 This is useful when there are descriptor based attributes that for
154 This is useful when there are descriptor based attributes that for
156 some reason raise AttributeError even though they exist. This happens
155 some reason raise AttributeError even though they exist. This happens
157 in zope.inteface with the __provides__ attribute.
156 in zope.inteface with the __provides__ attribute.
158 """
157 """
159 results = []
158 results = []
160 for key in dir(object):
159 for key in dir(object):
161 try:
160 try:
162 value = getattr(object, key)
161 value = getattr(object, key)
163 except AttributeError:
162 except AttributeError:
164 pass
163 pass
165 else:
164 else:
166 if not predicate or predicate(value):
165 if not predicate or predicate(value):
167 results.append((key, value))
166 results.append((key, value))
168 results.sort()
167 results.sort()
169 return results
168 return results
170
169
171
170
172 #-----------------------------------------------------------------------------
171 #-----------------------------------------------------------------------------
173 # Base TraitType for all traits
172 # Base TraitType for all traits
174 #-----------------------------------------------------------------------------
173 #-----------------------------------------------------------------------------
175
174
176
175
177 class TraitType(object):
176 class TraitType(object):
178 """A base class for all trait descriptors.
177 """A base class for all trait descriptors.
179
178
180 Notes
179 Notes
181 -----
180 -----
182 Our implementation of traits is based on Python's descriptor
181 Our implementation of traits is based on Python's descriptor
183 prototol. This class is the base class for all such descriptors. The
182 prototol. This class is the base class for all such descriptors. The
184 only magic we use is a custom metaclass for the main :class:`HasTraits`
183 only magic we use is a custom metaclass for the main :class:`HasTraits`
185 class that does the following:
184 class that does the following:
186
185
187 1. Sets the :attr:`name` attribute of every :class:`TraitType`
186 1. Sets the :attr:`name` attribute of every :class:`TraitType`
188 instance in the class dict to the name of the attribute.
187 instance in the class dict to the name of the attribute.
189 2. Sets the :attr:`this_class` attribute of every :class:`TraitType`
188 2. Sets the :attr:`this_class` attribute of every :class:`TraitType`
190 instance in the class dict to the *class* that declared the trait.
189 instance in the class dict to the *class* that declared the trait.
191 This is used by the :class:`This` trait to allow subclasses to
190 This is used by the :class:`This` trait to allow subclasses to
192 accept superclasses for :class:`This` values.
191 accept superclasses for :class:`This` values.
193 """
192 """
194
193
195
194
196 metadata = {}
195 metadata = {}
197 default_value = Undefined
196 default_value = Undefined
198 info_text = 'any value'
197 info_text = 'any value'
199
198
200 def __init__(self, default_value=NoDefaultSpecified, **metadata):
199 def __init__(self, default_value=NoDefaultSpecified, **metadata):
201 """Create a TraitType.
200 """Create a TraitType.
202 """
201 """
203 if default_value is not NoDefaultSpecified:
202 if default_value is not NoDefaultSpecified:
204 self.default_value = default_value
203 self.default_value = default_value
205
204
206 if len(metadata) > 0:
205 if len(metadata) > 0:
207 if len(self.metadata) > 0:
206 if len(self.metadata) > 0:
208 self._metadata = self.metadata.copy()
207 self._metadata = self.metadata.copy()
209 self._metadata.update(metadata)
208 self._metadata.update(metadata)
210 else:
209 else:
211 self._metadata = metadata
210 self._metadata = metadata
212 else:
211 else:
213 self._metadata = self.metadata
212 self._metadata = self.metadata
214
213
215 self.init()
214 self.init()
216
215
217 def init(self):
216 def init(self):
218 pass
217 pass
219
218
220 def get_default_value(self):
219 def get_default_value(self):
221 """Create a new instance of the default value."""
220 """Create a new instance of the default value."""
222 return self.default_value
221 return self.default_value
223
222
224 def instance_init(self, obj):
223 def instance_init(self, obj):
225 """This is called by :meth:`HasTraits.__new__` to finish init'ing.
224 """This is called by :meth:`HasTraits.__new__` to finish init'ing.
226
225
227 Some stages of initialization must be delayed until the parent
226 Some stages of initialization must be delayed until the parent
228 :class:`HasTraits` instance has been created. This method is
227 :class:`HasTraits` instance has been created. This method is
229 called in :meth:`HasTraits.__new__` after the instance has been
228 called in :meth:`HasTraits.__new__` after the instance has been
230 created.
229 created.
231
230
232 This method trigger the creation and validation of default values
231 This method trigger the creation and validation of default values
233 and also things like the resolution of str given class names in
232 and also things like the resolution of str given class names in
234 :class:`Type` and :class`Instance`.
233 :class:`Type` and :class`Instance`.
235
234
236 Parameters
235 Parameters
237 ----------
236 ----------
238 obj : :class:`HasTraits` instance
237 obj : :class:`HasTraits` instance
239 The parent :class:`HasTraits` instance that has just been
238 The parent :class:`HasTraits` instance that has just been
240 created.
239 created.
241 """
240 """
242 self.set_default_value(obj)
241 self.set_default_value(obj)
243
242
244 def set_default_value(self, obj):
243 def set_default_value(self, obj):
245 """Set the default value on a per instance basis.
244 """Set the default value on a per instance basis.
246
245
247 This method is called by :meth:`instance_init` to create and
246 This method is called by :meth:`instance_init` to create and
248 validate the default value. The creation and validation of
247 validate the default value. The creation and validation of
249 default values must be delayed until the parent :class:`HasTraits`
248 default values must be delayed until the parent :class:`HasTraits`
250 class has been instantiated.
249 class has been instantiated.
251 """
250 """
252 # Check for a deferred initializer defined in the same class as the
251 # Check for a deferred initializer defined in the same class as the
253 # trait declaration or above.
252 # trait declaration or above.
254 mro = type(obj).mro()
253 mro = type(obj).mro()
255 meth_name = '_%s_default' % self.name
254 meth_name = '_%s_default' % self.name
256 for cls in mro[:mro.index(self.this_class)+1]:
255 for cls in mro[:mro.index(self.this_class)+1]:
257 if meth_name in cls.__dict__:
256 if meth_name in cls.__dict__:
258 break
257 break
259 else:
258 else:
260 # We didn't find one. Do static initialization.
259 # We didn't find one. Do static initialization.
261 dv = self.get_default_value()
260 dv = self.get_default_value()
262 newdv = self._validate(obj, dv)
261 newdv = self._validate(obj, dv)
263 obj._trait_values[self.name] = newdv
262 obj._trait_values[self.name] = newdv
264 return
263 return
265 # Complete the dynamic initialization.
264 # Complete the dynamic initialization.
266 obj._trait_dyn_inits[self.name] = cls.__dict__[meth_name]
265 obj._trait_dyn_inits[self.name] = cls.__dict__[meth_name]
267
266
268 def __get__(self, obj, cls=None):
267 def __get__(self, obj, cls=None):
269 """Get the value of the trait by self.name for the instance.
268 """Get the value of the trait by self.name for the instance.
270
269
271 Default values are instantiated when :meth:`HasTraits.__new__`
270 Default values are instantiated when :meth:`HasTraits.__new__`
272 is called. Thus by the time this method gets called either the
271 is called. Thus by the time this method gets called either the
273 default value or a user defined value (they called :meth:`__set__`)
272 default value or a user defined value (they called :meth:`__set__`)
274 is in the :class:`HasTraits` instance.
273 is in the :class:`HasTraits` instance.
275 """
274 """
276 if obj is None:
275 if obj is None:
277 return self
276 return self
278 else:
277 else:
279 try:
278 try:
280 value = obj._trait_values[self.name]
279 value = obj._trait_values[self.name]
281 except KeyError:
280 except KeyError:
282 # Check for a dynamic initializer.
281 # Check for a dynamic initializer.
283 if self.name in obj._trait_dyn_inits:
282 if self.name in obj._trait_dyn_inits:
284 value = obj._trait_dyn_inits[self.name](obj)
283 value = obj._trait_dyn_inits[self.name](obj)
285 # FIXME: Do we really validate here?
284 # FIXME: Do we really validate here?
286 value = self._validate(obj, value)
285 value = self._validate(obj, value)
287 obj._trait_values[self.name] = value
286 obj._trait_values[self.name] = value
288 return value
287 return value
289 else:
288 else:
290 raise TraitError('Unexpected error in TraitType: '
289 raise TraitError('Unexpected error in TraitType: '
291 'both default value and dynamic initializer are '
290 'both default value and dynamic initializer are '
292 'absent.')
291 'absent.')
293 except Exception:
292 except Exception:
294 # HasTraits should call set_default_value to populate
293 # HasTraits should call set_default_value to populate
295 # this. So this should never be reached.
294 # this. So this should never be reached.
296 raise TraitError('Unexpected error in TraitType: '
295 raise TraitError('Unexpected error in TraitType: '
297 'default value not set properly')
296 'default value not set properly')
298 else:
297 else:
299 return value
298 return value
300
299
301 def __set__(self, obj, value):
300 def __set__(self, obj, value):
302 new_value = self._validate(obj, value)
301 new_value = self._validate(obj, value)
303 old_value = self.__get__(obj)
302 old_value = self.__get__(obj)
304 if old_value != new_value:
303 if old_value != new_value:
305 obj._trait_values[self.name] = new_value
304 obj._trait_values[self.name] = new_value
306 obj._notify_trait(self.name, old_value, new_value)
305 obj._notify_trait(self.name, old_value, new_value)
307
306
308 def _validate(self, obj, value):
307 def _validate(self, obj, value):
309 if hasattr(self, 'validate'):
308 if hasattr(self, 'validate'):
310 return self.validate(obj, value)
309 return self.validate(obj, value)
311 elif hasattr(self, 'is_valid_for'):
310 elif hasattr(self, 'is_valid_for'):
312 valid = self.is_valid_for(value)
311 valid = self.is_valid_for(value)
313 if valid:
312 if valid:
314 return value
313 return value
315 else:
314 else:
316 raise TraitError('invalid value for type: %r' % value)
315 raise TraitError('invalid value for type: %r' % value)
317 elif hasattr(self, 'value_for'):
316 elif hasattr(self, 'value_for'):
318 return self.value_for(value)
317 return self.value_for(value)
319 else:
318 else:
320 return value
319 return value
321
320
322 def info(self):
321 def info(self):
323 return self.info_text
322 return self.info_text
324
323
325 def error(self, obj, value):
324 def error(self, obj, value):
326 if obj is not None:
325 if obj is not None:
327 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
326 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
328 % (self.name, class_of(obj),
327 % (self.name, class_of(obj),
329 self.info(), repr_type(value))
328 self.info(), repr_type(value))
330 else:
329 else:
331 e = "The '%s' trait must be %s, but a value of %r was specified." \
330 e = "The '%s' trait must be %s, but a value of %r was specified." \
332 % (self.name, self.info(), repr_type(value))
331 % (self.name, self.info(), repr_type(value))
333 raise TraitError(e)
332 raise TraitError(e)
334
333
335 def get_metadata(self, key):
334 def get_metadata(self, key):
336 return getattr(self, '_metadata', {}).get(key, None)
335 return getattr(self, '_metadata', {}).get(key, None)
337
336
338 def set_metadata(self, key, value):
337 def set_metadata(self, key, value):
339 getattr(self, '_metadata', {})[key] = value
338 getattr(self, '_metadata', {})[key] = value
340
339
341
340
342 #-----------------------------------------------------------------------------
341 #-----------------------------------------------------------------------------
343 # The HasTraits implementation
342 # The HasTraits implementation
344 #-----------------------------------------------------------------------------
343 #-----------------------------------------------------------------------------
345
344
346
345
347 class MetaHasTraits(type):
346 class MetaHasTraits(type):
348 """A metaclass for HasTraits.
347 """A metaclass for HasTraits.
349
348
350 This metaclass makes sure that any TraitType class attributes are
349 This metaclass makes sure that any TraitType class attributes are
351 instantiated and sets their name attribute.
350 instantiated and sets their name attribute.
352 """
351 """
353
352
354 def __new__(mcls, name, bases, classdict):
353 def __new__(mcls, name, bases, classdict):
355 """Create the HasTraits class.
354 """Create the HasTraits class.
356
355
357 This instantiates all TraitTypes in the class dict and sets their
356 This instantiates all TraitTypes in the class dict and sets their
358 :attr:`name` attribute.
357 :attr:`name` attribute.
359 """
358 """
360 # print "MetaHasTraitlets (mcls, name): ", mcls, name
359 # print "MetaHasTraitlets (mcls, name): ", mcls, name
361 # print "MetaHasTraitlets (bases): ", bases
360 # print "MetaHasTraitlets (bases): ", bases
362 # print "MetaHasTraitlets (classdict): ", classdict
361 # print "MetaHasTraitlets (classdict): ", classdict
363 for k,v in classdict.iteritems():
362 for k,v in classdict.iteritems():
364 if isinstance(v, TraitType):
363 if isinstance(v, TraitType):
365 v.name = k
364 v.name = k
366 elif inspect.isclass(v):
365 elif inspect.isclass(v):
367 if issubclass(v, TraitType):
366 if issubclass(v, TraitType):
368 vinst = v()
367 vinst = v()
369 vinst.name = k
368 vinst.name = k
370 classdict[k] = vinst
369 classdict[k] = vinst
371 return super(MetaHasTraits, mcls).__new__(mcls, name, bases, classdict)
370 return super(MetaHasTraits, mcls).__new__(mcls, name, bases, classdict)
372
371
373 def __init__(cls, name, bases, classdict):
372 def __init__(cls, name, bases, classdict):
374 """Finish initializing the HasTraits class.
373 """Finish initializing the HasTraits class.
375
374
376 This sets the :attr:`this_class` attribute of each TraitType in the
375 This sets the :attr:`this_class` attribute of each TraitType in the
377 class dict to the newly created class ``cls``.
376 class dict to the newly created class ``cls``.
378 """
377 """
379 for k, v in classdict.iteritems():
378 for k, v in classdict.iteritems():
380 if isinstance(v, TraitType):
379 if isinstance(v, TraitType):
381 v.this_class = cls
380 v.this_class = cls
382 super(MetaHasTraits, cls).__init__(name, bases, classdict)
381 super(MetaHasTraits, cls).__init__(name, bases, classdict)
383
382
384 class HasTraits(object):
383 class HasTraits(object):
385
384
386 __metaclass__ = MetaHasTraits
385 __metaclass__ = MetaHasTraits
387
386
388 def __new__(cls, **kw):
387 def __new__(cls, **kw):
389 # This is needed because in Python 2.6 object.__new__ only accepts
388 # This is needed because in Python 2.6 object.__new__ only accepts
390 # the cls argument.
389 # the cls argument.
391 new_meth = super(HasTraits, cls).__new__
390 new_meth = super(HasTraits, cls).__new__
392 if new_meth is object.__new__:
391 if new_meth is object.__new__:
393 inst = new_meth(cls)
392 inst = new_meth(cls)
394 else:
393 else:
395 inst = new_meth(cls, **kw)
394 inst = new_meth(cls, **kw)
396 inst._trait_values = {}
395 inst._trait_values = {}
397 inst._trait_notifiers = {}
396 inst._trait_notifiers = {}
398 inst._trait_dyn_inits = {}
397 inst._trait_dyn_inits = {}
399 # Here we tell all the TraitType instances to set their default
398 # Here we tell all the TraitType instances to set their default
400 # values on the instance.
399 # values on the instance.
401 for key in dir(cls):
400 for key in dir(cls):
402 # Some descriptors raise AttributeError like zope.interface's
401 # Some descriptors raise AttributeError like zope.interface's
403 # __provides__ attributes even though they exist. This causes
402 # __provides__ attributes even though they exist. This causes
404 # AttributeErrors even though they are listed in dir(cls).
403 # AttributeErrors even though they are listed in dir(cls).
405 try:
404 try:
406 value = getattr(cls, key)
405 value = getattr(cls, key)
407 except AttributeError:
406 except AttributeError:
408 pass
407 pass
409 else:
408 else:
410 if isinstance(value, TraitType):
409 if isinstance(value, TraitType):
411 value.instance_init(inst)
410 value.instance_init(inst)
412
411
413 return inst
412 return inst
414
413
415 def __init__(self, **kw):
414 def __init__(self, **kw):
416 # Allow trait values to be set using keyword arguments.
415 # Allow trait values to be set using keyword arguments.
417 # We need to use setattr for this to trigger validation and
416 # We need to use setattr for this to trigger validation and
418 # notifications.
417 # notifications.
419 for key, value in kw.iteritems():
418 for key, value in kw.iteritems():
420 setattr(self, key, value)
419 setattr(self, key, value)
421
420
422 def _notify_trait(self, name, old_value, new_value):
421 def _notify_trait(self, name, old_value, new_value):
423
422
424 # First dynamic ones
423 # First dynamic ones
425 callables = self._trait_notifiers.get(name,[])
424 callables = self._trait_notifiers.get(name,[])
426 more_callables = self._trait_notifiers.get('anytrait',[])
425 more_callables = self._trait_notifiers.get('anytrait',[])
427 callables.extend(more_callables)
426 callables.extend(more_callables)
428
427
429 # Now static ones
428 # Now static ones
430 try:
429 try:
431 cb = getattr(self, '_%s_changed' % name)
430 cb = getattr(self, '_%s_changed' % name)
432 except:
431 except:
433 pass
432 pass
434 else:
433 else:
435 callables.append(cb)
434 callables.append(cb)
436
435
437 # Call them all now
436 # Call them all now
438 for c in callables:
437 for c in callables:
439 # Traits catches and logs errors here. I allow them to raise
438 # Traits catches and logs errors here. I allow them to raise
440 if callable(c):
439 if callable(c):
441 argspec = inspect.getargspec(c)
440 argspec = inspect.getargspec(c)
442 nargs = len(argspec[0])
441 nargs = len(argspec[0])
443 # Bound methods have an additional 'self' argument
442 # Bound methods have an additional 'self' argument
444 # I don't know how to treat unbound methods, but they
443 # I don't know how to treat unbound methods, but they
445 # can't really be used for callbacks.
444 # can't really be used for callbacks.
446 if isinstance(c, types.MethodType):
445 if isinstance(c, types.MethodType):
447 offset = -1
446 offset = -1
448 else:
447 else:
449 offset = 0
448 offset = 0
450 if nargs + offset == 0:
449 if nargs + offset == 0:
451 c()
450 c()
452 elif nargs + offset == 1:
451 elif nargs + offset == 1:
453 c(name)
452 c(name)
454 elif nargs + offset == 2:
453 elif nargs + offset == 2:
455 c(name, new_value)
454 c(name, new_value)
456 elif nargs + offset == 3:
455 elif nargs + offset == 3:
457 c(name, old_value, new_value)
456 c(name, old_value, new_value)
458 else:
457 else:
459 raise TraitError('a trait changed callback '
458 raise TraitError('a trait changed callback '
460 'must have 0-3 arguments.')
459 'must have 0-3 arguments.')
461 else:
460 else:
462 raise TraitError('a trait changed callback '
461 raise TraitError('a trait changed callback '
463 'must be callable.')
462 'must be callable.')
464
463
465
464
466 def _add_notifiers(self, handler, name):
465 def _add_notifiers(self, handler, name):
467 if not self._trait_notifiers.has_key(name):
466 if not self._trait_notifiers.has_key(name):
468 nlist = []
467 nlist = []
469 self._trait_notifiers[name] = nlist
468 self._trait_notifiers[name] = nlist
470 else:
469 else:
471 nlist = self._trait_notifiers[name]
470 nlist = self._trait_notifiers[name]
472 if handler not in nlist:
471 if handler not in nlist:
473 nlist.append(handler)
472 nlist.append(handler)
474
473
475 def _remove_notifiers(self, handler, name):
474 def _remove_notifiers(self, handler, name):
476 if self._trait_notifiers.has_key(name):
475 if self._trait_notifiers.has_key(name):
477 nlist = self._trait_notifiers[name]
476 nlist = self._trait_notifiers[name]
478 try:
477 try:
479 index = nlist.index(handler)
478 index = nlist.index(handler)
480 except ValueError:
479 except ValueError:
481 pass
480 pass
482 else:
481 else:
483 del nlist[index]
482 del nlist[index]
484
483
485 def on_trait_change(self, handler, name=None, remove=False):
484 def on_trait_change(self, handler, name=None, remove=False):
486 """Setup a handler to be called when a trait changes.
485 """Setup a handler to be called when a trait changes.
487
486
488 This is used to setup dynamic notifications of trait changes.
487 This is used to setup dynamic notifications of trait changes.
489
488
490 Static handlers can be created by creating methods on a HasTraits
489 Static handlers can be created by creating methods on a HasTraits
491 subclass with the naming convention '_[traitname]_changed'. Thus,
490 subclass with the naming convention '_[traitname]_changed'. Thus,
492 to create static handler for the trait 'a', create the method
491 to create static handler for the trait 'a', create the method
493 _a_changed(self, name, old, new) (fewer arguments can be used, see
492 _a_changed(self, name, old, new) (fewer arguments can be used, see
494 below).
493 below).
495
494
496 Parameters
495 Parameters
497 ----------
496 ----------
498 handler : callable
497 handler : callable
499 A callable that is called when a trait changes. Its
498 A callable that is called when a trait changes. Its
500 signature can be handler(), handler(name), handler(name, new)
499 signature can be handler(), handler(name), handler(name, new)
501 or handler(name, old, new).
500 or handler(name, old, new).
502 name : list, str, None
501 name : list, str, None
503 If None, the handler will apply to all traits. If a list
502 If None, the handler will apply to all traits. If a list
504 of str, handler will apply to all names in the list. If a
503 of str, handler will apply to all names in the list. If a
505 str, the handler will apply just to that name.
504 str, the handler will apply just to that name.
506 remove : bool
505 remove : bool
507 If False (the default), then install the handler. If True
506 If False (the default), then install the handler. If True
508 then unintall it.
507 then unintall it.
509 """
508 """
510 if remove:
509 if remove:
511 names = parse_notifier_name(name)
510 names = parse_notifier_name(name)
512 for n in names:
511 for n in names:
513 self._remove_notifiers(handler, n)
512 self._remove_notifiers(handler, n)
514 else:
513 else:
515 names = parse_notifier_name(name)
514 names = parse_notifier_name(name)
516 for n in names:
515 for n in names:
517 self._add_notifiers(handler, n)
516 self._add_notifiers(handler, n)
518
517
519 @classmethod
518 @classmethod
520 def class_trait_names(cls, **metadata):
519 def class_trait_names(cls, **metadata):
521 """Get a list of all the names of this classes traits.
520 """Get a list of all the names of this classes traits.
522
521
523 This method is just like the :meth:`trait_names` method, but is unbound.
522 This method is just like the :meth:`trait_names` method, but is unbound.
524 """
523 """
525 return cls.class_traits(**metadata).keys()
524 return cls.class_traits(**metadata).keys()
526
525
527 @classmethod
526 @classmethod
528 def class_traits(cls, **metadata):
527 def class_traits(cls, **metadata):
529 """Get a list of all the traits of this class.
528 """Get a list of all the traits of this class.
530
529
531 This method is just like the :meth:`traits` method, but is unbound.
530 This method is just like the :meth:`traits` method, but is unbound.
532
531
533 The TraitTypes returned don't know anything about the values
532 The TraitTypes returned don't know anything about the values
534 that the various HasTrait's instances are holding.
533 that the various HasTrait's instances are holding.
535
534
536 This follows the same algorithm as traits does and does not allow
535 This follows the same algorithm as traits does and does not allow
537 for any simple way of specifying merely that a metadata name
536 for any simple way of specifying merely that a metadata name
538 exists, but has any value. This is because get_metadata returns
537 exists, but has any value. This is because get_metadata returns
539 None if a metadata key doesn't exist.
538 None if a metadata key doesn't exist.
540 """
539 """
541 traits = dict([memb for memb in getmembers(cls) if \
540 traits = dict([memb for memb in getmembers(cls) if \
542 isinstance(memb[1], TraitType)])
541 isinstance(memb[1], TraitType)])
543
542
544 if len(metadata) == 0:
543 if len(metadata) == 0:
545 return traits
544 return traits
546
545
547 for meta_name, meta_eval in metadata.items():
546 for meta_name, meta_eval in metadata.items():
548 if type(meta_eval) is not FunctionType:
547 if type(meta_eval) is not FunctionType:
549 metadata[meta_name] = _SimpleTest(meta_eval)
548 metadata[meta_name] = _SimpleTest(meta_eval)
550
549
551 result = {}
550 result = {}
552 for name, trait in traits.items():
551 for name, trait in traits.items():
553 for meta_name, meta_eval in metadata.items():
552 for meta_name, meta_eval in metadata.items():
554 if not meta_eval(trait.get_metadata(meta_name)):
553 if not meta_eval(trait.get_metadata(meta_name)):
555 break
554 break
556 else:
555 else:
557 result[name] = trait
556 result[name] = trait
558
557
559 return result
558 return result
560
559
561 def trait_names(self, **metadata):
560 def trait_names(self, **metadata):
562 """Get a list of all the names of this classes traits."""
561 """Get a list of all the names of this classes traits."""
563 return self.traits(**metadata).keys()
562 return self.traits(**metadata).keys()
564
563
565 def traits(self, **metadata):
564 def traits(self, **metadata):
566 """Get a list of all the traits of this class.
565 """Get a list of all the traits of this class.
567
566
568 The TraitTypes returned don't know anything about the values
567 The TraitTypes returned don't know anything about the values
569 that the various HasTrait's instances are holding.
568 that the various HasTrait's instances are holding.
570
569
571 This follows the same algorithm as traits does and does not allow
570 This follows the same algorithm as traits does and does not allow
572 for any simple way of specifying merely that a metadata name
571 for any simple way of specifying merely that a metadata name
573 exists, but has any value. This is because get_metadata returns
572 exists, but has any value. This is because get_metadata returns
574 None if a metadata key doesn't exist.
573 None if a metadata key doesn't exist.
575 """
574 """
576 traits = dict([memb for memb in getmembers(self.__class__) if \
575 traits = dict([memb for memb in getmembers(self.__class__) if \
577 isinstance(memb[1], TraitType)])
576 isinstance(memb[1], TraitType)])
578
577
579 if len(metadata) == 0:
578 if len(metadata) == 0:
580 return traits
579 return traits
581
580
582 for meta_name, meta_eval in metadata.items():
581 for meta_name, meta_eval in metadata.items():
583 if type(meta_eval) is not FunctionType:
582 if type(meta_eval) is not FunctionType:
584 metadata[meta_name] = _SimpleTest(meta_eval)
583 metadata[meta_name] = _SimpleTest(meta_eval)
585
584
586 result = {}
585 result = {}
587 for name, trait in traits.items():
586 for name, trait in traits.items():
588 for meta_name, meta_eval in metadata.items():
587 for meta_name, meta_eval in metadata.items():
589 if not meta_eval(trait.get_metadata(meta_name)):
588 if not meta_eval(trait.get_metadata(meta_name)):
590 break
589 break
591 else:
590 else:
592 result[name] = trait
591 result[name] = trait
593
592
594 return result
593 return result
595
594
596 def trait_metadata(self, traitname, key):
595 def trait_metadata(self, traitname, key):
597 """Get metadata values for trait by key."""
596 """Get metadata values for trait by key."""
598 try:
597 try:
599 trait = getattr(self.__class__, traitname)
598 trait = getattr(self.__class__, traitname)
600 except AttributeError:
599 except AttributeError:
601 raise TraitError("Class %s does not have a trait named %s" %
600 raise TraitError("Class %s does not have a trait named %s" %
602 (self.__class__.__name__, traitname))
601 (self.__class__.__name__, traitname))
603 else:
602 else:
604 return trait.get_metadata(key)
603 return trait.get_metadata(key)
605
604
606 #-----------------------------------------------------------------------------
605 #-----------------------------------------------------------------------------
607 # Actual TraitTypes implementations/subclasses
606 # Actual TraitTypes implementations/subclasses
608 #-----------------------------------------------------------------------------
607 #-----------------------------------------------------------------------------
609
608
610 #-----------------------------------------------------------------------------
609 #-----------------------------------------------------------------------------
611 # TraitTypes subclasses for handling classes and instances of classes
610 # TraitTypes subclasses for handling classes and instances of classes
612 #-----------------------------------------------------------------------------
611 #-----------------------------------------------------------------------------
613
612
614
613
615 class ClassBasedTraitType(TraitType):
614 class ClassBasedTraitType(TraitType):
616 """A trait with error reporting for Type, Instance and This."""
615 """A trait with error reporting for Type, Instance and This."""
617
616
618 def error(self, obj, value):
617 def error(self, obj, value):
619 kind = type(value)
618 kind = type(value)
620 if kind is InstanceType:
619 if kind is InstanceType:
621 msg = 'class %s' % value.__class__.__name__
620 msg = 'class %s' % value.__class__.__name__
622 else:
621 else:
623 msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
622 msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
624
623
625 if obj is not None:
624 if obj is not None:
626 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
625 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
627 % (self.name, class_of(obj),
626 % (self.name, class_of(obj),
628 self.info(), msg)
627 self.info(), msg)
629 else:
628 else:
630 e = "The '%s' trait must be %s, but a value of %r was specified." \
629 e = "The '%s' trait must be %s, but a value of %r was specified." \
631 % (self.name, self.info(), msg)
630 % (self.name, self.info(), msg)
632
631
633 raise TraitError(e)
632 raise TraitError(e)
634
633
635
634
636 class Type(ClassBasedTraitType):
635 class Type(ClassBasedTraitType):
637 """A trait whose value must be a subclass of a specified class."""
636 """A trait whose value must be a subclass of a specified class."""
638
637
639 def __init__ (self, default_value=None, klass=None, allow_none=True, **metadata ):
638 def __init__ (self, default_value=None, klass=None, allow_none=True, **metadata ):
640 """Construct a Type trait
639 """Construct a Type trait
641
640
642 A Type trait specifies that its values must be subclasses of
641 A Type trait specifies that its values must be subclasses of
643 a particular class.
642 a particular class.
644
643
645 If only ``default_value`` is given, it is used for the ``klass`` as
644 If only ``default_value`` is given, it is used for the ``klass`` as
646 well.
645 well.
647
646
648 Parameters
647 Parameters
649 ----------
648 ----------
650 default_value : class, str or None
649 default_value : class, str or None
651 The default value must be a subclass of klass. If an str,
650 The default value must be a subclass of klass. If an str,
652 the str must be a fully specified class name, like 'foo.bar.Bah'.
651 the str must be a fully specified class name, like 'foo.bar.Bah'.
653 The string is resolved into real class, when the parent
652 The string is resolved into real class, when the parent
654 :class:`HasTraits` class is instantiated.
653 :class:`HasTraits` class is instantiated.
655 klass : class, str, None
654 klass : class, str, None
656 Values of this trait must be a subclass of klass. The klass
655 Values of this trait must be a subclass of klass. The klass
657 may be specified in a string like: 'foo.bar.MyClass'.
656 may be specified in a string like: 'foo.bar.MyClass'.
658 The string is resolved into real class, when the parent
657 The string is resolved into real class, when the parent
659 :class:`HasTraits` class is instantiated.
658 :class:`HasTraits` class is instantiated.
660 allow_none : boolean
659 allow_none : boolean
661 Indicates whether None is allowed as an assignable value. Even if
660 Indicates whether None is allowed as an assignable value. Even if
662 ``False``, the default value may be ``None``.
661 ``False``, the default value may be ``None``.
663 """
662 """
664 if default_value is None:
663 if default_value is None:
665 if klass is None:
664 if klass is None:
666 klass = object
665 klass = object
667 elif klass is None:
666 elif klass is None:
668 klass = default_value
667 klass = default_value
669
668
670 if not (inspect.isclass(klass) or isinstance(klass, basestring)):
669 if not (inspect.isclass(klass) or isinstance(klass, basestring)):
671 raise TraitError("A Type trait must specify a class.")
670 raise TraitError("A Type trait must specify a class.")
672
671
673 self.klass = klass
672 self.klass = klass
674 self._allow_none = allow_none
673 self._allow_none = allow_none
675
674
676 super(Type, self).__init__(default_value, **metadata)
675 super(Type, self).__init__(default_value, **metadata)
677
676
678 def validate(self, obj, value):
677 def validate(self, obj, value):
679 """Validates that the value is a valid object instance."""
678 """Validates that the value is a valid object instance."""
680 try:
679 try:
681 if issubclass(value, self.klass):
680 if issubclass(value, self.klass):
682 return value
681 return value
683 except:
682 except:
684 if (value is None) and (self._allow_none):
683 if (value is None) and (self._allow_none):
685 return value
684 return value
686
685
687 self.error(obj, value)
686 self.error(obj, value)
688
687
689 def info(self):
688 def info(self):
690 """ Returns a description of the trait."""
689 """ Returns a description of the trait."""
691 if isinstance(self.klass, basestring):
690 if isinstance(self.klass, basestring):
692 klass = self.klass
691 klass = self.klass
693 else:
692 else:
694 klass = self.klass.__name__
693 klass = self.klass.__name__
695 result = 'a subclass of ' + klass
694 result = 'a subclass of ' + klass
696 if self._allow_none:
695 if self._allow_none:
697 return result + ' or None'
696 return result + ' or None'
698 return result
697 return result
699
698
700 def instance_init(self, obj):
699 def instance_init(self, obj):
701 self._resolve_classes()
700 self._resolve_classes()
702 super(Type, self).instance_init(obj)
701 super(Type, self).instance_init(obj)
703
702
704 def _resolve_classes(self):
703 def _resolve_classes(self):
705 if isinstance(self.klass, basestring):
704 if isinstance(self.klass, basestring):
706 self.klass = import_item(self.klass)
705 self.klass = import_item(self.klass)
707 if isinstance(self.default_value, basestring):
706 if isinstance(self.default_value, basestring):
708 self.default_value = import_item(self.default_value)
707 self.default_value = import_item(self.default_value)
709
708
710 def get_default_value(self):
709 def get_default_value(self):
711 return self.default_value
710 return self.default_value
712
711
713
712
714 class DefaultValueGenerator(object):
713 class DefaultValueGenerator(object):
715 """A class for generating new default value instances."""
714 """A class for generating new default value instances."""
716
715
717 def __init__(self, *args, **kw):
716 def __init__(self, *args, **kw):
718 self.args = args
717 self.args = args
719 self.kw = kw
718 self.kw = kw
720
719
721 def generate(self, klass):
720 def generate(self, klass):
722 return klass(*self.args, **self.kw)
721 return klass(*self.args, **self.kw)
723
722
724
723
725 class Instance(ClassBasedTraitType):
724 class Instance(ClassBasedTraitType):
726 """A trait whose value must be an instance of a specified class.
725 """A trait whose value must be an instance of a specified class.
727
726
728 The value can also be an instance of a subclass of the specified class.
727 The value can also be an instance of a subclass of the specified class.
729 """
728 """
730
729
731 def __init__(self, klass=None, args=None, kw=None,
730 def __init__(self, klass=None, args=None, kw=None,
732 allow_none=True, **metadata ):
731 allow_none=True, **metadata ):
733 """Construct an Instance trait.
732 """Construct an Instance trait.
734
733
735 This trait allows values that are instances of a particular
734 This trait allows values that are instances of a particular
736 class or its sublclasses. Our implementation is quite different
735 class or its sublclasses. Our implementation is quite different
737 from that of enthough.traits as we don't allow instances to be used
736 from that of enthough.traits as we don't allow instances to be used
738 for klass and we handle the ``args`` and ``kw`` arguments differently.
737 for klass and we handle the ``args`` and ``kw`` arguments differently.
739
738
740 Parameters
739 Parameters
741 ----------
740 ----------
742 klass : class, str
741 klass : class, str
743 The class that forms the basis for the trait. Class names
742 The class that forms the basis for the trait. Class names
744 can also be specified as strings, like 'foo.bar.Bar'.
743 can also be specified as strings, like 'foo.bar.Bar'.
745 args : tuple
744 args : tuple
746 Positional arguments for generating the default value.
745 Positional arguments for generating the default value.
747 kw : dict
746 kw : dict
748 Keyword arguments for generating the default value.
747 Keyword arguments for generating the default value.
749 allow_none : bool
748 allow_none : bool
750 Indicates whether None is allowed as a value.
749 Indicates whether None is allowed as a value.
751
750
752 Default Value
751 Default Value
753 -------------
752 -------------
754 If both ``args`` and ``kw`` are None, then the default value is None.
753 If both ``args`` and ``kw`` are None, then the default value is None.
755 If ``args`` is a tuple and ``kw`` is a dict, then the default is
754 If ``args`` is a tuple and ``kw`` is a dict, then the default is
756 created as ``klass(*args, **kw)``. If either ``args`` or ``kw`` is
755 created as ``klass(*args, **kw)``. If either ``args`` or ``kw`` is
757 not (but not both), None is replace by ``()`` or ``{}``.
756 not (but not both), None is replace by ``()`` or ``{}``.
758 """
757 """
759
758
760 self._allow_none = allow_none
759 self._allow_none = allow_none
761
760
762 if (klass is None) or (not (inspect.isclass(klass) or isinstance(klass, basestring))):
761 if (klass is None) or (not (inspect.isclass(klass) or isinstance(klass, basestring))):
763 raise TraitError('The klass argument must be a class'
762 raise TraitError('The klass argument must be a class'
764 ' you gave: %r' % klass)
763 ' you gave: %r' % klass)
765 self.klass = klass
764 self.klass = klass
766
765
767 # self.klass is a class, so handle default_value
766 # self.klass is a class, so handle default_value
768 if args is None and kw is None:
767 if args is None and kw is None:
769 default_value = None
768 default_value = None
770 else:
769 else:
771 if args is None:
770 if args is None:
772 # kw is not None
771 # kw is not None
773 args = ()
772 args = ()
774 elif kw is None:
773 elif kw is None:
775 # args is not None
774 # args is not None
776 kw = {}
775 kw = {}
777
776
778 if not isinstance(kw, dict):
777 if not isinstance(kw, dict):
779 raise TraitError("The 'kw' argument must be a dict or None.")
778 raise TraitError("The 'kw' argument must be a dict or None.")
780 if not isinstance(args, tuple):
779 if not isinstance(args, tuple):
781 raise TraitError("The 'args' argument must be a tuple or None.")
780 raise TraitError("The 'args' argument must be a tuple or None.")
782
781
783 default_value = DefaultValueGenerator(*args, **kw)
782 default_value = DefaultValueGenerator(*args, **kw)
784
783
785 super(Instance, self).__init__(default_value, **metadata)
784 super(Instance, self).__init__(default_value, **metadata)
786
785
787 def validate(self, obj, value):
786 def validate(self, obj, value):
788 if value is None:
787 if value is None:
789 if self._allow_none:
788 if self._allow_none:
790 return value
789 return value
791 self.error(obj, value)
790 self.error(obj, value)
792
791
793 if isinstance(value, self.klass):
792 if isinstance(value, self.klass):
794 return value
793 return value
795 else:
794 else:
796 self.error(obj, value)
795 self.error(obj, value)
797
796
798 def info(self):
797 def info(self):
799 if isinstance(self.klass, basestring):
798 if isinstance(self.klass, basestring):
800 klass = self.klass
799 klass = self.klass
801 else:
800 else:
802 klass = self.klass.__name__
801 klass = self.klass.__name__
803 result = class_of(klass)
802 result = class_of(klass)
804 if self._allow_none:
803 if self._allow_none:
805 return result + ' or None'
804 return result + ' or None'
806
805
807 return result
806 return result
808
807
809 def instance_init(self, obj):
808 def instance_init(self, obj):
810 self._resolve_classes()
809 self._resolve_classes()
811 super(Instance, self).instance_init(obj)
810 super(Instance, self).instance_init(obj)
812
811
813 def _resolve_classes(self):
812 def _resolve_classes(self):
814 if isinstance(self.klass, basestring):
813 if isinstance(self.klass, basestring):
815 self.klass = import_item(self.klass)
814 self.klass = import_item(self.klass)
816
815
817 def get_default_value(self):
816 def get_default_value(self):
818 """Instantiate a default value instance.
817 """Instantiate a default value instance.
819
818
820 This is called when the containing HasTraits classes'
819 This is called when the containing HasTraits classes'
821 :meth:`__new__` method is called to ensure that a unique instance
820 :meth:`__new__` method is called to ensure that a unique instance
822 is created for each HasTraits instance.
821 is created for each HasTraits instance.
823 """
822 """
824 dv = self.default_value
823 dv = self.default_value
825 if isinstance(dv, DefaultValueGenerator):
824 if isinstance(dv, DefaultValueGenerator):
826 return dv.generate(self.klass)
825 return dv.generate(self.klass)
827 else:
826 else:
828 return dv
827 return dv
829
828
830
829
831 class This(ClassBasedTraitType):
830 class This(ClassBasedTraitType):
832 """A trait for instances of the class containing this trait.
831 """A trait for instances of the class containing this trait.
833
832
834 Because how how and when class bodies are executed, the ``This``
833 Because how how and when class bodies are executed, the ``This``
835 trait can only have a default value of None. This, and because we
834 trait can only have a default value of None. This, and because we
836 always validate default values, ``allow_none`` is *always* true.
835 always validate default values, ``allow_none`` is *always* true.
837 """
836 """
838
837
839 info_text = 'an instance of the same type as the receiver or None'
838 info_text = 'an instance of the same type as the receiver or None'
840
839
841 def __init__(self, **metadata):
840 def __init__(self, **metadata):
842 super(This, self).__init__(None, **metadata)
841 super(This, self).__init__(None, **metadata)
843
842
844 def validate(self, obj, value):
843 def validate(self, obj, value):
845 # What if value is a superclass of obj.__class__? This is
844 # What if value is a superclass of obj.__class__? This is
846 # complicated if it was the superclass that defined the This
845 # complicated if it was the superclass that defined the This
847 # trait.
846 # trait.
848 if isinstance(value, self.this_class) or (value is None):
847 if isinstance(value, self.this_class) or (value is None):
849 return value
848 return value
850 else:
849 else:
851 self.error(obj, value)
850 self.error(obj, value)
852
851
853
852
854 #-----------------------------------------------------------------------------
853 #-----------------------------------------------------------------------------
855 # Basic TraitTypes implementations/subclasses
854 # Basic TraitTypes implementations/subclasses
856 #-----------------------------------------------------------------------------
855 #-----------------------------------------------------------------------------
857
856
858
857
859 class Any(TraitType):
858 class Any(TraitType):
860 default_value = None
859 default_value = None
861 info_text = 'any value'
860 info_text = 'any value'
862
861
863
862
864 class Int(TraitType):
863 class Int(TraitType):
865 """A integer trait."""
864 """A integer trait."""
866
865
867 default_value = 0
866 default_value = 0
868 info_text = 'an integer'
867 info_text = 'an integer'
869
868
870 def validate(self, obj, value):
869 def validate(self, obj, value):
871 if isinstance(value, int):
870 if isinstance(value, int):
872 return value
871 return value
873 self.error(obj, value)
872 self.error(obj, value)
874
873
875 class CInt(Int):
874 class CInt(Int):
876 """A casting version of the int trait."""
875 """A casting version of the int trait."""
877
876
878 def validate(self, obj, value):
877 def validate(self, obj, value):
879 try:
878 try:
880 return int(value)
879 return int(value)
881 except:
880 except:
882 self.error(obj, value)
881 self.error(obj, value)
883
882
884
883
885 class Long(TraitType):
884 class Long(TraitType):
886 """A long integer trait."""
885 """A long integer trait."""
887
886
888 default_value = 0L
887 default_value = 0L
889 info_text = 'a long'
888 info_text = 'a long'
890
889
891 def validate(self, obj, value):
890 def validate(self, obj, value):
892 if isinstance(value, long):
891 if isinstance(value, long):
893 return value
892 return value
894 if isinstance(value, int):
893 if isinstance(value, int):
895 return long(value)
894 return long(value)
896 self.error(obj, value)
895 self.error(obj, value)
897
896
898
897
899 class CLong(Long):
898 class CLong(Long):
900 """A casting version of the long integer trait."""
899 """A casting version of the long integer trait."""
901
900
902 def validate(self, obj, value):
901 def validate(self, obj, value):
903 try:
902 try:
904 return long(value)
903 return long(value)
905 except:
904 except:
906 self.error(obj, value)
905 self.error(obj, value)
907
906
908
907
909 class Float(TraitType):
908 class Float(TraitType):
910 """A float trait."""
909 """A float trait."""
911
910
912 default_value = 0.0
911 default_value = 0.0
913 info_text = 'a float'
912 info_text = 'a float'
914
913
915 def validate(self, obj, value):
914 def validate(self, obj, value):
916 if isinstance(value, float):
915 if isinstance(value, float):
917 return value
916 return value
918 if isinstance(value, int):
917 if isinstance(value, int):
919 return float(value)
918 return float(value)
920 self.error(obj, value)
919 self.error(obj, value)
921
920
922
921
923 class CFloat(Float):
922 class CFloat(Float):
924 """A casting version of the float trait."""
923 """A casting version of the float trait."""
925
924
926 def validate(self, obj, value):
925 def validate(self, obj, value):
927 try:
926 try:
928 return float(value)
927 return float(value)
929 except:
928 except:
930 self.error(obj, value)
929 self.error(obj, value)
931
930
932 class Complex(TraitType):
931 class Complex(TraitType):
933 """A trait for complex numbers."""
932 """A trait for complex numbers."""
934
933
935 default_value = 0.0 + 0.0j
934 default_value = 0.0 + 0.0j
936 info_text = 'a complex number'
935 info_text = 'a complex number'
937
936
938 def validate(self, obj, value):
937 def validate(self, obj, value):
939 if isinstance(value, complex):
938 if isinstance(value, complex):
940 return value
939 return value
941 if isinstance(value, (float, int)):
940 if isinstance(value, (float, int)):
942 return complex(value)
941 return complex(value)
943 self.error(obj, value)
942 self.error(obj, value)
944
943
945
944
946 class CComplex(Complex):
945 class CComplex(Complex):
947 """A casting version of the complex number trait."""
946 """A casting version of the complex number trait."""
948
947
949 def validate (self, obj, value):
948 def validate (self, obj, value):
950 try:
949 try:
951 return complex(value)
950 return complex(value)
952 except:
951 except:
953 self.error(obj, value)
952 self.error(obj, value)
954
953
955 # We should always be explicit about whether we're using bytes or unicode, both
954 # We should always be explicit about whether we're using bytes or unicode, both
956 # for Python 3 conversion and for reliable unicode behaviour on Python 2. So
955 # for Python 3 conversion and for reliable unicode behaviour on Python 2. So
957 # we don't have a Str type.
956 # we don't have a Str type.
958 class Bytes(TraitType):
957 class Bytes(TraitType):
959 """A trait for strings."""
958 """A trait for strings."""
960
959
961 default_value = ''
960 default_value = ''
962 info_text = 'a string'
961 info_text = 'a string'
963
962
964 def validate(self, obj, value):
963 def validate(self, obj, value):
965 if isinstance(value, bytes):
964 if isinstance(value, bytes):
966 return value
965 return value
967 self.error(obj, value)
966 self.error(obj, value)
968
967
969
968
970 class CBytes(Bytes):
969 class CBytes(Bytes):
971 """A casting version of the string trait."""
970 """A casting version of the string trait."""
972
971
973 def validate(self, obj, value):
972 def validate(self, obj, value):
974 try:
973 try:
975 return bytes(value)
974 return bytes(value)
976 except:
975 except:
977 self.error(obj, value)
976 self.error(obj, value)
978
977
979
978
980 class Unicode(TraitType):
979 class Unicode(TraitType):
981 """A trait for unicode strings."""
980 """A trait for unicode strings."""
982
981
983 default_value = u''
982 default_value = u''
984 info_text = 'a unicode string'
983 info_text = 'a unicode string'
985
984
986 def validate(self, obj, value):
985 def validate(self, obj, value):
987 if isinstance(value, unicode):
986 if isinstance(value, unicode):
988 return value
987 return value
989 if isinstance(value, bytes):
988 if isinstance(value, bytes):
990 return unicode(value)
989 return unicode(value)
991 self.error(obj, value)
990 self.error(obj, value)
992
991
993
992
994 class CUnicode(Unicode):
993 class CUnicode(Unicode):
995 """A casting version of the unicode trait."""
994 """A casting version of the unicode trait."""
996
995
997 def validate(self, obj, value):
996 def validate(self, obj, value):
998 try:
997 try:
999 return unicode(value)
998 return unicode(value)
1000 except:
999 except:
1001 self.error(obj, value)
1000 self.error(obj, value)
1002
1001
1003
1002
1004 class ObjectName(TraitType):
1003 class ObjectName(TraitType):
1005 """A string holding a valid object name in this version of Python.
1004 """A string holding a valid object name in this version of Python.
1006
1005
1007 This does not check that the name exists in any scope."""
1006 This does not check that the name exists in any scope."""
1008 info_text = "a valid object identifier in Python"
1007 info_text = "a valid object identifier in Python"
1009
1008
1010 if sys.version_info[0] < 3:
1009 if sys.version_info[0] < 3:
1011 # Python 2:
1010 # Python 2:
1012 _name_re = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*$")
1011 _name_re = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*$")
1013 def isidentifier(self, s):
1012 def isidentifier(self, s):
1014 return bool(self._name_re.match(s))
1013 return bool(self._name_re.match(s))
1015
1014
1016 def coerce_str(self, obj, value):
1015 def coerce_str(self, obj, value):
1017 "In Python 2, coerce ascii-only unicode to str"
1016 "In Python 2, coerce ascii-only unicode to str"
1018 if isinstance(value, unicode):
1017 if isinstance(value, unicode):
1019 try:
1018 try:
1020 return str(value)
1019 return str(value)
1021 except UnicodeEncodeError:
1020 except UnicodeEncodeError:
1022 self.error(obj, value)
1021 self.error(obj, value)
1023 return value
1022 return value
1024
1023
1025 else:
1024 else:
1026 # Python 3:
1025 # Python 3:
1027 isidentifier = staticmethod(lambda s: s.isidentifier())
1026 isidentifier = staticmethod(lambda s: s.isidentifier())
1028 coerce_str = staticmethod(lambda _,s: s)
1027 coerce_str = staticmethod(lambda _,s: s)
1029
1028
1030 def validate(self, obj, value):
1029 def validate(self, obj, value):
1031 value = self.coerce_str(obj, value)
1030 value = self.coerce_str(obj, value)
1032
1031
1033 if isinstance(value, str) and self.isidentifier(value):
1032 if isinstance(value, str) and self.isidentifier(value):
1034 return value
1033 return value
1035 self.error(obj, value)
1034 self.error(obj, value)
1036
1035
1037 class DottedObjectName(ObjectName):
1036 class DottedObjectName(ObjectName):
1038 """A string holding a valid dotted object name in Python, such as A.b3._c"""
1037 """A string holding a valid dotted object name in Python, such as A.b3._c"""
1039 def validate(self, obj, value):
1038 def validate(self, obj, value):
1040 value = self.coerce_str(obj, value)
1039 value = self.coerce_str(obj, value)
1041
1040
1042 if isinstance(value, str) and all(self.isidentifier(x) \
1041 if isinstance(value, str) and all(self.isidentifier(x) \
1043 for x in value.split('.')):
1042 for x in value.split('.')):
1044 return value
1043 return value
1045 self.error(obj, value)
1044 self.error(obj, value)
1046
1045
1047
1046
1048 class Bool(TraitType):
1047 class Bool(TraitType):
1049 """A boolean (True, False) trait."""
1048 """A boolean (True, False) trait."""
1050
1049
1051 default_value = False
1050 default_value = False
1052 info_text = 'a boolean'
1051 info_text = 'a boolean'
1053
1052
1054 def validate(self, obj, value):
1053 def validate(self, obj, value):
1055 if isinstance(value, bool):
1054 if isinstance(value, bool):
1056 return value
1055 return value
1057 self.error(obj, value)
1056 self.error(obj, value)
1058
1057
1059
1058
1060 class CBool(Bool):
1059 class CBool(Bool):
1061 """A casting version of the boolean trait."""
1060 """A casting version of the boolean trait."""
1062
1061
1063 def validate(self, obj, value):
1062 def validate(self, obj, value):
1064 try:
1063 try:
1065 return bool(value)
1064 return bool(value)
1066 except:
1065 except:
1067 self.error(obj, value)
1066 self.error(obj, value)
1068
1067
1069
1068
1070 class Enum(TraitType):
1069 class Enum(TraitType):
1071 """An enum that whose value must be in a given sequence."""
1070 """An enum that whose value must be in a given sequence."""
1072
1071
1073 def __init__(self, values, default_value=None, allow_none=True, **metadata):
1072 def __init__(self, values, default_value=None, allow_none=True, **metadata):
1074 self.values = values
1073 self.values = values
1075 self._allow_none = allow_none
1074 self._allow_none = allow_none
1076 super(Enum, self).__init__(default_value, **metadata)
1075 super(Enum, self).__init__(default_value, **metadata)
1077
1076
1078 def validate(self, obj, value):
1077 def validate(self, obj, value):
1079 if value is None:
1078 if value is None:
1080 if self._allow_none:
1079 if self._allow_none:
1081 return value
1080 return value
1082
1081
1083 if value in self.values:
1082 if value in self.values:
1084 return value
1083 return value
1085 self.error(obj, value)
1084 self.error(obj, value)
1086
1085
1087 def info(self):
1086 def info(self):
1088 """ Returns a description of the trait."""
1087 """ Returns a description of the trait."""
1089 result = 'any of ' + repr(self.values)
1088 result = 'any of ' + repr(self.values)
1090 if self._allow_none:
1089 if self._allow_none:
1091 return result + ' or None'
1090 return result + ' or None'
1092 return result
1091 return result
1093
1092
1094 class CaselessStrEnum(Enum):
1093 class CaselessStrEnum(Enum):
1095 """An enum of strings that are caseless in validate."""
1094 """An enum of strings that are caseless in validate."""
1096
1095
1097 def validate(self, obj, value):
1096 def validate(self, obj, value):
1098 if value is None:
1097 if value is None:
1099 if self._allow_none:
1098 if self._allow_none:
1100 return value
1099 return value
1101
1100
1102 if not isinstance(value, basestring):
1101 if not isinstance(value, basestring):
1103 self.error(obj, value)
1102 self.error(obj, value)
1104
1103
1105 for v in self.values:
1104 for v in self.values:
1106 if v.lower() == value.lower():
1105 if v.lower() == value.lower():
1107 return v
1106 return v
1108 self.error(obj, value)
1107 self.error(obj, value)
1109
1108
1110 class Container(Instance):
1109 class Container(Instance):
1111 """An instance of a container (list, set, etc.)
1110 """An instance of a container (list, set, etc.)
1112
1111
1113 To be subclassed by overriding klass.
1112 To be subclassed by overriding klass.
1114 """
1113 """
1115 klass = None
1114 klass = None
1116 _valid_defaults = SequenceTypes
1115 _valid_defaults = SequenceTypes
1117 _trait = None
1116 _trait = None
1118
1117
1119 def __init__(self, trait=None, default_value=None, allow_none=True,
1118 def __init__(self, trait=None, default_value=None, allow_none=True,
1120 **metadata):
1119 **metadata):
1121 """Create a container trait type from a list, set, or tuple.
1120 """Create a container trait type from a list, set, or tuple.
1122
1121
1123 The default value is created by doing ``List(default_value)``,
1122 The default value is created by doing ``List(default_value)``,
1124 which creates a copy of the ``default_value``.
1123 which creates a copy of the ``default_value``.
1125
1124
1126 ``trait`` can be specified, which restricts the type of elements
1125 ``trait`` can be specified, which restricts the type of elements
1127 in the container to that TraitType.
1126 in the container to that TraitType.
1128
1127
1129 If only one arg is given and it is not a Trait, it is taken as
1128 If only one arg is given and it is not a Trait, it is taken as
1130 ``default_value``:
1129 ``default_value``:
1131
1130
1132 ``c = List([1,2,3])``
1131 ``c = List([1,2,3])``
1133
1132
1134 Parameters
1133 Parameters
1135 ----------
1134 ----------
1136
1135
1137 trait : TraitType [ optional ]
1136 trait : TraitType [ optional ]
1138 the type for restricting the contents of the Container. If unspecified,
1137 the type for restricting the contents of the Container. If unspecified,
1139 types are not checked.
1138 types are not checked.
1140
1139
1141 default_value : SequenceType [ optional ]
1140 default_value : SequenceType [ optional ]
1142 The default value for the Trait. Must be list/tuple/set, and
1141 The default value for the Trait. Must be list/tuple/set, and
1143 will be cast to the container type.
1142 will be cast to the container type.
1144
1143
1145 allow_none : Bool [ default True ]
1144 allow_none : Bool [ default True ]
1146 Whether to allow the value to be None
1145 Whether to allow the value to be None
1147
1146
1148 **metadata : any
1147 **metadata : any
1149 further keys for extensions to the Trait (e.g. config)
1148 further keys for extensions to the Trait (e.g. config)
1150
1149
1151 """
1150 """
1152 istrait = lambda t: isinstance(t, type) and issubclass(t, TraitType)
1151 istrait = lambda t: isinstance(t, type) and issubclass(t, TraitType)
1153
1152
1154 # allow List([values]):
1153 # allow List([values]):
1155 if default_value is None and not istrait(trait):
1154 if default_value is None and not istrait(trait):
1156 default_value = trait
1155 default_value = trait
1157 trait = None
1156 trait = None
1158
1157
1159 if default_value is None:
1158 if default_value is None:
1160 args = ()
1159 args = ()
1161 elif isinstance(default_value, self._valid_defaults):
1160 elif isinstance(default_value, self._valid_defaults):
1162 args = (default_value,)
1161 args = (default_value,)
1163 else:
1162 else:
1164 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1163 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1165
1164
1166 if istrait(trait):
1165 if istrait(trait):
1167 self._trait = trait()
1166 self._trait = trait()
1168 self._trait.name = 'element'
1167 self._trait.name = 'element'
1169 elif trait is not None:
1168 elif trait is not None:
1170 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1169 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1171
1170
1172 super(Container,self).__init__(klass=self.klass, args=args,
1171 super(Container,self).__init__(klass=self.klass, args=args,
1173 allow_none=allow_none, **metadata)
1172 allow_none=allow_none, **metadata)
1174
1173
1175 def element_error(self, obj, element, validator):
1174 def element_error(self, obj, element, validator):
1176 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1175 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1177 % (self.name, class_of(obj), validator.info(), repr_type(element))
1176 % (self.name, class_of(obj), validator.info(), repr_type(element))
1178 raise TraitError(e)
1177 raise TraitError(e)
1179
1178
1180 def validate(self, obj, value):
1179 def validate(self, obj, value):
1181 value = super(Container, self).validate(obj, value)
1180 value = super(Container, self).validate(obj, value)
1182 if value is None:
1181 if value is None:
1183 return value
1182 return value
1184
1183
1185 value = self.validate_elements(obj, value)
1184 value = self.validate_elements(obj, value)
1186
1185
1187 return value
1186 return value
1188
1187
1189 def validate_elements(self, obj, value):
1188 def validate_elements(self, obj, value):
1190 validated = []
1189 validated = []
1191 if self._trait is None or isinstance(self._trait, Any):
1190 if self._trait is None or isinstance(self._trait, Any):
1192 return value
1191 return value
1193 for v in value:
1192 for v in value:
1194 try:
1193 try:
1195 v = self._trait.validate(obj, v)
1194 v = self._trait.validate(obj, v)
1196 except TraitError:
1195 except TraitError:
1197 self.element_error(obj, v, self._trait)
1196 self.element_error(obj, v, self._trait)
1198 else:
1197 else:
1199 validated.append(v)
1198 validated.append(v)
1200 return self.klass(validated)
1199 return self.klass(validated)
1201
1200
1202
1201
1203 class List(Container):
1202 class List(Container):
1204 """An instance of a Python list."""
1203 """An instance of a Python list."""
1205 klass = list
1204 klass = list
1206
1205
1207 def __init__(self, trait=None, default_value=None, minlen=0, maxlen=sys.maxint,
1206 def __init__(self, trait=None, default_value=None, minlen=0, maxlen=sys.maxint,
1208 allow_none=True, **metadata):
1207 allow_none=True, **metadata):
1209 """Create a List trait type from a list, set, or tuple.
1208 """Create a List trait type from a list, set, or tuple.
1210
1209
1211 The default value is created by doing ``List(default_value)``,
1210 The default value is created by doing ``List(default_value)``,
1212 which creates a copy of the ``default_value``.
1211 which creates a copy of the ``default_value``.
1213
1212
1214 ``trait`` can be specified, which restricts the type of elements
1213 ``trait`` can be specified, which restricts the type of elements
1215 in the container to that TraitType.
1214 in the container to that TraitType.
1216
1215
1217 If only one arg is given and it is not a Trait, it is taken as
1216 If only one arg is given and it is not a Trait, it is taken as
1218 ``default_value``:
1217 ``default_value``:
1219
1218
1220 ``c = List([1,2,3])``
1219 ``c = List([1,2,3])``
1221
1220
1222 Parameters
1221 Parameters
1223 ----------
1222 ----------
1224
1223
1225 trait : TraitType [ optional ]
1224 trait : TraitType [ optional ]
1226 the type for restricting the contents of the Container. If unspecified,
1225 the type for restricting the contents of the Container. If unspecified,
1227 types are not checked.
1226 types are not checked.
1228
1227
1229 default_value : SequenceType [ optional ]
1228 default_value : SequenceType [ optional ]
1230 The default value for the Trait. Must be list/tuple/set, and
1229 The default value for the Trait. Must be list/tuple/set, and
1231 will be cast to the container type.
1230 will be cast to the container type.
1232
1231
1233 minlen : Int [ default 0 ]
1232 minlen : Int [ default 0 ]
1234 The minimum length of the input list
1233 The minimum length of the input list
1235
1234
1236 maxlen : Int [ default sys.maxint ]
1235 maxlen : Int [ default sys.maxint ]
1237 The maximum length of the input list
1236 The maximum length of the input list
1238
1237
1239 allow_none : Bool [ default True ]
1238 allow_none : Bool [ default True ]
1240 Whether to allow the value to be None
1239 Whether to allow the value to be None
1241
1240
1242 **metadata : any
1241 **metadata : any
1243 further keys for extensions to the Trait (e.g. config)
1242 further keys for extensions to the Trait (e.g. config)
1244
1243
1245 """
1244 """
1246 self._minlen = minlen
1245 self._minlen = minlen
1247 self._maxlen = maxlen
1246 self._maxlen = maxlen
1248 super(List, self).__init__(trait=trait, default_value=default_value,
1247 super(List, self).__init__(trait=trait, default_value=default_value,
1249 allow_none=allow_none, **metadata)
1248 allow_none=allow_none, **metadata)
1250
1249
1251 def length_error(self, obj, value):
1250 def length_error(self, obj, value):
1252 e = "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." \
1251 e = "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." \
1253 % (self.name, class_of(obj), self._minlen, self._maxlen, value)
1252 % (self.name, class_of(obj), self._minlen, self._maxlen, value)
1254 raise TraitError(e)
1253 raise TraitError(e)
1255
1254
1256 def validate_elements(self, obj, value):
1255 def validate_elements(self, obj, value):
1257 length = len(value)
1256 length = len(value)
1258 if length < self._minlen or length > self._maxlen:
1257 if length < self._minlen or length > self._maxlen:
1259 self.length_error(obj, value)
1258 self.length_error(obj, value)
1260
1259
1261 return super(List, self).validate_elements(obj, value)
1260 return super(List, self).validate_elements(obj, value)
1262
1261
1263
1262
1264 class Set(Container):
1263 class Set(Container):
1265 """An instance of a Python set."""
1264 """An instance of a Python set."""
1266 klass = set
1265 klass = set
1267
1266
1268 class Tuple(Container):
1267 class Tuple(Container):
1269 """An instance of a Python tuple."""
1268 """An instance of a Python tuple."""
1270 klass = tuple
1269 klass = tuple
1271
1270
1272 def __init__(self, *traits, **metadata):
1271 def __init__(self, *traits, **metadata):
1273 """Tuple(*traits, default_value=None, allow_none=True, **medatata)
1272 """Tuple(*traits, default_value=None, allow_none=True, **medatata)
1274
1273
1275 Create a tuple from a list, set, or tuple.
1274 Create a tuple from a list, set, or tuple.
1276
1275
1277 Create a fixed-type tuple with Traits:
1276 Create a fixed-type tuple with Traits:
1278
1277
1279 ``t = Tuple(Int, Str, CStr)``
1278 ``t = Tuple(Int, Str, CStr)``
1280
1279
1281 would be length 3, with Int,Str,CStr for each element.
1280 would be length 3, with Int,Str,CStr for each element.
1282
1281
1283 If only one arg is given and it is not a Trait, it is taken as
1282 If only one arg is given and it is not a Trait, it is taken as
1284 default_value:
1283 default_value:
1285
1284
1286 ``t = Tuple((1,2,3))``
1285 ``t = Tuple((1,2,3))``
1287
1286
1288 Otherwise, ``default_value`` *must* be specified by keyword.
1287 Otherwise, ``default_value`` *must* be specified by keyword.
1289
1288
1290 Parameters
1289 Parameters
1291 ----------
1290 ----------
1292
1291
1293 *traits : TraitTypes [ optional ]
1292 *traits : TraitTypes [ optional ]
1294 the tsype for restricting the contents of the Tuple. If unspecified,
1293 the tsype for restricting the contents of the Tuple. If unspecified,
1295 types are not checked. If specified, then each positional argument
1294 types are not checked. If specified, then each positional argument
1296 corresponds to an element of the tuple. Tuples defined with traits
1295 corresponds to an element of the tuple. Tuples defined with traits
1297 are of fixed length.
1296 are of fixed length.
1298
1297
1299 default_value : SequenceType [ optional ]
1298 default_value : SequenceType [ optional ]
1300 The default value for the Tuple. Must be list/tuple/set, and
1299 The default value for the Tuple. Must be list/tuple/set, and
1301 will be cast to a tuple. If `traits` are specified, the
1300 will be cast to a tuple. If `traits` are specified, the
1302 `default_value` must conform to the shape and type they specify.
1301 `default_value` must conform to the shape and type they specify.
1303
1302
1304 allow_none : Bool [ default True ]
1303 allow_none : Bool [ default True ]
1305 Whether to allow the value to be None
1304 Whether to allow the value to be None
1306
1305
1307 **metadata : any
1306 **metadata : any
1308 further keys for extensions to the Trait (e.g. config)
1307 further keys for extensions to the Trait (e.g. config)
1309
1308
1310 """
1309 """
1311 default_value = metadata.pop('default_value', None)
1310 default_value = metadata.pop('default_value', None)
1312 allow_none = metadata.pop('allow_none', True)
1311 allow_none = metadata.pop('allow_none', True)
1313
1312
1314 istrait = lambda t: isinstance(t, type) and issubclass(t, TraitType)
1313 istrait = lambda t: isinstance(t, type) and issubclass(t, TraitType)
1315
1314
1316 # allow Tuple((values,)):
1315 # allow Tuple((values,)):
1317 if len(traits) == 1 and default_value is None and not istrait(traits[0]):
1316 if len(traits) == 1 and default_value is None and not istrait(traits[0]):
1318 default_value = traits[0]
1317 default_value = traits[0]
1319 traits = ()
1318 traits = ()
1320
1319
1321 if default_value is None:
1320 if default_value is None:
1322 args = ()
1321 args = ()
1323 elif isinstance(default_value, self._valid_defaults):
1322 elif isinstance(default_value, self._valid_defaults):
1324 args = (default_value,)
1323 args = (default_value,)
1325 else:
1324 else:
1326 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1325 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1327
1326
1328 self._traits = []
1327 self._traits = []
1329 for trait in traits:
1328 for trait in traits:
1330 t = trait()
1329 t = trait()
1331 t.name = 'element'
1330 t.name = 'element'
1332 self._traits.append(t)
1331 self._traits.append(t)
1333
1332
1334 if self._traits and default_value is None:
1333 if self._traits and default_value is None:
1335 # don't allow default to be an empty container if length is specified
1334 # don't allow default to be an empty container if length is specified
1336 args = None
1335 args = None
1337 super(Container,self).__init__(klass=self.klass, args=args,
1336 super(Container,self).__init__(klass=self.klass, args=args,
1338 allow_none=allow_none, **metadata)
1337 allow_none=allow_none, **metadata)
1339
1338
1340 def validate_elements(self, obj, value):
1339 def validate_elements(self, obj, value):
1341 if not self._traits:
1340 if not self._traits:
1342 # nothing to validate
1341 # nothing to validate
1343 return value
1342 return value
1344 if len(value) != len(self._traits):
1343 if len(value) != len(self._traits):
1345 e = "The '%s' trait of %s instance requires %i elements, but a value of %s was specified." \
1344 e = "The '%s' trait of %s instance requires %i elements, but a value of %s was specified." \
1346 % (self.name, class_of(obj), len(self._traits), repr_type(value))
1345 % (self.name, class_of(obj), len(self._traits), repr_type(value))
1347 raise TraitError(e)
1346 raise TraitError(e)
1348
1347
1349 validated = []
1348 validated = []
1350 for t,v in zip(self._traits, value):
1349 for t,v in zip(self._traits, value):
1351 try:
1350 try:
1352 v = t.validate(obj, v)
1351 v = t.validate(obj, v)
1353 except TraitError:
1352 except TraitError:
1354 self.element_error(obj, v, t)
1353 self.element_error(obj, v, t)
1355 else:
1354 else:
1356 validated.append(v)
1355 validated.append(v)
1357 return tuple(validated)
1356 return tuple(validated)
1358
1357
1359
1358
1360 class Dict(Instance):
1359 class Dict(Instance):
1361 """An instance of a Python dict."""
1360 """An instance of a Python dict."""
1362
1361
1363 def __init__(self, default_value=None, allow_none=True, **metadata):
1362 def __init__(self, default_value=None, allow_none=True, **metadata):
1364 """Create a dict trait type from a dict.
1363 """Create a dict trait type from a dict.
1365
1364
1366 The default value is created by doing ``dict(default_value)``,
1365 The default value is created by doing ``dict(default_value)``,
1367 which creates a copy of the ``default_value``.
1366 which creates a copy of the ``default_value``.
1368 """
1367 """
1369 if default_value is None:
1368 if default_value is None:
1370 args = ((),)
1369 args = ((),)
1371 elif isinstance(default_value, dict):
1370 elif isinstance(default_value, dict):
1372 args = (default_value,)
1371 args = (default_value,)
1373 elif isinstance(default_value, SequenceTypes):
1372 elif isinstance(default_value, SequenceTypes):
1374 args = (default_value,)
1373 args = (default_value,)
1375 else:
1374 else:
1376 raise TypeError('default value of Dict was %s' % default_value)
1375 raise TypeError('default value of Dict was %s' % default_value)
1377
1376
1378 super(Dict,self).__init__(klass=dict, args=args,
1377 super(Dict,self).__init__(klass=dict, args=args,
1379 allow_none=allow_none, **metadata)
1378 allow_none=allow_none, **metadata)
1380
1379
1381 class TCPAddress(TraitType):
1380 class TCPAddress(TraitType):
1382 """A trait for an (ip, port) tuple.
1381 """A trait for an (ip, port) tuple.
1383
1382
1384 This allows for both IPv4 IP addresses as well as hostnames.
1383 This allows for both IPv4 IP addresses as well as hostnames.
1385 """
1384 """
1386
1385
1387 default_value = ('127.0.0.1', 0)
1386 default_value = ('127.0.0.1', 0)
1388 info_text = 'an (ip, port) tuple'
1387 info_text = 'an (ip, port) tuple'
1389
1388
1390 def validate(self, obj, value):
1389 def validate(self, obj, value):
1391 if isinstance(value, tuple):
1390 if isinstance(value, tuple):
1392 if len(value) == 2:
1391 if len(value) == 2:
1393 if isinstance(value[0], basestring) and isinstance(value[1], int):
1392 if isinstance(value[0], basestring) and isinstance(value[1], int):
1394 port = value[1]
1393 port = value[1]
1395 if port >= 0 and port <= 65535:
1394 if port >= 0 and port <= 65535:
1396 return value
1395 return value
1397 self.error(obj, value)
1396 self.error(obj, value)
1 NO CONTENT: modified file chmod 100644 => 100755
NO CONTENT: modified file chmod 100644 => 100755
@@ -1,227 +1,226 b''
1 #!/usr/bin/env python
2 """An Application for launching a kernel
1 """An Application for launching a kernel
3
2
4 Authors
3 Authors
5 -------
4 -------
6 * MinRK
5 * MinRK
7 """
6 """
8 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
8 # Copyright (C) 2011 The IPython Development Team
10 #
9 #
11 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING.txt, distributed as part of this software.
11 # the file COPYING.txt, distributed as part of this software.
13 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
14
13
15 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
16 # Imports
15 # Imports
17 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
18
17
19 # Standard library imports.
18 # Standard library imports.
20 import os
19 import os
21 import sys
20 import sys
22
21
23 # System library imports.
22 # System library imports.
24 import zmq
23 import zmq
25
24
26 # IPython imports.
25 # IPython imports.
27 from IPython.core.ultratb import FormattedTB
26 from IPython.core.ultratb import FormattedTB
28 from IPython.core.application import (
27 from IPython.core.application import (
29 BaseIPythonApplication, base_flags, base_aliases
28 BaseIPythonApplication, base_flags, base_aliases
30 )
29 )
31 from IPython.utils import io
30 from IPython.utils import io
32 from IPython.utils.localinterfaces import LOCALHOST
31 from IPython.utils.localinterfaces import LOCALHOST
33 from IPython.utils.traitlets import (Any, Instance, Dict, Unicode, Int, Bool,
32 from IPython.utils.traitlets import (Any, Instance, Dict, Unicode, Int, Bool,
34 DottedObjectName)
33 DottedObjectName)
35 from IPython.utils.importstring import import_item
34 from IPython.utils.importstring import import_item
36 # local imports
35 # local imports
37 from IPython.zmq.heartbeat import Heartbeat
36 from IPython.zmq.heartbeat import Heartbeat
38 from IPython.zmq.parentpoller import ParentPollerUnix, ParentPollerWindows
37 from IPython.zmq.parentpoller import ParentPollerUnix, ParentPollerWindows
39 from IPython.zmq.session import Session
38 from IPython.zmq.session import Session
40
39
41
40
42 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
43 # Flags and Aliases
42 # Flags and Aliases
44 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
45
44
46 kernel_aliases = dict(base_aliases)
45 kernel_aliases = dict(base_aliases)
47 kernel_aliases.update({
46 kernel_aliases.update({
48 'ip' : 'KernelApp.ip',
47 'ip' : 'KernelApp.ip',
49 'hb' : 'KernelApp.hb_port',
48 'hb' : 'KernelApp.hb_port',
50 'shell' : 'KernelApp.shell_port',
49 'shell' : 'KernelApp.shell_port',
51 'iopub' : 'KernelApp.iopub_port',
50 'iopub' : 'KernelApp.iopub_port',
52 'stdin' : 'KernelApp.stdin_port',
51 'stdin' : 'KernelApp.stdin_port',
53 'parent': 'KernelApp.parent',
52 'parent': 'KernelApp.parent',
54 })
53 })
55 if sys.platform.startswith('win'):
54 if sys.platform.startswith('win'):
56 kernel_aliases['interrupt'] = 'KernelApp.interrupt'
55 kernel_aliases['interrupt'] = 'KernelApp.interrupt'
57
56
58 kernel_flags = dict(base_flags)
57 kernel_flags = dict(base_flags)
59 kernel_flags.update({
58 kernel_flags.update({
60 'no-stdout' : (
59 'no-stdout' : (
61 {'KernelApp' : {'no_stdout' : True}},
60 {'KernelApp' : {'no_stdout' : True}},
62 "redirect stdout to the null device"),
61 "redirect stdout to the null device"),
63 'no-stderr' : (
62 'no-stderr' : (
64 {'KernelApp' : {'no_stderr' : True}},
63 {'KernelApp' : {'no_stderr' : True}},
65 "redirect stderr to the null device"),
64 "redirect stderr to the null device"),
66 })
65 })
67
66
68
67
69 #-----------------------------------------------------------------------------
68 #-----------------------------------------------------------------------------
70 # Application class for starting a Kernel
69 # Application class for starting a Kernel
71 #-----------------------------------------------------------------------------
70 #-----------------------------------------------------------------------------
72
71
73 class KernelApp(BaseIPythonApplication):
72 class KernelApp(BaseIPythonApplication):
74 name='pykernel'
73 name='pykernel'
75 aliases = Dict(kernel_aliases)
74 aliases = Dict(kernel_aliases)
76 flags = Dict(kernel_flags)
75 flags = Dict(kernel_flags)
77 classes = [Session]
76 classes = [Session]
78 # the kernel class, as an importstring
77 # the kernel class, as an importstring
79 kernel_class = DottedObjectName('IPython.zmq.pykernel.Kernel')
78 kernel_class = DottedObjectName('IPython.zmq.pykernel.Kernel')
80 kernel = Any()
79 kernel = Any()
81 poller = Any() # don't restrict this even though current pollers are all Threads
80 poller = Any() # don't restrict this even though current pollers are all Threads
82 heartbeat = Instance(Heartbeat)
81 heartbeat = Instance(Heartbeat)
83 session = Instance('IPython.zmq.session.Session')
82 session = Instance('IPython.zmq.session.Session')
84 ports = Dict()
83 ports = Dict()
85
84
86 # inherit config file name from parent:
85 # inherit config file name from parent:
87 parent_appname = Unicode(config=True)
86 parent_appname = Unicode(config=True)
88 def _parent_appname_changed(self, name, old, new):
87 def _parent_appname_changed(self, name, old, new):
89 if self.config_file_specified:
88 if self.config_file_specified:
90 # it was manually specified, ignore
89 # it was manually specified, ignore
91 return
90 return
92 self.config_file_name = new.replace('-','_') + u'_config.py'
91 self.config_file_name = new.replace('-','_') + u'_config.py'
93 # don't let this count as specifying the config file
92 # don't let this count as specifying the config file
94 self.config_file_specified = False
93 self.config_file_specified = False
95
94
96 # connection info:
95 # connection info:
97 ip = Unicode(LOCALHOST, config=True,
96 ip = Unicode(LOCALHOST, config=True,
98 help="Set the IP or interface on which the kernel will listen.")
97 help="Set the IP or interface on which the kernel will listen.")
99 hb_port = Int(0, config=True, help="set the heartbeat port [default: random]")
98 hb_port = Int(0, config=True, help="set the heartbeat port [default: random]")
100 shell_port = Int(0, config=True, help="set the shell (XREP) port [default: random]")
99 shell_port = Int(0, config=True, help="set the shell (XREP) port [default: random]")
101 iopub_port = Int(0, config=True, help="set the iopub (PUB) port [default: random]")
100 iopub_port = Int(0, config=True, help="set the iopub (PUB) port [default: random]")
102 stdin_port = Int(0, config=True, help="set the stdin (XREQ) port [default: random]")
101 stdin_port = Int(0, config=True, help="set the stdin (XREQ) port [default: random]")
103
102
104 # streams, etc.
103 # streams, etc.
105 no_stdout = Bool(False, config=True, help="redirect stdout to the null device")
104 no_stdout = Bool(False, config=True, help="redirect stdout to the null device")
106 no_stderr = Bool(False, config=True, help="redirect stderr to the null device")
105 no_stderr = Bool(False, config=True, help="redirect stderr to the null device")
107 outstream_class = DottedObjectName('IPython.zmq.iostream.OutStream',
106 outstream_class = DottedObjectName('IPython.zmq.iostream.OutStream',
108 config=True, help="The importstring for the OutStream factory")
107 config=True, help="The importstring for the OutStream factory")
109 displayhook_class = DottedObjectName('IPython.zmq.displayhook.ZMQDisplayHook',
108 displayhook_class = DottedObjectName('IPython.zmq.displayhook.ZMQDisplayHook',
110 config=True, help="The importstring for the DisplayHook factory")
109 config=True, help="The importstring for the DisplayHook factory")
111
110
112 # polling
111 # polling
113 parent = Int(0, config=True,
112 parent = Int(0, config=True,
114 help="""kill this process if its parent dies. On Windows, the argument
113 help="""kill this process if its parent dies. On Windows, the argument
115 specifies the HANDLE of the parent process, otherwise it is simply boolean.
114 specifies the HANDLE of the parent process, otherwise it is simply boolean.
116 """)
115 """)
117 interrupt = Int(0, config=True,
116 interrupt = Int(0, config=True,
118 help="""ONLY USED ON WINDOWS
117 help="""ONLY USED ON WINDOWS
119 Interrupt this process when the parent is signalled.
118 Interrupt this process when the parent is signalled.
120 """)
119 """)
121
120
122 def init_crash_handler(self):
121 def init_crash_handler(self):
123 # Install minimal exception handling
122 # Install minimal exception handling
124 sys.excepthook = FormattedTB(mode='Verbose', color_scheme='NoColor',
123 sys.excepthook = FormattedTB(mode='Verbose', color_scheme='NoColor',
125 ostream=sys.__stdout__)
124 ostream=sys.__stdout__)
126
125
127 def init_poller(self):
126 def init_poller(self):
128 if sys.platform == 'win32':
127 if sys.platform == 'win32':
129 if self.interrupt or self.parent:
128 if self.interrupt or self.parent:
130 self.poller = ParentPollerWindows(self.interrupt, self.parent)
129 self.poller = ParentPollerWindows(self.interrupt, self.parent)
131 elif self.parent:
130 elif self.parent:
132 self.poller = ParentPollerUnix()
131 self.poller = ParentPollerUnix()
133
132
134 def _bind_socket(self, s, port):
133 def _bind_socket(self, s, port):
135 iface = 'tcp://%s' % self.ip
134 iface = 'tcp://%s' % self.ip
136 if port <= 0:
135 if port <= 0:
137 port = s.bind_to_random_port(iface)
136 port = s.bind_to_random_port(iface)
138 else:
137 else:
139 s.bind(iface + ':%i'%port)
138 s.bind(iface + ':%i'%port)
140 return port
139 return port
141
140
142 def init_sockets(self):
141 def init_sockets(self):
143 # Create a context, a session, and the kernel sockets.
142 # Create a context, a session, and the kernel sockets.
144 self.log.info("Starting the kernel at pid: %i", os.getpid())
143 self.log.info("Starting the kernel at pid: %i", os.getpid())
145 context = zmq.Context.instance()
144 context = zmq.Context.instance()
146 # Uncomment this to try closing the context.
145 # Uncomment this to try closing the context.
147 # atexit.register(context.term)
146 # atexit.register(context.term)
148
147
149 self.shell_socket = context.socket(zmq.XREP)
148 self.shell_socket = context.socket(zmq.XREP)
150 self.shell_port = self._bind_socket(self.shell_socket, self.shell_port)
149 self.shell_port = self._bind_socket(self.shell_socket, self.shell_port)
151 self.log.debug("shell XREP Channel on port: %i"%self.shell_port)
150 self.log.debug("shell XREP Channel on port: %i"%self.shell_port)
152
151
153 self.iopub_socket = context.socket(zmq.PUB)
152 self.iopub_socket = context.socket(zmq.PUB)
154 self.iopub_port = self._bind_socket(self.iopub_socket, self.iopub_port)
153 self.iopub_port = self._bind_socket(self.iopub_socket, self.iopub_port)
155 self.log.debug("iopub PUB Channel on port: %i"%self.iopub_port)
154 self.log.debug("iopub PUB Channel on port: %i"%self.iopub_port)
156
155
157 self.stdin_socket = context.socket(zmq.XREQ)
156 self.stdin_socket = context.socket(zmq.XREQ)
158 self.stdin_port = self._bind_socket(self.stdin_socket, self.stdin_port)
157 self.stdin_port = self._bind_socket(self.stdin_socket, self.stdin_port)
159 self.log.debug("stdin XREQ Channel on port: %i"%self.stdin_port)
158 self.log.debug("stdin XREQ Channel on port: %i"%self.stdin_port)
160
159
161 self.heartbeat = Heartbeat(context, (self.ip, self.hb_port))
160 self.heartbeat = Heartbeat(context, (self.ip, self.hb_port))
162 self.hb_port = self.heartbeat.port
161 self.hb_port = self.heartbeat.port
163 self.log.debug("Heartbeat REP Channel on port: %i"%self.hb_port)
162 self.log.debug("Heartbeat REP Channel on port: %i"%self.hb_port)
164
163
165 # Helper to make it easier to connect to an existing kernel, until we have
164 # Helper to make it easier to connect to an existing kernel, until we have
166 # single-port connection negotiation fully implemented.
165 # single-port connection negotiation fully implemented.
167 # set log-level to critical, to make sure it is output
166 # set log-level to critical, to make sure it is output
168 self.log.critical("To connect another client to this kernel, use:")
167 self.log.critical("To connect another client to this kernel, use:")
169 self.log.critical("--existing --shell={0} --iopub={1} --stdin={2} --hb={3}".format(
168 self.log.critical("--existing --shell={0} --iopub={1} --stdin={2} --hb={3}".format(
170 self.shell_port, self.iopub_port, self.stdin_port, self.hb_port))
169 self.shell_port, self.iopub_port, self.stdin_port, self.hb_port))
171
170
172
171
173 self.ports = dict(shell=self.shell_port, iopub=self.iopub_port,
172 self.ports = dict(shell=self.shell_port, iopub=self.iopub_port,
174 stdin=self.stdin_port, hb=self.hb_port)
173 stdin=self.stdin_port, hb=self.hb_port)
175
174
176 def init_session(self):
175 def init_session(self):
177 """create our session object"""
176 """create our session object"""
178 self.session = Session(config=self.config, username=u'kernel')
177 self.session = Session(config=self.config, username=u'kernel')
179
178
180 def init_blackhole(self):
179 def init_blackhole(self):
181 """redirects stdout/stderr to devnull if necessary"""
180 """redirects stdout/stderr to devnull if necessary"""
182 if self.no_stdout or self.no_stderr:
181 if self.no_stdout or self.no_stderr:
183 blackhole = file(os.devnull, 'w')
182 blackhole = file(os.devnull, 'w')
184 if self.no_stdout:
183 if self.no_stdout:
185 sys.stdout = sys.__stdout__ = blackhole
184 sys.stdout = sys.__stdout__ = blackhole
186 if self.no_stderr:
185 if self.no_stderr:
187 sys.stderr = sys.__stderr__ = blackhole
186 sys.stderr = sys.__stderr__ = blackhole
188
187
189 def init_io(self):
188 def init_io(self):
190 """Redirect input streams and set a display hook."""
189 """Redirect input streams and set a display hook."""
191 if self.outstream_class:
190 if self.outstream_class:
192 outstream_factory = import_item(str(self.outstream_class))
191 outstream_factory = import_item(str(self.outstream_class))
193 sys.stdout = outstream_factory(self.session, self.iopub_socket, u'stdout')
192 sys.stdout = outstream_factory(self.session, self.iopub_socket, u'stdout')
194 sys.stderr = outstream_factory(self.session, self.iopub_socket, u'stderr')
193 sys.stderr = outstream_factory(self.session, self.iopub_socket, u'stderr')
195 if self.displayhook_class:
194 if self.displayhook_class:
196 displayhook_factory = import_item(str(self.displayhook_class))
195 displayhook_factory = import_item(str(self.displayhook_class))
197 sys.displayhook = displayhook_factory(self.session, self.iopub_socket)
196 sys.displayhook = displayhook_factory(self.session, self.iopub_socket)
198
197
199 def init_kernel(self):
198 def init_kernel(self):
200 """Create the Kernel object itself"""
199 """Create the Kernel object itself"""
201 kernel_factory = import_item(str(self.kernel_class))
200 kernel_factory = import_item(str(self.kernel_class))
202 self.kernel = kernel_factory(config=self.config, session=self.session,
201 self.kernel = kernel_factory(config=self.config, session=self.session,
203 shell_socket=self.shell_socket,
202 shell_socket=self.shell_socket,
204 iopub_socket=self.iopub_socket,
203 iopub_socket=self.iopub_socket,
205 stdin_socket=self.stdin_socket,
204 stdin_socket=self.stdin_socket,
206 log=self.log
205 log=self.log
207 )
206 )
208 self.kernel.record_ports(self.ports)
207 self.kernel.record_ports(self.ports)
209
208
210 def initialize(self, argv=None):
209 def initialize(self, argv=None):
211 super(KernelApp, self).initialize(argv)
210 super(KernelApp, self).initialize(argv)
212 self.init_blackhole()
211 self.init_blackhole()
213 self.init_session()
212 self.init_session()
214 self.init_poller()
213 self.init_poller()
215 self.init_sockets()
214 self.init_sockets()
216 self.init_io()
215 self.init_io()
217 self.init_kernel()
216 self.init_kernel()
218
217
219 def start(self):
218 def start(self):
220 self.heartbeat.start()
219 self.heartbeat.start()
221 if self.poller is not None:
220 if self.poller is not None:
222 self.poller.start()
221 self.poller.start()
223 try:
222 try:
224 self.kernel.start()
223 self.kernel.start()
225 except KeyboardInterrupt:
224 except KeyboardInterrupt:
226 pass
225 pass
227
226
@@ -1,679 +1,678 b''
1 #!/usr/bin/env python
2 """Session object for building, serializing, sending, and receiving messages in
1 """Session object for building, serializing, sending, and receiving messages in
3 IPython. The Session object supports serialization, HMAC signatures, and
2 IPython. The Session object supports serialization, HMAC signatures, and
4 metadata on messages.
3 metadata on messages.
5
4
6 Also defined here are utilities for working with Sessions:
5 Also defined here are utilities for working with Sessions:
7 * A SessionFactory to be used as a base class for configurables that work with
6 * A SessionFactory to be used as a base class for configurables that work with
8 Sessions.
7 Sessions.
9 * A Message object for convenience that allows attribute-access to the msg dict.
8 * A Message object for convenience that allows attribute-access to the msg dict.
10
9
11 Authors:
10 Authors:
12
11
13 * Min RK
12 * Min RK
14 * Brian Granger
13 * Brian Granger
15 * Fernando Perez
14 * Fernando Perez
16 """
15 """
17 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
18 # Copyright (C) 2010-2011 The IPython Development Team
17 # Copyright (C) 2010-2011 The IPython Development Team
19 #
18 #
20 # Distributed under the terms of the BSD License. The full license is in
19 # Distributed under the terms of the BSD License. The full license is in
21 # the file COPYING, distributed as part of this software.
20 # the file COPYING, distributed as part of this software.
22 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
23
22
24 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
25 # Imports
24 # Imports
26 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
27
26
28 import hmac
27 import hmac
29 import logging
28 import logging
30 import os
29 import os
31 import pprint
30 import pprint
32 import uuid
31 import uuid
33 from datetime import datetime
32 from datetime import datetime
34
33
35 try:
34 try:
36 import cPickle
35 import cPickle
37 pickle = cPickle
36 pickle = cPickle
38 except:
37 except:
39 cPickle = None
38 cPickle = None
40 import pickle
39 import pickle
41
40
42 import zmq
41 import zmq
43 from zmq.utils import jsonapi
42 from zmq.utils import jsonapi
44 from zmq.eventloop.ioloop import IOLoop
43 from zmq.eventloop.ioloop import IOLoop
45 from zmq.eventloop.zmqstream import ZMQStream
44 from zmq.eventloop.zmqstream import ZMQStream
46
45
47 from IPython.config.configurable import Configurable, LoggingConfigurable
46 from IPython.config.configurable import Configurable, LoggingConfigurable
48 from IPython.utils.importstring import import_item
47 from IPython.utils.importstring import import_item
49 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
48 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
50 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
49 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
51 DottedObjectName)
50 DottedObjectName)
52
51
53 #-----------------------------------------------------------------------------
52 #-----------------------------------------------------------------------------
54 # utility functions
53 # utility functions
55 #-----------------------------------------------------------------------------
54 #-----------------------------------------------------------------------------
56
55
57 def squash_unicode(obj):
56 def squash_unicode(obj):
58 """coerce unicode back to bytestrings."""
57 """coerce unicode back to bytestrings."""
59 if isinstance(obj,dict):
58 if isinstance(obj,dict):
60 for key in obj.keys():
59 for key in obj.keys():
61 obj[key] = squash_unicode(obj[key])
60 obj[key] = squash_unicode(obj[key])
62 if isinstance(key, unicode):
61 if isinstance(key, unicode):
63 obj[squash_unicode(key)] = obj.pop(key)
62 obj[squash_unicode(key)] = obj.pop(key)
64 elif isinstance(obj, list):
63 elif isinstance(obj, list):
65 for i,v in enumerate(obj):
64 for i,v in enumerate(obj):
66 obj[i] = squash_unicode(v)
65 obj[i] = squash_unicode(v)
67 elif isinstance(obj, unicode):
66 elif isinstance(obj, unicode):
68 obj = obj.encode('utf8')
67 obj = obj.encode('utf8')
69 return obj
68 return obj
70
69
71 #-----------------------------------------------------------------------------
70 #-----------------------------------------------------------------------------
72 # globals and defaults
71 # globals and defaults
73 #-----------------------------------------------------------------------------
72 #-----------------------------------------------------------------------------
74 key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
73 key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
75 json_packer = lambda obj: jsonapi.dumps(obj, **{key:date_default})
74 json_packer = lambda obj: jsonapi.dumps(obj, **{key:date_default})
76 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
75 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
77
76
78 pickle_packer = lambda o: pickle.dumps(o,-1)
77 pickle_packer = lambda o: pickle.dumps(o,-1)
79 pickle_unpacker = pickle.loads
78 pickle_unpacker = pickle.loads
80
79
81 default_packer = json_packer
80 default_packer = json_packer
82 default_unpacker = json_unpacker
81 default_unpacker = json_unpacker
83
82
84
83
85 DELIM=b"<IDS|MSG>"
84 DELIM=b"<IDS|MSG>"
86
85
87 #-----------------------------------------------------------------------------
86 #-----------------------------------------------------------------------------
88 # Classes
87 # Classes
89 #-----------------------------------------------------------------------------
88 #-----------------------------------------------------------------------------
90
89
91 class SessionFactory(LoggingConfigurable):
90 class SessionFactory(LoggingConfigurable):
92 """The Base class for configurables that have a Session, Context, logger,
91 """The Base class for configurables that have a Session, Context, logger,
93 and IOLoop.
92 and IOLoop.
94 """
93 """
95
94
96 logname = Unicode('')
95 logname = Unicode('')
97 def _logname_changed(self, name, old, new):
96 def _logname_changed(self, name, old, new):
98 self.log = logging.getLogger(new)
97 self.log = logging.getLogger(new)
99
98
100 # not configurable:
99 # not configurable:
101 context = Instance('zmq.Context')
100 context = Instance('zmq.Context')
102 def _context_default(self):
101 def _context_default(self):
103 return zmq.Context.instance()
102 return zmq.Context.instance()
104
103
105 session = Instance('IPython.zmq.session.Session')
104 session = Instance('IPython.zmq.session.Session')
106
105
107 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
106 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
108 def _loop_default(self):
107 def _loop_default(self):
109 return IOLoop.instance()
108 return IOLoop.instance()
110
109
111 def __init__(self, **kwargs):
110 def __init__(self, **kwargs):
112 super(SessionFactory, self).__init__(**kwargs)
111 super(SessionFactory, self).__init__(**kwargs)
113
112
114 if self.session is None:
113 if self.session is None:
115 # construct the session
114 # construct the session
116 self.session = Session(**kwargs)
115 self.session = Session(**kwargs)
117
116
118
117
119 class Message(object):
118 class Message(object):
120 """A simple message object that maps dict keys to attributes.
119 """A simple message object that maps dict keys to attributes.
121
120
122 A Message can be created from a dict and a dict from a Message instance
121 A Message can be created from a dict and a dict from a Message instance
123 simply by calling dict(msg_obj)."""
122 simply by calling dict(msg_obj)."""
124
123
125 def __init__(self, msg_dict):
124 def __init__(self, msg_dict):
126 dct = self.__dict__
125 dct = self.__dict__
127 for k, v in dict(msg_dict).iteritems():
126 for k, v in dict(msg_dict).iteritems():
128 if isinstance(v, dict):
127 if isinstance(v, dict):
129 v = Message(v)
128 v = Message(v)
130 dct[k] = v
129 dct[k] = v
131
130
132 # Having this iterator lets dict(msg_obj) work out of the box.
131 # Having this iterator lets dict(msg_obj) work out of the box.
133 def __iter__(self):
132 def __iter__(self):
134 return iter(self.__dict__.iteritems())
133 return iter(self.__dict__.iteritems())
135
134
136 def __repr__(self):
135 def __repr__(self):
137 return repr(self.__dict__)
136 return repr(self.__dict__)
138
137
139 def __str__(self):
138 def __str__(self):
140 return pprint.pformat(self.__dict__)
139 return pprint.pformat(self.__dict__)
141
140
142 def __contains__(self, k):
141 def __contains__(self, k):
143 return k in self.__dict__
142 return k in self.__dict__
144
143
145 def __getitem__(self, k):
144 def __getitem__(self, k):
146 return self.__dict__[k]
145 return self.__dict__[k]
147
146
148
147
149 def msg_header(msg_id, msg_type, username, session):
148 def msg_header(msg_id, msg_type, username, session):
150 date = datetime.now()
149 date = datetime.now()
151 return locals()
150 return locals()
152
151
153 def extract_header(msg_or_header):
152 def extract_header(msg_or_header):
154 """Given a message or header, return the header."""
153 """Given a message or header, return the header."""
155 if not msg_or_header:
154 if not msg_or_header:
156 return {}
155 return {}
157 try:
156 try:
158 # See if msg_or_header is the entire message.
157 # See if msg_or_header is the entire message.
159 h = msg_or_header['header']
158 h = msg_or_header['header']
160 except KeyError:
159 except KeyError:
161 try:
160 try:
162 # See if msg_or_header is just the header
161 # See if msg_or_header is just the header
163 h = msg_or_header['msg_id']
162 h = msg_or_header['msg_id']
164 except KeyError:
163 except KeyError:
165 raise
164 raise
166 else:
165 else:
167 h = msg_or_header
166 h = msg_or_header
168 if not isinstance(h, dict):
167 if not isinstance(h, dict):
169 h = dict(h)
168 h = dict(h)
170 return h
169 return h
171
170
172 class Session(Configurable):
171 class Session(Configurable):
173 """Object for handling serialization and sending of messages.
172 """Object for handling serialization and sending of messages.
174
173
175 The Session object handles building messages and sending them
174 The Session object handles building messages and sending them
176 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
175 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
177 other over the network via Session objects, and only need to work with the
176 other over the network via Session objects, and only need to work with the
178 dict-based IPython message spec. The Session will handle
177 dict-based IPython message spec. The Session will handle
179 serialization/deserialization, security, and metadata.
178 serialization/deserialization, security, and metadata.
180
179
181 Sessions support configurable serialiization via packer/unpacker traits,
180 Sessions support configurable serialiization via packer/unpacker traits,
182 and signing with HMAC digests via the key/keyfile traits.
181 and signing with HMAC digests via the key/keyfile traits.
183
182
184 Parameters
183 Parameters
185 ----------
184 ----------
186
185
187 debug : bool
186 debug : bool
188 whether to trigger extra debugging statements
187 whether to trigger extra debugging statements
189 packer/unpacker : str : 'json', 'pickle' or import_string
188 packer/unpacker : str : 'json', 'pickle' or import_string
190 importstrings for methods to serialize message parts. If just
189 importstrings for methods to serialize message parts. If just
191 'json' or 'pickle', predefined JSON and pickle packers will be used.
190 'json' or 'pickle', predefined JSON and pickle packers will be used.
192 Otherwise, the entire importstring must be used.
191 Otherwise, the entire importstring must be used.
193
192
194 The functions must accept at least valid JSON input, and output *bytes*.
193 The functions must accept at least valid JSON input, and output *bytes*.
195
194
196 For example, to use msgpack:
195 For example, to use msgpack:
197 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
196 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
198 pack/unpack : callables
197 pack/unpack : callables
199 You can also set the pack/unpack callables for serialization directly.
198 You can also set the pack/unpack callables for serialization directly.
200 session : bytes
199 session : bytes
201 the ID of this Session object. The default is to generate a new UUID.
200 the ID of this Session object. The default is to generate a new UUID.
202 username : unicode
201 username : unicode
203 username added to message headers. The default is to ask the OS.
202 username added to message headers. The default is to ask the OS.
204 key : bytes
203 key : bytes
205 The key used to initialize an HMAC signature. If unset, messages
204 The key used to initialize an HMAC signature. If unset, messages
206 will not be signed or checked.
205 will not be signed or checked.
207 keyfile : filepath
206 keyfile : filepath
208 The file containing a key. If this is set, `key` will be initialized
207 The file containing a key. If this is set, `key` will be initialized
209 to the contents of the file.
208 to the contents of the file.
210
209
211 """
210 """
212
211
213 debug=Bool(False, config=True, help="""Debug output in the Session""")
212 debug=Bool(False, config=True, help="""Debug output in the Session""")
214
213
215 packer = DottedObjectName('json',config=True,
214 packer = DottedObjectName('json',config=True,
216 help="""The name of the packer for serializing messages.
215 help="""The name of the packer for serializing messages.
217 Should be one of 'json', 'pickle', or an import name
216 Should be one of 'json', 'pickle', or an import name
218 for a custom callable serializer.""")
217 for a custom callable serializer.""")
219 def _packer_changed(self, name, old, new):
218 def _packer_changed(self, name, old, new):
220 if new.lower() == 'json':
219 if new.lower() == 'json':
221 self.pack = json_packer
220 self.pack = json_packer
222 self.unpack = json_unpacker
221 self.unpack = json_unpacker
223 elif new.lower() == 'pickle':
222 elif new.lower() == 'pickle':
224 self.pack = pickle_packer
223 self.pack = pickle_packer
225 self.unpack = pickle_unpacker
224 self.unpack = pickle_unpacker
226 else:
225 else:
227 self.pack = import_item(str(new))
226 self.pack = import_item(str(new))
228
227
229 unpacker = DottedObjectName('json', config=True,
228 unpacker = DottedObjectName('json', config=True,
230 help="""The name of the unpacker for unserializing messages.
229 help="""The name of the unpacker for unserializing messages.
231 Only used with custom functions for `packer`.""")
230 Only used with custom functions for `packer`.""")
232 def _unpacker_changed(self, name, old, new):
231 def _unpacker_changed(self, name, old, new):
233 if new.lower() == 'json':
232 if new.lower() == 'json':
234 self.pack = json_packer
233 self.pack = json_packer
235 self.unpack = json_unpacker
234 self.unpack = json_unpacker
236 elif new.lower() == 'pickle':
235 elif new.lower() == 'pickle':
237 self.pack = pickle_packer
236 self.pack = pickle_packer
238 self.unpack = pickle_unpacker
237 self.unpack = pickle_unpacker
239 else:
238 else:
240 self.unpack = import_item(str(new))
239 self.unpack = import_item(str(new))
241
240
242 session = CBytes(b'', config=True,
241 session = CBytes(b'', config=True,
243 help="""The UUID identifying this session.""")
242 help="""The UUID identifying this session.""")
244 def _session_default(self):
243 def _session_default(self):
245 return bytes(uuid.uuid4())
244 return bytes(uuid.uuid4())
246
245
247 username = Unicode(os.environ.get('USER','username'), config=True,
246 username = Unicode(os.environ.get('USER','username'), config=True,
248 help="""Username for the Session. Default is your system username.""")
247 help="""Username for the Session. Default is your system username.""")
249
248
250 # message signature related traits:
249 # message signature related traits:
251 key = CBytes(b'', config=True,
250 key = CBytes(b'', config=True,
252 help="""execution key, for extra authentication.""")
251 help="""execution key, for extra authentication.""")
253 def _key_changed(self, name, old, new):
252 def _key_changed(self, name, old, new):
254 if new:
253 if new:
255 self.auth = hmac.HMAC(new)
254 self.auth = hmac.HMAC(new)
256 else:
255 else:
257 self.auth = None
256 self.auth = None
258 auth = Instance(hmac.HMAC)
257 auth = Instance(hmac.HMAC)
259 digest_history = Set()
258 digest_history = Set()
260
259
261 keyfile = Unicode('', config=True,
260 keyfile = Unicode('', config=True,
262 help="""path to file containing execution key.""")
261 help="""path to file containing execution key.""")
263 def _keyfile_changed(self, name, old, new):
262 def _keyfile_changed(self, name, old, new):
264 with open(new, 'rb') as f:
263 with open(new, 'rb') as f:
265 self.key = f.read().strip()
264 self.key = f.read().strip()
266
265
267 pack = Any(default_packer) # the actual packer function
266 pack = Any(default_packer) # the actual packer function
268 def _pack_changed(self, name, old, new):
267 def _pack_changed(self, name, old, new):
269 if not callable(new):
268 if not callable(new):
270 raise TypeError("packer must be callable, not %s"%type(new))
269 raise TypeError("packer must be callable, not %s"%type(new))
271
270
272 unpack = Any(default_unpacker) # the actual packer function
271 unpack = Any(default_unpacker) # the actual packer function
273 def _unpack_changed(self, name, old, new):
272 def _unpack_changed(self, name, old, new):
274 # unpacker is not checked - it is assumed to be
273 # unpacker is not checked - it is assumed to be
275 if not callable(new):
274 if not callable(new):
276 raise TypeError("unpacker must be callable, not %s"%type(new))
275 raise TypeError("unpacker must be callable, not %s"%type(new))
277
276
278 def __init__(self, **kwargs):
277 def __init__(self, **kwargs):
279 """create a Session object
278 """create a Session object
280
279
281 Parameters
280 Parameters
282 ----------
281 ----------
283
282
284 debug : bool
283 debug : bool
285 whether to trigger extra debugging statements
284 whether to trigger extra debugging statements
286 packer/unpacker : str : 'json', 'pickle' or import_string
285 packer/unpacker : str : 'json', 'pickle' or import_string
287 importstrings for methods to serialize message parts. If just
286 importstrings for methods to serialize message parts. If just
288 'json' or 'pickle', predefined JSON and pickle packers will be used.
287 'json' or 'pickle', predefined JSON and pickle packers will be used.
289 Otherwise, the entire importstring must be used.
288 Otherwise, the entire importstring must be used.
290
289
291 The functions must accept at least valid JSON input, and output
290 The functions must accept at least valid JSON input, and output
292 *bytes*.
291 *bytes*.
293
292
294 For example, to use msgpack:
293 For example, to use msgpack:
295 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
294 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
296 pack/unpack : callables
295 pack/unpack : callables
297 You can also set the pack/unpack callables for serialization
296 You can also set the pack/unpack callables for serialization
298 directly.
297 directly.
299 session : bytes
298 session : bytes
300 the ID of this Session object. The default is to generate a new
299 the ID of this Session object. The default is to generate a new
301 UUID.
300 UUID.
302 username : unicode
301 username : unicode
303 username added to message headers. The default is to ask the OS.
302 username added to message headers. The default is to ask the OS.
304 key : bytes
303 key : bytes
305 The key used to initialize an HMAC signature. If unset, messages
304 The key used to initialize an HMAC signature. If unset, messages
306 will not be signed or checked.
305 will not be signed or checked.
307 keyfile : filepath
306 keyfile : filepath
308 The file containing a key. If this is set, `key` will be
307 The file containing a key. If this is set, `key` will be
309 initialized to the contents of the file.
308 initialized to the contents of the file.
310 """
309 """
311 super(Session, self).__init__(**kwargs)
310 super(Session, self).__init__(**kwargs)
312 self._check_packers()
311 self._check_packers()
313 self.none = self.pack({})
312 self.none = self.pack({})
314
313
315 @property
314 @property
316 def msg_id(self):
315 def msg_id(self):
317 """always return new uuid"""
316 """always return new uuid"""
318 return str(uuid.uuid4())
317 return str(uuid.uuid4())
319
318
320 def _check_packers(self):
319 def _check_packers(self):
321 """check packers for binary data and datetime support."""
320 """check packers for binary data and datetime support."""
322 pack = self.pack
321 pack = self.pack
323 unpack = self.unpack
322 unpack = self.unpack
324
323
325 # check simple serialization
324 # check simple serialization
326 msg = dict(a=[1,'hi'])
325 msg = dict(a=[1,'hi'])
327 try:
326 try:
328 packed = pack(msg)
327 packed = pack(msg)
329 except Exception:
328 except Exception:
330 raise ValueError("packer could not serialize a simple message")
329 raise ValueError("packer could not serialize a simple message")
331
330
332 # ensure packed message is bytes
331 # ensure packed message is bytes
333 if not isinstance(packed, bytes):
332 if not isinstance(packed, bytes):
334 raise ValueError("message packed to %r, but bytes are required"%type(packed))
333 raise ValueError("message packed to %r, but bytes are required"%type(packed))
335
334
336 # check that unpack is pack's inverse
335 # check that unpack is pack's inverse
337 try:
336 try:
338 unpacked = unpack(packed)
337 unpacked = unpack(packed)
339 except Exception:
338 except Exception:
340 raise ValueError("unpacker could not handle the packer's output")
339 raise ValueError("unpacker could not handle the packer's output")
341
340
342 # check datetime support
341 # check datetime support
343 msg = dict(t=datetime.now())
342 msg = dict(t=datetime.now())
344 try:
343 try:
345 unpacked = unpack(pack(msg))
344 unpacked = unpack(pack(msg))
346 except Exception:
345 except Exception:
347 self.pack = lambda o: pack(squash_dates(o))
346 self.pack = lambda o: pack(squash_dates(o))
348 self.unpack = lambda s: extract_dates(unpack(s))
347 self.unpack = lambda s: extract_dates(unpack(s))
349
348
350 def msg_header(self, msg_type):
349 def msg_header(self, msg_type):
351 return msg_header(self.msg_id, msg_type, self.username, self.session)
350 return msg_header(self.msg_id, msg_type, self.username, self.session)
352
351
353 def msg(self, msg_type, content=None, parent=None, subheader=None):
352 def msg(self, msg_type, content=None, parent=None, subheader=None):
354 """Return the nested message dict.
353 """Return the nested message dict.
355
354
356 This format is different from what is sent over the wire. The
355 This format is different from what is sent over the wire. The
357 self.serialize method converts this nested message dict to the wire
356 self.serialize method converts this nested message dict to the wire
358 format, which uses a message list.
357 format, which uses a message list.
359 """
358 """
360 msg = {}
359 msg = {}
361 msg['header'] = self.msg_header(msg_type)
360 msg['header'] = self.msg_header(msg_type)
362 msg['msg_id'] = msg['header']['msg_id']
361 msg['msg_id'] = msg['header']['msg_id']
363 msg['parent_header'] = {} if parent is None else extract_header(parent)
362 msg['parent_header'] = {} if parent is None else extract_header(parent)
364 msg['msg_type'] = msg_type
363 msg['msg_type'] = msg_type
365 msg['content'] = {} if content is None else content
364 msg['content'] = {} if content is None else content
366 sub = {} if subheader is None else subheader
365 sub = {} if subheader is None else subheader
367 msg['header'].update(sub)
366 msg['header'].update(sub)
368 return msg
367 return msg
369
368
370 def sign(self, msg_list):
369 def sign(self, msg_list):
371 """Sign a message with HMAC digest. If no auth, return b''.
370 """Sign a message with HMAC digest. If no auth, return b''.
372
371
373 Parameters
372 Parameters
374 ----------
373 ----------
375 msg_list : list
374 msg_list : list
376 The [p_header,p_parent,p_content] part of the message list.
375 The [p_header,p_parent,p_content] part of the message list.
377 """
376 """
378 if self.auth is None:
377 if self.auth is None:
379 return b''
378 return b''
380 h = self.auth.copy()
379 h = self.auth.copy()
381 for m in msg_list:
380 for m in msg_list:
382 h.update(m)
381 h.update(m)
383 return h.hexdigest()
382 return h.hexdigest()
384
383
385 def serialize(self, msg, ident=None):
384 def serialize(self, msg, ident=None):
386 """Serialize the message components to bytes.
385 """Serialize the message components to bytes.
387
386
388 Parameters
387 Parameters
389 ----------
388 ----------
390 msg : dict or Message
389 msg : dict or Message
391 The nexted message dict as returned by the self.msg method.
390 The nexted message dict as returned by the self.msg method.
392
391
393 Returns
392 Returns
394 -------
393 -------
395 msg_list : list
394 msg_list : list
396 The list of bytes objects to be sent with the format:
395 The list of bytes objects to be sent with the format:
397 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
396 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
398 buffer1,buffer2,...]. In this list, the p_* entities are
397 buffer1,buffer2,...]. In this list, the p_* entities are
399 the packed or serialized versions, so if JSON is used, these
398 the packed or serialized versions, so if JSON is used, these
400 are uft8 encoded JSON strings.
399 are uft8 encoded JSON strings.
401 """
400 """
402 content = msg.get('content', {})
401 content = msg.get('content', {})
403 if content is None:
402 if content is None:
404 content = self.none
403 content = self.none
405 elif isinstance(content, dict):
404 elif isinstance(content, dict):
406 content = self.pack(content)
405 content = self.pack(content)
407 elif isinstance(content, bytes):
406 elif isinstance(content, bytes):
408 # content is already packed, as in a relayed message
407 # content is already packed, as in a relayed message
409 pass
408 pass
410 elif isinstance(content, unicode):
409 elif isinstance(content, unicode):
411 # should be bytes, but JSON often spits out unicode
410 # should be bytes, but JSON often spits out unicode
412 content = content.encode('utf8')
411 content = content.encode('utf8')
413 else:
412 else:
414 raise TypeError("Content incorrect type: %s"%type(content))
413 raise TypeError("Content incorrect type: %s"%type(content))
415
414
416 real_message = [self.pack(msg['header']),
415 real_message = [self.pack(msg['header']),
417 self.pack(msg['parent_header']),
416 self.pack(msg['parent_header']),
418 content
417 content
419 ]
418 ]
420
419
421 to_send = []
420 to_send = []
422
421
423 if isinstance(ident, list):
422 if isinstance(ident, list):
424 # accept list of idents
423 # accept list of idents
425 to_send.extend(ident)
424 to_send.extend(ident)
426 elif ident is not None:
425 elif ident is not None:
427 to_send.append(ident)
426 to_send.append(ident)
428 to_send.append(DELIM)
427 to_send.append(DELIM)
429
428
430 signature = self.sign(real_message)
429 signature = self.sign(real_message)
431 to_send.append(signature)
430 to_send.append(signature)
432
431
433 to_send.extend(real_message)
432 to_send.extend(real_message)
434
433
435 return to_send
434 return to_send
436
435
437 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
436 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
438 buffers=None, subheader=None, track=False):
437 buffers=None, subheader=None, track=False):
439 """Build and send a message via stream or socket.
438 """Build and send a message via stream or socket.
440
439
441 The message format used by this function internally is as follows:
440 The message format used by this function internally is as follows:
442
441
443 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
442 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
444 buffer1,buffer2,...]
443 buffer1,buffer2,...]
445
444
446 The self.serialize method converts the nested message dict into this
445 The self.serialize method converts the nested message dict into this
447 format.
446 format.
448
447
449 Parameters
448 Parameters
450 ----------
449 ----------
451
450
452 stream : zmq.Socket or ZMQStream
451 stream : zmq.Socket or ZMQStream
453 the socket-like object used to send the data
452 the socket-like object used to send the data
454 msg_or_type : str or Message/dict
453 msg_or_type : str or Message/dict
455 Normally, msg_or_type will be a msg_type unless a message is being
454 Normally, msg_or_type will be a msg_type unless a message is being
456 sent more than once.
455 sent more than once.
457
456
458 content : dict or None
457 content : dict or None
459 the content of the message (ignored if msg_or_type is a message)
458 the content of the message (ignored if msg_or_type is a message)
460 parent : Message or dict or None
459 parent : Message or dict or None
461 the parent or parent header describing the parent of this message
460 the parent or parent header describing the parent of this message
462 ident : bytes or list of bytes
461 ident : bytes or list of bytes
463 the zmq.IDENTITY routing path
462 the zmq.IDENTITY routing path
464 subheader : dict or None
463 subheader : dict or None
465 extra header keys for this message's header
464 extra header keys for this message's header
466 buffers : list or None
465 buffers : list or None
467 the already-serialized buffers to be appended to the message
466 the already-serialized buffers to be appended to the message
468 track : bool
467 track : bool
469 whether to track. Only for use with Sockets,
468 whether to track. Only for use with Sockets,
470 because ZMQStream objects cannot track messages.
469 because ZMQStream objects cannot track messages.
471
470
472 Returns
471 Returns
473 -------
472 -------
474 msg : message dict
473 msg : message dict
475 the constructed message
474 the constructed message
476 (msg,tracker) : (message dict, MessageTracker)
475 (msg,tracker) : (message dict, MessageTracker)
477 if track=True, then a 2-tuple will be returned,
476 if track=True, then a 2-tuple will be returned,
478 the first element being the constructed
477 the first element being the constructed
479 message, and the second being the MessageTracker
478 message, and the second being the MessageTracker
480
479
481 """
480 """
482
481
483 if not isinstance(stream, (zmq.Socket, ZMQStream)):
482 if not isinstance(stream, (zmq.Socket, ZMQStream)):
484 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
483 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
485 elif track and isinstance(stream, ZMQStream):
484 elif track and isinstance(stream, ZMQStream):
486 raise TypeError("ZMQStream cannot track messages")
485 raise TypeError("ZMQStream cannot track messages")
487
486
488 if isinstance(msg_or_type, (Message, dict)):
487 if isinstance(msg_or_type, (Message, dict)):
489 # we got a Message, not a msg_type
488 # we got a Message, not a msg_type
490 # don't build a new Message
489 # don't build a new Message
491 msg = msg_or_type
490 msg = msg_or_type
492 else:
491 else:
493 msg = self.msg(msg_or_type, content, parent, subheader)
492 msg = self.msg(msg_or_type, content, parent, subheader)
494
493
495 buffers = [] if buffers is None else buffers
494 buffers = [] if buffers is None else buffers
496 to_send = self.serialize(msg, ident)
495 to_send = self.serialize(msg, ident)
497 flag = 0
496 flag = 0
498 if buffers:
497 if buffers:
499 flag = zmq.SNDMORE
498 flag = zmq.SNDMORE
500 _track = False
499 _track = False
501 else:
500 else:
502 _track=track
501 _track=track
503 if track:
502 if track:
504 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
503 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
505 else:
504 else:
506 tracker = stream.send_multipart(to_send, flag, copy=False)
505 tracker = stream.send_multipart(to_send, flag, copy=False)
507 for b in buffers[:-1]:
506 for b in buffers[:-1]:
508 stream.send(b, flag, copy=False)
507 stream.send(b, flag, copy=False)
509 if buffers:
508 if buffers:
510 if track:
509 if track:
511 tracker = stream.send(buffers[-1], copy=False, track=track)
510 tracker = stream.send(buffers[-1], copy=False, track=track)
512 else:
511 else:
513 tracker = stream.send(buffers[-1], copy=False)
512 tracker = stream.send(buffers[-1], copy=False)
514
513
515 # omsg = Message(msg)
514 # omsg = Message(msg)
516 if self.debug:
515 if self.debug:
517 pprint.pprint(msg)
516 pprint.pprint(msg)
518 pprint.pprint(to_send)
517 pprint.pprint(to_send)
519 pprint.pprint(buffers)
518 pprint.pprint(buffers)
520
519
521 msg['tracker'] = tracker
520 msg['tracker'] = tracker
522
521
523 return msg
522 return msg
524
523
525 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
524 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
526 """Send a raw message via ident path.
525 """Send a raw message via ident path.
527
526
528 This method is used to send a already serialized message.
527 This method is used to send a already serialized message.
529
528
530 Parameters
529 Parameters
531 ----------
530 ----------
532 stream : ZMQStream or Socket
531 stream : ZMQStream or Socket
533 The ZMQ stream or socket to use for sending the message.
532 The ZMQ stream or socket to use for sending the message.
534 msg_list : list
533 msg_list : list
535 The serialized list of messages to send. This only includes the
534 The serialized list of messages to send. This only includes the
536 [p_header,p_parent,p_content,buffer1,buffer2,...] portion of
535 [p_header,p_parent,p_content,buffer1,buffer2,...] portion of
537 the message.
536 the message.
538 ident : ident or list
537 ident : ident or list
539 A single ident or a list of idents to use in sending.
538 A single ident or a list of idents to use in sending.
540 """
539 """
541 to_send = []
540 to_send = []
542 if isinstance(ident, bytes):
541 if isinstance(ident, bytes):
543 ident = [ident]
542 ident = [ident]
544 if ident is not None:
543 if ident is not None:
545 to_send.extend(ident)
544 to_send.extend(ident)
546
545
547 to_send.append(DELIM)
546 to_send.append(DELIM)
548 to_send.append(self.sign(msg_list))
547 to_send.append(self.sign(msg_list))
549 to_send.extend(msg_list)
548 to_send.extend(msg_list)
550 stream.send_multipart(msg_list, flags, copy=copy)
549 stream.send_multipart(msg_list, flags, copy=copy)
551
550
552 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
551 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
553 """Receive and unpack a message.
552 """Receive and unpack a message.
554
553
555 Parameters
554 Parameters
556 ----------
555 ----------
557 socket : ZMQStream or Socket
556 socket : ZMQStream or Socket
558 The socket or stream to use in receiving.
557 The socket or stream to use in receiving.
559
558
560 Returns
559 Returns
561 -------
560 -------
562 [idents], msg
561 [idents], msg
563 [idents] is a list of idents and msg is a nested message dict of
562 [idents] is a list of idents and msg is a nested message dict of
564 same format as self.msg returns.
563 same format as self.msg returns.
565 """
564 """
566 if isinstance(socket, ZMQStream):
565 if isinstance(socket, ZMQStream):
567 socket = socket.socket
566 socket = socket.socket
568 try:
567 try:
569 msg_list = socket.recv_multipart(mode)
568 msg_list = socket.recv_multipart(mode)
570 except zmq.ZMQError as e:
569 except zmq.ZMQError as e:
571 if e.errno == zmq.EAGAIN:
570 if e.errno == zmq.EAGAIN:
572 # We can convert EAGAIN to None as we know in this case
571 # We can convert EAGAIN to None as we know in this case
573 # recv_multipart won't return None.
572 # recv_multipart won't return None.
574 return None,None
573 return None,None
575 else:
574 else:
576 raise
575 raise
577 # split multipart message into identity list and message dict
576 # split multipart message into identity list and message dict
578 # invalid large messages can cause very expensive string comparisons
577 # invalid large messages can cause very expensive string comparisons
579 idents, msg_list = self.feed_identities(msg_list, copy)
578 idents, msg_list = self.feed_identities(msg_list, copy)
580 try:
579 try:
581 return idents, self.unpack_message(msg_list, content=content, copy=copy)
580 return idents, self.unpack_message(msg_list, content=content, copy=copy)
582 except Exception as e:
581 except Exception as e:
583 print (idents, msg_list)
582 print (idents, msg_list)
584 # TODO: handle it
583 # TODO: handle it
585 raise e
584 raise e
586
585
587 def feed_identities(self, msg_list, copy=True):
586 def feed_identities(self, msg_list, copy=True):
588 """Split the identities from the rest of the message.
587 """Split the identities from the rest of the message.
589
588
590 Feed until DELIM is reached, then return the prefix as idents and
589 Feed until DELIM is reached, then return the prefix as idents and
591 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
590 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
592 but that would be silly.
591 but that would be silly.
593
592
594 Parameters
593 Parameters
595 ----------
594 ----------
596 msg_list : a list of Message or bytes objects
595 msg_list : a list of Message or bytes objects
597 The message to be split.
596 The message to be split.
598 copy : bool
597 copy : bool
599 flag determining whether the arguments are bytes or Messages
598 flag determining whether the arguments are bytes or Messages
600
599
601 Returns
600 Returns
602 -------
601 -------
603 (idents,msg_list) : two lists
602 (idents,msg_list) : two lists
604 idents will always be a list of bytes - the indentity prefix
603 idents will always be a list of bytes - the indentity prefix
605 msg_list will be a list of bytes or Messages, unchanged from input
604 msg_list will be a list of bytes or Messages, unchanged from input
606 msg_list should be unpackable via self.unpack_message at this point.
605 msg_list should be unpackable via self.unpack_message at this point.
607 """
606 """
608 if copy:
607 if copy:
609 idx = msg_list.index(DELIM)
608 idx = msg_list.index(DELIM)
610 return msg_list[:idx], msg_list[idx+1:]
609 return msg_list[:idx], msg_list[idx+1:]
611 else:
610 else:
612 failed = True
611 failed = True
613 for idx,m in enumerate(msg_list):
612 for idx,m in enumerate(msg_list):
614 if m.bytes == DELIM:
613 if m.bytes == DELIM:
615 failed = False
614 failed = False
616 break
615 break
617 if failed:
616 if failed:
618 raise ValueError("DELIM not in msg_list")
617 raise ValueError("DELIM not in msg_list")
619 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
618 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
620 return [m.bytes for m in idents], msg_list
619 return [m.bytes for m in idents], msg_list
621
620
622 def unpack_message(self, msg_list, content=True, copy=True):
621 def unpack_message(self, msg_list, content=True, copy=True):
623 """Return a message object from the format
622 """Return a message object from the format
624 sent by self.send.
623 sent by self.send.
625
624
626 Parameters:
625 Parameters:
627 -----------
626 -----------
628
627
629 content : bool (True)
628 content : bool (True)
630 whether to unpack the content dict (True),
629 whether to unpack the content dict (True),
631 or leave it serialized (False)
630 or leave it serialized (False)
632
631
633 copy : bool (True)
632 copy : bool (True)
634 whether to return the bytes (True),
633 whether to return the bytes (True),
635 or the non-copying Message object in each place (False)
634 or the non-copying Message object in each place (False)
636
635
637 """
636 """
638 minlen = 4
637 minlen = 4
639 message = {}
638 message = {}
640 if not copy:
639 if not copy:
641 for i in range(minlen):
640 for i in range(minlen):
642 msg_list[i] = msg_list[i].bytes
641 msg_list[i] = msg_list[i].bytes
643 if self.auth is not None:
642 if self.auth is not None:
644 signature = msg_list[0]
643 signature = msg_list[0]
645 if signature in self.digest_history:
644 if signature in self.digest_history:
646 raise ValueError("Duplicate Signature: %r"%signature)
645 raise ValueError("Duplicate Signature: %r"%signature)
647 self.digest_history.add(signature)
646 self.digest_history.add(signature)
648 check = self.sign(msg_list[1:4])
647 check = self.sign(msg_list[1:4])
649 if not signature == check:
648 if not signature == check:
650 raise ValueError("Invalid Signature: %r"%signature)
649 raise ValueError("Invalid Signature: %r"%signature)
651 if not len(msg_list) >= minlen:
650 if not len(msg_list) >= minlen:
652 raise TypeError("malformed message, must have at least %i elements"%minlen)
651 raise TypeError("malformed message, must have at least %i elements"%minlen)
653 message['header'] = self.unpack(msg_list[1])
652 message['header'] = self.unpack(msg_list[1])
654 message['msg_type'] = message['header']['msg_type']
653 message['msg_type'] = message['header']['msg_type']
655 message['parent_header'] = self.unpack(msg_list[2])
654 message['parent_header'] = self.unpack(msg_list[2])
656 if content:
655 if content:
657 message['content'] = self.unpack(msg_list[3])
656 message['content'] = self.unpack(msg_list[3])
658 else:
657 else:
659 message['content'] = msg_list[3]
658 message['content'] = msg_list[3]
660
659
661 message['buffers'] = msg_list[4:]
660 message['buffers'] = msg_list[4:]
662 return message
661 return message
663
662
664 def test_msg2obj():
663 def test_msg2obj():
665 am = dict(x=1)
664 am = dict(x=1)
666 ao = Message(am)
665 ao = Message(am)
667 assert ao.x == am['x']
666 assert ao.x == am['x']
668
667
669 am['y'] = dict(z=1)
668 am['y'] = dict(z=1)
670 ao = Message(am)
669 ao = Message(am)
671 assert ao.y.z == am['y']['z']
670 assert ao.y.z == am['y']['z']
672
671
673 k1, k2 = 'y', 'z'
672 k1, k2 = 'y', 'z'
674 assert ao[k1][k2] == am[k1][k2]
673 assert ao[k1][k2] == am[k1][k2]
675
674
676 am2 = dict(ao)
675 am2 = dict(ao)
677 assert am['x'] == am2['x']
676 assert am['x'] == am2['x']
678 assert am['y']['z'] == am2['y']['z']
677 assert am['y']['z'] == am2['y']['z']
679
678
General Comments 0
You need to be logged in to leave comments. Login now