##// END OF EJS Templates
Fix rpmlint: non-executable-script...
Thomas Spura -
Show More
@@ -1,328 +1,327 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 A base class for objects that are configurable.
5 4
6 5 Authors:
7 6
8 7 * Brian Granger
9 8 * Fernando Perez
10 9 * Min RK
11 10 """
12 11
13 12 #-----------------------------------------------------------------------------
14 13 # Copyright (C) 2008-2011 The IPython Development Team
15 14 #
16 15 # Distributed under the terms of the BSD License. The full license is in
17 16 # the file COPYING, distributed as part of this software.
18 17 #-----------------------------------------------------------------------------
19 18
20 19 #-----------------------------------------------------------------------------
21 20 # Imports
22 21 #-----------------------------------------------------------------------------
23 22
24 23 import datetime
25 24 from copy import deepcopy
26 25
27 26 from loader import Config
28 27 from IPython.utils.traitlets import HasTraits, Instance
29 28 from IPython.utils.text import indent, wrap_paragraphs
30 29
31 30
32 31 #-----------------------------------------------------------------------------
33 32 # Helper classes for Configurables
34 33 #-----------------------------------------------------------------------------
35 34
36 35
37 36 class ConfigurableError(Exception):
38 37 pass
39 38
40 39
41 40 class MultipleInstanceError(ConfigurableError):
42 41 pass
43 42
44 43 #-----------------------------------------------------------------------------
45 44 # Configurable implementation
46 45 #-----------------------------------------------------------------------------
47 46
48 47 class Configurable(HasTraits):
49 48
50 49 config = Instance(Config,(),{})
51 50 created = None
52 51
53 52 def __init__(self, **kwargs):
54 53 """Create a configurable given a config config.
55 54
56 55 Parameters
57 56 ----------
58 57 config : Config
59 58 If this is empty, default values are used. If config is a
60 59 :class:`Config` instance, it will be used to configure the
61 60 instance.
62 61
63 62 Notes
64 63 -----
65 64 Subclasses of Configurable must call the :meth:`__init__` method of
66 65 :class:`Configurable` *before* doing anything else and using
67 66 :func:`super`::
68 67
69 68 class MyConfigurable(Configurable):
70 69 def __init__(self, config=None):
71 70 super(MyConfigurable, self).__init__(config)
72 71 # Then any other code you need to finish initialization.
73 72
74 73 This ensures that instances will be configured properly.
75 74 """
76 75 config = kwargs.pop('config', None)
77 76 if config is not None:
78 77 # We used to deepcopy, but for now we are trying to just save
79 78 # by reference. This *could* have side effects as all components
80 79 # will share config. In fact, I did find such a side effect in
81 80 # _config_changed below. If a config attribute value was a mutable type
82 81 # all instances of a component were getting the same copy, effectively
83 82 # making that a class attribute.
84 83 # self.config = deepcopy(config)
85 84 self.config = config
86 85 # This should go second so individual keyword arguments override
87 86 # the values in config.
88 87 super(Configurable, self).__init__(**kwargs)
89 88 self.created = datetime.datetime.now()
90 89
91 90 #-------------------------------------------------------------------------
92 91 # Static trait notifiations
93 92 #-------------------------------------------------------------------------
94 93
95 94 def _config_changed(self, name, old, new):
96 95 """Update all the class traits having ``config=True`` as metadata.
97 96
98 97 For any class trait with a ``config`` metadata attribute that is
99 98 ``True``, we update the trait with the value of the corresponding
100 99 config entry.
101 100 """
102 101 # Get all traits with a config metadata entry that is True
103 102 traits = self.traits(config=True)
104 103
105 104 # We auto-load config section for this class as well as any parent
106 105 # classes that are Configurable subclasses. This starts with Configurable
107 106 # and works down the mro loading the config for each section.
108 107 section_names = [cls.__name__ for cls in \
109 108 reversed(self.__class__.__mro__) if
110 109 issubclass(cls, Configurable) and issubclass(self.__class__, cls)]
111 110
112 111 for sname in section_names:
113 112 # Don't do a blind getattr as that would cause the config to
114 113 # dynamically create the section with name self.__class__.__name__.
115 114 if new._has_section(sname):
116 115 my_config = new[sname]
117 116 for k, v in traits.iteritems():
118 117 # Don't allow traitlets with config=True to start with
119 118 # uppercase. Otherwise, they are confused with Config
120 119 # subsections. But, developers shouldn't have uppercase
121 120 # attributes anyways! (PEP 6)
122 121 if k[0].upper()==k[0] and not k.startswith('_'):
123 122 raise ConfigurableError('Configurable traitlets with '
124 123 'config=True must start with a lowercase so they are '
125 124 'not confused with Config subsections: %s.%s' % \
126 125 (self.__class__.__name__, k))
127 126 try:
128 127 # Here we grab the value from the config
129 128 # If k has the naming convention of a config
130 129 # section, it will be auto created.
131 130 config_value = my_config[k]
132 131 except KeyError:
133 132 pass
134 133 else:
135 134 # print "Setting %s.%s from %s.%s=%r" % \
136 135 # (self.__class__.__name__,k,sname,k,config_value)
137 136 # We have to do a deepcopy here if we don't deepcopy the entire
138 137 # config object. If we don't, a mutable config_value will be
139 138 # shared by all instances, effectively making it a class attribute.
140 139 setattr(self, k, deepcopy(config_value))
141 140
142 141 @classmethod
143 142 def class_get_help(cls):
144 143 """Get the help string for this class in ReST format."""
145 144 cls_traits = cls.class_traits(config=True)
146 145 final_help = []
147 146 final_help.append(u'%s options' % cls.__name__)
148 147 final_help.append(len(final_help[0])*u'-')
149 148 for k,v in cls.class_traits(config=True).iteritems():
150 149 help = cls.class_get_trait_help(v)
151 150 final_help.append(help)
152 151 return '\n'.join(final_help)
153 152
154 153 @classmethod
155 154 def class_get_trait_help(cls, trait):
156 155 """Get the help string for a single trait."""
157 156 lines = []
158 157 header = "--%s.%s=<%s>" % (cls.__name__, trait.name, trait.__class__.__name__)
159 158 lines.append(header)
160 159 try:
161 160 dvr = repr(trait.get_default_value())
162 161 except Exception:
163 162 dvr = None # ignore defaults we can't construct
164 163 if dvr is not None:
165 164 if len(dvr) > 64:
166 165 dvr = dvr[:61]+'...'
167 166 lines.append(indent('Default: %s'%dvr, 4))
168 167 if 'Enum' in trait.__class__.__name__:
169 168 # include Enum choices
170 169 lines.append(indent('Choices: %r'%(trait.values,)))
171 170
172 171 help = trait.get_metadata('help')
173 172 if help is not None:
174 173 help = '\n'.join(wrap_paragraphs(help, 76))
175 174 lines.append(indent(help, 4))
176 175 return '\n'.join(lines)
177 176
178 177 @classmethod
179 178 def class_print_help(cls):
180 179 """Get the help string for a single trait and print it."""
181 180 print cls.class_get_help()
182 181
183 182 @classmethod
184 183 def class_config_section(cls):
185 184 """Get the config class config section"""
186 185 def c(s):
187 186 """return a commented, wrapped block."""
188 187 s = '\n\n'.join(wrap_paragraphs(s, 78))
189 188
190 189 return '# ' + s.replace('\n', '\n# ')
191 190
192 191 # section header
193 192 breaker = '#' + '-'*78
194 193 s = "# %s configuration"%cls.__name__
195 194 lines = [breaker, s, breaker, '']
196 195 # get the description trait
197 196 desc = cls.class_traits().get('description')
198 197 if desc:
199 198 desc = desc.default_value
200 199 else:
201 200 # no description trait, use __doc__
202 201 desc = getattr(cls, '__doc__', '')
203 202 if desc:
204 203 lines.append(c(desc))
205 204 lines.append('')
206 205
207 206 parents = []
208 207 for parent in cls.mro():
209 208 # only include parents that are not base classes
210 209 # and are not the class itself
211 210 if issubclass(parent, Configurable) and \
212 211 not parent in (Configurable, SingletonConfigurable, cls):
213 212 parents.append(parent)
214 213
215 214 if parents:
216 215 pstr = ', '.join([ p.__name__ for p in parents ])
217 216 lines.append(c('%s will inherit config from: %s'%(cls.__name__, pstr)))
218 217 lines.append('')
219 218
220 219 for name,trait in cls.class_traits(config=True).iteritems():
221 220 help = trait.get_metadata('help') or ''
222 221 lines.append(c(help))
223 222 lines.append('# c.%s.%s = %r'%(cls.__name__, name, trait.get_default_value()))
224 223 lines.append('')
225 224 return '\n'.join(lines)
226 225
227 226
228 227
229 228 class SingletonConfigurable(Configurable):
230 229 """A configurable that only allows one instance.
231 230
232 231 This class is for classes that should only have one instance of itself
233 232 or *any* subclass. To create and retrieve such a class use the
234 233 :meth:`SingletonConfigurable.instance` method.
235 234 """
236 235
237 236 _instance = None
238 237
239 238 @classmethod
240 239 def _walk_mro(cls):
241 240 """Walk the cls.mro() for parent classes that are also singletons
242 241
243 242 For use in instance()
244 243 """
245 244
246 245 for subclass in cls.mro():
247 246 if issubclass(cls, subclass) and \
248 247 issubclass(subclass, SingletonConfigurable) and \
249 248 subclass != SingletonConfigurable:
250 249 yield subclass
251 250
252 251 @classmethod
253 252 def clear_instance(cls):
254 253 """unset _instance for this class and singleton parents.
255 254 """
256 255 if not cls.initialized():
257 256 return
258 257 for subclass in cls._walk_mro():
259 258 if isinstance(subclass._instance, cls):
260 259 # only clear instances that are instances
261 260 # of the calling class
262 261 subclass._instance = None
263 262
264 263 @classmethod
265 264 def instance(cls, *args, **kwargs):
266 265 """Returns a global instance of this class.
267 266
268 267 This method create a new instance if none have previously been created
269 268 and returns a previously created instance is one already exists.
270 269
271 270 The arguments and keyword arguments passed to this method are passed
272 271 on to the :meth:`__init__` method of the class upon instantiation.
273 272
274 273 Examples
275 274 --------
276 275
277 276 Create a singleton class using instance, and retrieve it::
278 277
279 278 >>> from IPython.config.configurable import SingletonConfigurable
280 279 >>> class Foo(SingletonConfigurable): pass
281 280 >>> foo = Foo.instance()
282 281 >>> foo == Foo.instance()
283 282 True
284 283
285 284 Create a subclass that is retrived using the base class instance::
286 285
287 286 >>> class Bar(SingletonConfigurable): pass
288 287 >>> class Bam(Bar): pass
289 288 >>> bam = Bam.instance()
290 289 >>> bam == Bar.instance()
291 290 True
292 291 """
293 292 # Create and save the instance
294 293 if cls._instance is None:
295 294 inst = cls(*args, **kwargs)
296 295 # Now make sure that the instance will also be returned by
297 296 # parent classes' _instance attribute.
298 297 for subclass in cls._walk_mro():
299 298 subclass._instance = inst
300 299
301 300 if isinstance(cls._instance, cls):
302 301 return cls._instance
303 302 else:
304 303 raise MultipleInstanceError(
305 304 'Multiple incompatible subclass instances of '
306 305 '%s are being created.' % cls.__name__
307 306 )
308 307
309 308 @classmethod
310 309 def initialized(cls):
311 310 """Has an instance been created?"""
312 311 return hasattr(cls, "_instance") and cls._instance is not None
313 312
314 313
315 314 class LoggingConfigurable(Configurable):
316 315 """A parent class for Configurables that log.
317 316
318 317 Subclasses have a log trait, and the default behavior
319 318 is to get the logger from the currently running Application
320 319 via Application.instance().log.
321 320 """
322 321
323 322 log = Instance('logging.Logger')
324 323 def _log_default(self):
325 324 from IPython.config.application import Application
326 325 return Application.instance().log
327 326
328 327
@@ -1,166 +1,165 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 Tests for IPython.config.configurable
5 4
6 5 Authors:
7 6
8 7 * Brian Granger
9 8 * Fernando Perez (design help)
10 9 """
11 10
12 11 #-----------------------------------------------------------------------------
13 12 # Copyright (C) 2008-2010 The IPython Development Team
14 13 #
15 14 # Distributed under the terms of the BSD License. The full license is in
16 15 # the file COPYING, distributed as part of this software.
17 16 #-----------------------------------------------------------------------------
18 17
19 18 #-----------------------------------------------------------------------------
20 19 # Imports
21 20 #-----------------------------------------------------------------------------
22 21
23 22 from unittest import TestCase
24 23
25 24 from IPython.config.configurable import (
26 25 Configurable,
27 26 SingletonConfigurable
28 27 )
29 28
30 29 from IPython.utils.traitlets import (
31 30 Int, Float, Unicode
32 31 )
33 32
34 33 from IPython.config.loader import Config
35 34
36 35
37 36 #-----------------------------------------------------------------------------
38 37 # Test cases
39 38 #-----------------------------------------------------------------------------
40 39
41 40
42 41 class MyConfigurable(Configurable):
43 42 a = Int(1, config=True, help="The integer a.")
44 43 b = Float(1.0, config=True, help="The integer b.")
45 44 c = Unicode('no config')
46 45
47 46
48 47 mc_help=u"""MyConfigurable options
49 48 ----------------------
50 49 --MyConfigurable.a=<Int>
51 50 Default: 1
52 51 The integer a.
53 52 --MyConfigurable.b=<Float>
54 53 Default: 1.0
55 54 The integer b."""
56 55
57 56 class Foo(Configurable):
58 57 a = Int(0, config=True, help="The integer a.")
59 58 b = Unicode('nope', config=True)
60 59
61 60
62 61 class Bar(Foo):
63 62 b = Unicode('gotit', config=False, help="The string b.")
64 63 c = Float(config=True, help="The string c.")
65 64
66 65
67 66 class TestConfigurable(TestCase):
68 67
69 68 def test_default(self):
70 69 c1 = Configurable()
71 70 c2 = Configurable(config=c1.config)
72 71 c3 = Configurable(config=c2.config)
73 72 self.assertEquals(c1.config, c2.config)
74 73 self.assertEquals(c2.config, c3.config)
75 74
76 75 def test_custom(self):
77 76 config = Config()
78 77 config.foo = 'foo'
79 78 config.bar = 'bar'
80 79 c1 = Configurable(config=config)
81 80 c2 = Configurable(config=c1.config)
82 81 c3 = Configurable(config=c2.config)
83 82 self.assertEquals(c1.config, config)
84 83 self.assertEquals(c2.config, config)
85 84 self.assertEquals(c3.config, config)
86 85 # Test that copies are not made
87 86 self.assert_(c1.config is config)
88 87 self.assert_(c2.config is config)
89 88 self.assert_(c3.config is config)
90 89 self.assert_(c1.config is c2.config)
91 90 self.assert_(c2.config is c3.config)
92 91
93 92 def test_inheritance(self):
94 93 config = Config()
95 94 config.MyConfigurable.a = 2
96 95 config.MyConfigurable.b = 2.0
97 96 c1 = MyConfigurable(config=config)
98 97 c2 = MyConfigurable(config=c1.config)
99 98 self.assertEquals(c1.a, config.MyConfigurable.a)
100 99 self.assertEquals(c1.b, config.MyConfigurable.b)
101 100 self.assertEquals(c2.a, config.MyConfigurable.a)
102 101 self.assertEquals(c2.b, config.MyConfigurable.b)
103 102
104 103 def test_parent(self):
105 104 config = Config()
106 105 config.Foo.a = 10
107 106 config.Foo.b = "wow"
108 107 config.Bar.b = 'later'
109 108 config.Bar.c = 100.0
110 109 f = Foo(config=config)
111 110 b = Bar(config=f.config)
112 111 self.assertEquals(f.a, 10)
113 112 self.assertEquals(f.b, 'wow')
114 113 self.assertEquals(b.b, 'gotit')
115 114 self.assertEquals(b.c, 100.0)
116 115
117 116 def test_override1(self):
118 117 config = Config()
119 118 config.MyConfigurable.a = 2
120 119 config.MyConfigurable.b = 2.0
121 120 c = MyConfigurable(a=3, config=config)
122 121 self.assertEquals(c.a, 3)
123 122 self.assertEquals(c.b, config.MyConfigurable.b)
124 123 self.assertEquals(c.c, 'no config')
125 124
126 125 def test_override2(self):
127 126 config = Config()
128 127 config.Foo.a = 1
129 128 config.Bar.b = 'or' # Up above b is config=False, so this won't do it.
130 129 config.Bar.c = 10.0
131 130 c = Bar(config=config)
132 131 self.assertEquals(c.a, config.Foo.a)
133 132 self.assertEquals(c.b, 'gotit')
134 133 self.assertEquals(c.c, config.Bar.c)
135 134 c = Bar(a=2, b='and', c=20.0, config=config)
136 135 self.assertEquals(c.a, 2)
137 136 self.assertEquals(c.b, 'and')
138 137 self.assertEquals(c.c, 20.0)
139 138
140 139 def test_help(self):
141 140 self.assertEquals(MyConfigurable.class_get_help(), mc_help)
142 141
143 142
144 143 class TestSingletonConfigurable(TestCase):
145 144
146 145 def test_instance(self):
147 146 from IPython.config.configurable import SingletonConfigurable
148 147 class Foo(SingletonConfigurable): pass
149 148 self.assertEquals(Foo.initialized(), False)
150 149 foo = Foo.instance()
151 150 self.assertEquals(Foo.initialized(), True)
152 151 self.assertEquals(foo, Foo.instance())
153 152 self.assertEquals(SingletonConfigurable._instance, None)
154 153
155 154 def test_inheritance(self):
156 155 class Bar(SingletonConfigurable): pass
157 156 class Bam(Bar): pass
158 157 self.assertEquals(Bar.initialized(), False)
159 158 self.assertEquals(Bam.initialized(), False)
160 159 bam = Bam.instance()
161 160 bam == Bar.instance()
162 161 self.assertEquals(Bar.initialized(), True)
163 162 self.assertEquals(Bam.initialized(), True)
164 163 self.assertEquals(bam, Bam._instance)
165 164 self.assertEquals(bam, Bar._instance)
166 165 self.assertEquals(SingletonConfigurable._instance, None)
@@ -1,226 +1,225 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 Tests for IPython.config.loader
5 4
6 5 Authors:
7 6
8 7 * Brian Granger
9 8 * Fernando Perez (design help)
10 9 """
11 10
12 11 #-----------------------------------------------------------------------------
13 12 # Copyright (C) 2008-2009 The IPython Development Team
14 13 #
15 14 # Distributed under the terms of the BSD License. The full license is in
16 15 # the file COPYING, distributed as part of this software.
17 16 #-----------------------------------------------------------------------------
18 17
19 18 #-----------------------------------------------------------------------------
20 19 # Imports
21 20 #-----------------------------------------------------------------------------
22 21
23 22 import os
24 23 import sys
25 24 from tempfile import mkstemp
26 25 from unittest import TestCase
27 26
28 27 from nose import SkipTest
29 28
30 29 from IPython.testing.tools import mute_warn
31 30
32 31 from IPython.utils.traitlets import Int, Unicode
33 32 from IPython.config.configurable import Configurable
34 33 from IPython.config.loader import (
35 34 Config,
36 35 PyFileConfigLoader,
37 36 KeyValueConfigLoader,
38 37 ArgParseConfigLoader,
39 38 ConfigError
40 39 )
41 40
42 41 #-----------------------------------------------------------------------------
43 42 # Actual tests
44 43 #-----------------------------------------------------------------------------
45 44
46 45
47 46 pyfile = """
48 47 c = get_config()
49 48 c.a=10
50 49 c.b=20
51 50 c.Foo.Bar.value=10
52 51 c.Foo.Bam.value=range(10)
53 52 c.D.C.value='hi there'
54 53 """
55 54
56 55 class TestPyFileCL(TestCase):
57 56
58 57 def test_basic(self):
59 58 fd, fname = mkstemp('.py')
60 59 f = os.fdopen(fd, 'w')
61 60 f.write(pyfile)
62 61 f.close()
63 62 # Unlink the file
64 63 cl = PyFileConfigLoader(fname)
65 64 config = cl.load_config()
66 65 self.assertEquals(config.a, 10)
67 66 self.assertEquals(config.b, 20)
68 67 self.assertEquals(config.Foo.Bar.value, 10)
69 68 self.assertEquals(config.Foo.Bam.value, range(10))
70 69 self.assertEquals(config.D.C.value, 'hi there')
71 70
72 71 class MyLoader1(ArgParseConfigLoader):
73 72 def _add_arguments(self):
74 73 p = self.parser
75 74 p.add_argument('-f', '--foo', dest='Global.foo', type=str)
76 75 p.add_argument('-b', dest='MyClass.bar', type=int)
77 76 p.add_argument('-n', dest='n', action='store_true')
78 77 p.add_argument('Global.bam', type=str)
79 78
80 79 class MyLoader2(ArgParseConfigLoader):
81 80 def _add_arguments(self):
82 81 subparsers = self.parser.add_subparsers(dest='subparser_name')
83 82 subparser1 = subparsers.add_parser('1')
84 83 subparser1.add_argument('-x',dest='Global.x')
85 84 subparser2 = subparsers.add_parser('2')
86 85 subparser2.add_argument('y')
87 86
88 87 class TestArgParseCL(TestCase):
89 88
90 89 def test_basic(self):
91 90 cl = MyLoader1()
92 91 config = cl.load_config('-f hi -b 10 -n wow'.split())
93 92 self.assertEquals(config.Global.foo, 'hi')
94 93 self.assertEquals(config.MyClass.bar, 10)
95 94 self.assertEquals(config.n, True)
96 95 self.assertEquals(config.Global.bam, 'wow')
97 96 config = cl.load_config(['wow'])
98 97 self.assertEquals(config.keys(), ['Global'])
99 98 self.assertEquals(config.Global.keys(), ['bam'])
100 99 self.assertEquals(config.Global.bam, 'wow')
101 100
102 101 def test_add_arguments(self):
103 102 cl = MyLoader2()
104 103 config = cl.load_config('2 frobble'.split())
105 104 self.assertEquals(config.subparser_name, '2')
106 105 self.assertEquals(config.y, 'frobble')
107 106 config = cl.load_config('1 -x frobble'.split())
108 107 self.assertEquals(config.subparser_name, '1')
109 108 self.assertEquals(config.Global.x, 'frobble')
110 109
111 110 def test_argv(self):
112 111 cl = MyLoader1(argv='-f hi -b 10 -n wow'.split())
113 112 config = cl.load_config()
114 113 self.assertEquals(config.Global.foo, 'hi')
115 114 self.assertEquals(config.MyClass.bar, 10)
116 115 self.assertEquals(config.n, True)
117 116 self.assertEquals(config.Global.bam, 'wow')
118 117
119 118
120 119 class TestKeyValueCL(TestCase):
121 120
122 121 def test_basic(self):
123 122 cl = KeyValueConfigLoader()
124 123 argv = ['--'+s.strip('c.') for s in pyfile.split('\n')[2:-1]]
125 124 with mute_warn():
126 125 config = cl.load_config(argv)
127 126 self.assertEquals(config.a, 10)
128 127 self.assertEquals(config.b, 20)
129 128 self.assertEquals(config.Foo.Bar.value, 10)
130 129 self.assertEquals(config.Foo.Bam.value, range(10))
131 130 self.assertEquals(config.D.C.value, 'hi there')
132 131
133 132 def test_extra_args(self):
134 133 cl = KeyValueConfigLoader()
135 134 with mute_warn():
136 135 config = cl.load_config(['--a=5', 'b', '--c=10', 'd'])
137 136 self.assertEquals(cl.extra_args, ['b', 'd'])
138 137 self.assertEquals(config.a, 5)
139 138 self.assertEquals(config.c, 10)
140 139 with mute_warn():
141 140 config = cl.load_config(['--', '--a=5', '--c=10'])
142 141 self.assertEquals(cl.extra_args, ['--a=5', '--c=10'])
143 142
144 143 def test_unicode_args(self):
145 144 cl = KeyValueConfigLoader()
146 145 argv = [u'--a=épsîlön']
147 146 with mute_warn():
148 147 config = cl.load_config(argv)
149 148 self.assertEquals(config.a, u'épsîlön')
150 149
151 150 def test_unicode_bytes_args(self):
152 151 uarg = u'--a=é'
153 152 try:
154 153 barg = uarg.encode(sys.stdin.encoding)
155 154 except (TypeError, UnicodeEncodeError):
156 155 raise SkipTest("sys.stdin.encoding can't handle 'é'")
157 156
158 157 cl = KeyValueConfigLoader()
159 158 with mute_warn():
160 159 config = cl.load_config([barg])
161 160 self.assertEquals(config.a, u'é')
162 161
163 162
164 163 class TestConfig(TestCase):
165 164
166 165 def test_setget(self):
167 166 c = Config()
168 167 c.a = 10
169 168 self.assertEquals(c.a, 10)
170 169 self.assertEquals(c.has_key('b'), False)
171 170
172 171 def test_auto_section(self):
173 172 c = Config()
174 173 self.assertEquals(c.has_key('A'), True)
175 174 self.assertEquals(c._has_section('A'), False)
176 175 A = c.A
177 176 A.foo = 'hi there'
178 177 self.assertEquals(c._has_section('A'), True)
179 178 self.assertEquals(c.A.foo, 'hi there')
180 179 del c.A
181 180 self.assertEquals(len(c.A.keys()),0)
182 181
183 182 def test_merge_doesnt_exist(self):
184 183 c1 = Config()
185 184 c2 = Config()
186 185 c2.bar = 10
187 186 c2.Foo.bar = 10
188 187 c1._merge(c2)
189 188 self.assertEquals(c1.Foo.bar, 10)
190 189 self.assertEquals(c1.bar, 10)
191 190 c2.Bar.bar = 10
192 191 c1._merge(c2)
193 192 self.assertEquals(c1.Bar.bar, 10)
194 193
195 194 def test_merge_exists(self):
196 195 c1 = Config()
197 196 c2 = Config()
198 197 c1.Foo.bar = 10
199 198 c1.Foo.bam = 30
200 199 c2.Foo.bar = 20
201 200 c2.Foo.wow = 40
202 201 c1._merge(c2)
203 202 self.assertEquals(c1.Foo.bam, 30)
204 203 self.assertEquals(c1.Foo.bar, 20)
205 204 self.assertEquals(c1.Foo.wow, 40)
206 205 c2.Foo.Bam.bam = 10
207 206 c1._merge(c2)
208 207 self.assertEquals(c1.Foo.Bam.bam, 10)
209 208
210 209 def test_deepcopy(self):
211 210 c1 = Config()
212 211 c1.Foo.bar = 10
213 212 c1.Foo.bam = 30
214 213 c1.a = 'asdf'
215 214 c1.b = range(10)
216 215 import copy
217 216 c2 = copy.deepcopy(c1)
218 217 self.assertEquals(c1, c2)
219 218 self.assert_(c1 is not c2)
220 219 self.assert_(c1.Foo is not c2.Foo)
221 220
222 221 def test_builtin(self):
223 222 c1 = Config()
224 223 exec 'foo = True' in c1
225 224 self.assertEquals(c1.foo, True)
226 225 self.assertRaises(ConfigError, setattr, c1, 'ValueError', 10)
@@ -1,264 +1,263 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 System command aliases.
5 4
6 5 Authors:
7 6
8 7 * Fernando Perez
9 8 * Brian Granger
10 9 """
11 10
12 11 #-----------------------------------------------------------------------------
13 12 # Copyright (C) 2008-2010 The IPython Development Team
14 13 #
15 14 # Distributed under the terms of the BSD License.
16 15 #
17 16 # The full license is in the file COPYING.txt, distributed with this software.
18 17 #-----------------------------------------------------------------------------
19 18
20 19 #-----------------------------------------------------------------------------
21 20 # Imports
22 21 #-----------------------------------------------------------------------------
23 22
24 23 import __builtin__
25 24 import keyword
26 25 import os
27 26 import re
28 27 import sys
29 28
30 29 from IPython.config.configurable import Configurable
31 30 from IPython.core.splitinput import split_user_input
32 31
33 32 from IPython.utils.traitlets import List, Instance
34 33 from IPython.utils.autoattr import auto_attr
35 34 from IPython.utils.warn import warn, error
36 35
37 36 #-----------------------------------------------------------------------------
38 37 # Utilities
39 38 #-----------------------------------------------------------------------------
40 39
41 40 # This is used as the pattern for calls to split_user_input.
42 41 shell_line_split = re.compile(r'^(\s*)(\S*\s*)(.*$)')
43 42
44 43 def default_aliases():
45 44 """Return list of shell aliases to auto-define.
46 45 """
47 46 # Note: the aliases defined here should be safe to use on a kernel
48 47 # regardless of what frontend it is attached to. Frontends that use a
49 48 # kernel in-process can define additional aliases that will only work in
50 49 # their case. For example, things like 'less' or 'clear' that manipulate
51 50 # the terminal should NOT be declared here, as they will only work if the
52 51 # kernel is running inside a true terminal, and not over the network.
53 52
54 53 if os.name == 'posix':
55 54 default_aliases = [('mkdir', 'mkdir'), ('rmdir', 'rmdir'),
56 55 ('mv', 'mv -i'), ('rm', 'rm -i'), ('cp', 'cp -i'),
57 56 ('cat', 'cat'),
58 57 ]
59 58 # Useful set of ls aliases. The GNU and BSD options are a little
60 59 # different, so we make aliases that provide as similar as possible
61 60 # behavior in ipython, by passing the right flags for each platform
62 61 if sys.platform.startswith('linux'):
63 62 ls_aliases = [('ls', 'ls -F --color'),
64 63 # long ls
65 64 ('ll', 'ls -F -o --color'),
66 65 # ls normal files only
67 66 ('lf', 'ls -F -o --color %l | grep ^-'),
68 67 # ls symbolic links
69 68 ('lk', 'ls -F -o --color %l | grep ^l'),
70 69 # directories or links to directories,
71 70 ('ldir', 'ls -F -o --color %l | grep /$'),
72 71 # things which are executable
73 72 ('lx', 'ls -F -o --color %l | grep ^-..x'),
74 73 ]
75 74 else:
76 75 # BSD, OSX, etc.
77 76 ls_aliases = [('ls', 'ls -F'),
78 77 # long ls
79 78 ('ll', 'ls -F -l'),
80 79 # ls normal files only
81 80 ('lf', 'ls -F -l %l | grep ^-'),
82 81 # ls symbolic links
83 82 ('lk', 'ls -F -l %l | grep ^l'),
84 83 # directories or links to directories,
85 84 ('ldir', 'ls -F -l %l | grep /$'),
86 85 # things which are executable
87 86 ('lx', 'ls -F -l %l | grep ^-..x'),
88 87 ]
89 88 default_aliases = default_aliases + ls_aliases
90 89 elif os.name in ['nt', 'dos']:
91 90 default_aliases = [('ls', 'dir /on'),
92 91 ('ddir', 'dir /ad /on'), ('ldir', 'dir /ad /on'),
93 92 ('mkdir', 'mkdir'), ('rmdir', 'rmdir'),
94 93 ('echo', 'echo'), ('ren', 'ren'), ('copy', 'copy'),
95 94 ]
96 95 else:
97 96 default_aliases = []
98 97
99 98 return default_aliases
100 99
101 100
102 101 class AliasError(Exception):
103 102 pass
104 103
105 104
106 105 class InvalidAliasError(AliasError):
107 106 pass
108 107
109 108 #-----------------------------------------------------------------------------
110 109 # Main AliasManager class
111 110 #-----------------------------------------------------------------------------
112 111
113 112 class AliasManager(Configurable):
114 113
115 114 default_aliases = List(default_aliases(), config=True)
116 115 user_aliases = List(default_value=[], config=True)
117 116 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
118 117
119 118 def __init__(self, shell=None, config=None):
120 119 super(AliasManager, self).__init__(shell=shell, config=config)
121 120 self.alias_table = {}
122 121 self.exclude_aliases()
123 122 self.init_aliases()
124 123
125 124 def __contains__(self, name):
126 125 return name in self.alias_table
127 126
128 127 @property
129 128 def aliases(self):
130 129 return [(item[0], item[1][1]) for item in self.alias_table.iteritems()]
131 130
132 131 def exclude_aliases(self):
133 132 # set of things NOT to alias (keywords, builtins and some magics)
134 133 no_alias = set(['cd','popd','pushd','dhist','alias','unalias'])
135 134 no_alias.update(set(keyword.kwlist))
136 135 no_alias.update(set(__builtin__.__dict__.keys()))
137 136 self.no_alias = no_alias
138 137
139 138 def init_aliases(self):
140 139 # Load default aliases
141 140 for name, cmd in self.default_aliases:
142 141 self.soft_define_alias(name, cmd)
143 142
144 143 # Load user aliases
145 144 for name, cmd in self.user_aliases:
146 145 self.soft_define_alias(name, cmd)
147 146
148 147 def clear_aliases(self):
149 148 self.alias_table.clear()
150 149
151 150 def soft_define_alias(self, name, cmd):
152 151 """Define an alias, but don't raise on an AliasError."""
153 152 try:
154 153 self.define_alias(name, cmd)
155 154 except AliasError, e:
156 155 error("Invalid alias: %s" % e)
157 156
158 157 def define_alias(self, name, cmd):
159 158 """Define a new alias after validating it.
160 159
161 160 This will raise an :exc:`AliasError` if there are validation
162 161 problems.
163 162 """
164 163 nargs = self.validate_alias(name, cmd)
165 164 self.alias_table[name] = (nargs, cmd)
166 165
167 166 def undefine_alias(self, name):
168 167 if self.alias_table.has_key(name):
169 168 del self.alias_table[name]
170 169
171 170 def validate_alias(self, name, cmd):
172 171 """Validate an alias and return the its number of arguments."""
173 172 if name in self.no_alias:
174 173 raise InvalidAliasError("The name %s can't be aliased "
175 174 "because it is a keyword or builtin." % name)
176 175 if not (isinstance(cmd, basestring)):
177 176 raise InvalidAliasError("An alias command must be a string, "
178 177 "got: %r" % name)
179 178 nargs = cmd.count('%s')
180 179 if nargs>0 and cmd.find('%l')>=0:
181 180 raise InvalidAliasError('The %s and %l specifiers are mutually '
182 181 'exclusive in alias definitions.')
183 182 return nargs
184 183
185 184 def call_alias(self, alias, rest=''):
186 185 """Call an alias given its name and the rest of the line."""
187 186 cmd = self.transform_alias(alias, rest)
188 187 try:
189 188 self.shell.system(cmd)
190 189 except:
191 190 self.shell.showtraceback()
192 191
193 192 def transform_alias(self, alias,rest=''):
194 193 """Transform alias to system command string."""
195 194 nargs, cmd = self.alias_table[alias]
196 195
197 196 if ' ' in cmd and os.path.isfile(cmd):
198 197 cmd = '"%s"' % cmd
199 198
200 199 # Expand the %l special to be the user's input line
201 200 if cmd.find('%l') >= 0:
202 201 cmd = cmd.replace('%l', rest)
203 202 rest = ''
204 203 if nargs==0:
205 204 # Simple, argument-less aliases
206 205 cmd = '%s %s' % (cmd, rest)
207 206 else:
208 207 # Handle aliases with positional arguments
209 208 args = rest.split(None, nargs)
210 209 if len(args) < nargs:
211 210 raise AliasError('Alias <%s> requires %s arguments, %s given.' %
212 211 (alias, nargs, len(args)))
213 212 cmd = '%s %s' % (cmd % tuple(args[:nargs]),' '.join(args[nargs:]))
214 213 return cmd
215 214
216 215 def expand_alias(self, line):
217 216 """ Expand an alias in the command line
218 217
219 218 Returns the provided command line, possibly with the first word
220 219 (command) translated according to alias expansion rules.
221 220
222 221 [ipython]|16> _ip.expand_aliases("np myfile.txt")
223 222 <16> 'q:/opt/np/notepad++.exe myfile.txt'
224 223 """
225 224
226 225 pre,fn,rest = split_user_input(line)
227 226 res = pre + self.expand_aliases(fn, rest)
228 227 return res
229 228
230 229 def expand_aliases(self, fn, rest):
231 230 """Expand multiple levels of aliases:
232 231
233 232 if:
234 233
235 234 alias foo bar /tmp
236 235 alias baz foo
237 236
238 237 then:
239 238
240 239 baz huhhahhei -> bar /tmp huhhahhei
241 240 """
242 241 line = fn + " " + rest
243 242
244 243 done = set()
245 244 while 1:
246 245 pre,fn,rest = split_user_input(line, shell_line_split)
247 246 if fn in self.alias_table:
248 247 if fn in done:
249 248 warn("Cyclic alias definition, repeated '%s'" % fn)
250 249 return ""
251 250 done.add(fn)
252 251
253 252 l2 = self.transform_alias(fn, rest)
254 253 if l2 == line:
255 254 break
256 255 # ls -> ls -F should not recurse forever
257 256 if l2.split(None,1)[0] == line.split(None,1)[0]:
258 257 line = l2
259 258 break
260 259 line=l2
261 260 else:
262 261 break
263 262
264 263 return line
@@ -1,71 +1,70 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 Autocall capabilities for IPython.core.
5 4
6 5 Authors:
7 6
8 7 * Brian Granger
9 8 * Fernando Perez
10 9 * Thomas Kluyver
11 10
12 11 Notes
13 12 -----
14 13 """
15 14
16 15 #-----------------------------------------------------------------------------
17 16 # Copyright (C) 2008-2009 The IPython Development Team
18 17 #
19 18 # Distributed under the terms of the BSD License. The full license is in
20 19 # the file COPYING, distributed as part of this software.
21 20 #-----------------------------------------------------------------------------
22 21
23 22 #-----------------------------------------------------------------------------
24 23 # Imports
25 24 #-----------------------------------------------------------------------------
26 25
27 26
28 27 #-----------------------------------------------------------------------------
29 28 # Code
30 29 #-----------------------------------------------------------------------------
31 30
32 31 class IPyAutocall(object):
33 32 """ Instances of this class are always autocalled
34 33
35 34 This happens regardless of 'autocall' variable state. Use this to
36 35 develop macro-like mechanisms.
37 36 """
38 37 _ip = None
39 38 rewrite = True
40 39 def __init__(self, ip=None):
41 40 self._ip = ip
42 41
43 42 def set_ip(self, ip):
44 43 """ Will be used to set _ip point to current ipython instance b/f call
45 44
46 45 Override this method if you don't want this to happen.
47 46
48 47 """
49 48 self._ip = ip
50 49
51 50
52 51 class ExitAutocall(IPyAutocall):
53 52 """An autocallable object which will be added to the user namespace so that
54 53 exit, exit(), quit or quit() are all valid ways to close the shell."""
55 54 rewrite = False
56 55
57 56 def __call__(self):
58 57 self._ip.ask_exit()
59 58
60 59 class ZMQExitAutocall(ExitAutocall):
61 60 """Exit IPython. Autocallable, so it needn't be explicitly called.
62 61
63 62 Parameters
64 63 ----------
65 64 keep_kernel : bool
66 65 If True, leave the kernel alive. Otherwise, tell the kernel to exit too
67 66 (default).
68 67 """
69 68 def __call__(self, keep_kernel=False):
70 69 self._ip.keepkernel_on_exit = keep_kernel
71 70 self._ip.ask_exit()
@@ -1,71 +1,70 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 A context manager for handling sys.displayhook.
5 4
6 5 Authors:
7 6
8 7 * Robert Kern
9 8 * Brian Granger
10 9 """
11 10
12 11 #-----------------------------------------------------------------------------
13 12 # Copyright (C) 2008-2009 The IPython Development Team
14 13 #
15 14 # Distributed under the terms of the BSD License. The full license is in
16 15 # the file COPYING, distributed as part of this software.
17 16 #-----------------------------------------------------------------------------
18 17
19 18 #-----------------------------------------------------------------------------
20 19 # Imports
21 20 #-----------------------------------------------------------------------------
22 21
23 22 import sys
24 23
25 24 from IPython.config.configurable import Configurable
26 25 from IPython.utils.traitlets import Any
27 26
28 27 #-----------------------------------------------------------------------------
29 28 # Classes and functions
30 29 #-----------------------------------------------------------------------------
31 30
32 31
33 32 class DisplayTrap(Configurable):
34 33 """Object to manage sys.displayhook.
35 34
36 35 This came from IPython.core.kernel.display_hook, but is simplified
37 36 (no callbacks or formatters) until more of the core is refactored.
38 37 """
39 38
40 39 hook = Any
41 40
42 41 def __init__(self, hook=None):
43 42 super(DisplayTrap, self).__init__(hook=hook, config=None)
44 43 self.old_hook = None
45 44 # We define this to track if a single BuiltinTrap is nested.
46 45 # Only turn off the trap when the outermost call to __exit__ is made.
47 46 self._nested_level = 0
48 47
49 48 def __enter__(self):
50 49 if self._nested_level == 0:
51 50 self.set()
52 51 self._nested_level += 1
53 52 return self
54 53
55 54 def __exit__(self, type, value, traceback):
56 55 if self._nested_level == 1:
57 56 self.unset()
58 57 self._nested_level -= 1
59 58 # Returning False will cause exceptions to propagate
60 59 return False
61 60
62 61 def set(self):
63 62 """Set the hook."""
64 63 if sys.displayhook is not self.hook:
65 64 self.old_hook = sys.displayhook
66 65 sys.displayhook = self.hook
67 66
68 67 def unset(self):
69 68 """Unset the hook."""
70 69 sys.displayhook = self.old_hook
71 70
@@ -1,52 +1,51 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 Global exception classes for IPython.core.
5 4
6 5 Authors:
7 6
8 7 * Brian Granger
9 8 * Fernando Perez
10 9
11 10 Notes
12 11 -----
13 12 """
14 13
15 14 #-----------------------------------------------------------------------------
16 15 # Copyright (C) 2008-2009 The IPython Development Team
17 16 #
18 17 # Distributed under the terms of the BSD License. The full license is in
19 18 # the file COPYING, distributed as part of this software.
20 19 #-----------------------------------------------------------------------------
21 20
22 21 #-----------------------------------------------------------------------------
23 22 # Imports
24 23 #-----------------------------------------------------------------------------
25 24
26 25 #-----------------------------------------------------------------------------
27 26 # Exception classes
28 27 #-----------------------------------------------------------------------------
29 28
30 29 class IPythonCoreError(Exception):
31 30 pass
32 31
33 32
34 33 class TryNext(IPythonCoreError):
35 34 """Try next hook exception.
36 35
37 36 Raise this in your hook function to indicate that the next hook handler
38 37 should be used to handle the operation. If you pass arguments to the
39 38 constructor those arguments will be used by the next hook instead of the
40 39 original ones.
41 40 """
42 41
43 42 def __init__(self, *args, **kwargs):
44 43 self.args = args
45 44 self.kwargs = kwargs
46 45
47 46 class UsageError(IPythonCoreError):
48 47 """Error in magic function arguments, etc.
49 48
50 49 Something that probably won't warrant a full traceback, but should
51 50 nevertheless interrupt a macro / batch file.
52 51 """
@@ -1,30 +1,29 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 This module is *completely* deprecated and should no longer be used for
5 4 any purpose. Currently, we have a few parts of the core that have
6 5 not been componentized and thus, still rely on this module. When everything
7 6 has been made into a component, this module will be sent to deathrow.
8 7 """
9 8
10 9 #-----------------------------------------------------------------------------
11 10 # Copyright (C) 2008-2009 The IPython Development Team
12 11 #
13 12 # Distributed under the terms of the BSD License. The full license is in
14 13 # the file COPYING, distributed as part of this software.
15 14 #-----------------------------------------------------------------------------
16 15
17 16 #-----------------------------------------------------------------------------
18 17 # Imports
19 18 #-----------------------------------------------------------------------------
20 19
21 20 #-----------------------------------------------------------------------------
22 21 # Classes and functions
23 22 #-----------------------------------------------------------------------------
24 23
25 24
26 25 def get():
27 26 """Get the global InteractiveShell instance."""
28 27 from IPython.core.interactiveshell import InteractiveShell
29 28 return InteractiveShell.instance()
30 29
@@ -1,327 +1,326 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 Paging capabilities for IPython.core
5 4
6 5 Authors:
7 6
8 7 * Brian Granger
9 8 * Fernando Perez
10 9
11 10 Notes
12 11 -----
13 12
14 13 For now this uses ipapi, so it can't be in IPython.utils. If we can get
15 14 rid of that dependency, we could move it there.
16 15 -----
17 16 """
18 17
19 18 #-----------------------------------------------------------------------------
20 19 # Copyright (C) 2008-2009 The IPython Development Team
21 20 #
22 21 # Distributed under the terms of the BSD License. The full license is in
23 22 # the file COPYING, distributed as part of this software.
24 23 #-----------------------------------------------------------------------------
25 24
26 25 #-----------------------------------------------------------------------------
27 26 # Imports
28 27 #-----------------------------------------------------------------------------
29 28
30 29 import os
31 30 import re
32 31 import sys
33 32 import tempfile
34 33
35 34 from IPython.core import ipapi
36 35 from IPython.core.error import TryNext
37 36 from IPython.utils.cursesimport import use_curses
38 37 from IPython.utils.data import chop
39 38 from IPython.utils import io
40 39 from IPython.utils.process import system
41 40 from IPython.utils.terminal import get_terminal_size
42 41
43 42
44 43 #-----------------------------------------------------------------------------
45 44 # Classes and functions
46 45 #-----------------------------------------------------------------------------
47 46
48 47 esc_re = re.compile(r"(\x1b[^m]+m)")
49 48
50 49 def page_dumb(strng, start=0, screen_lines=25):
51 50 """Very dumb 'pager' in Python, for when nothing else works.
52 51
53 52 Only moves forward, same interface as page(), except for pager_cmd and
54 53 mode."""
55 54
56 55 out_ln = strng.splitlines()[start:]
57 56 screens = chop(out_ln,screen_lines-1)
58 57 if len(screens) == 1:
59 58 print >>io.stdout, os.linesep.join(screens[0])
60 59 else:
61 60 last_escape = ""
62 61 for scr in screens[0:-1]:
63 62 hunk = os.linesep.join(scr)
64 63 print >>io.stdout, last_escape + hunk
65 64 if not page_more():
66 65 return
67 66 esc_list = esc_re.findall(hunk)
68 67 if len(esc_list) > 0:
69 68 last_escape = esc_list[-1]
70 69 print >>io.stdout, last_escape + os.linesep.join(screens[-1])
71 70
72 71
73 72 def page(strng, start=0, screen_lines=0, pager_cmd=None):
74 73 """Print a string, piping through a pager after a certain length.
75 74
76 75 The screen_lines parameter specifies the number of *usable* lines of your
77 76 terminal screen (total lines minus lines you need to reserve to show other
78 77 information).
79 78
80 79 If you set screen_lines to a number <=0, page() will try to auto-determine
81 80 your screen size and will only use up to (screen_size+screen_lines) for
82 81 printing, paging after that. That is, if you want auto-detection but need
83 82 to reserve the bottom 3 lines of the screen, use screen_lines = -3, and for
84 83 auto-detection without any lines reserved simply use screen_lines = 0.
85 84
86 85 If a string won't fit in the allowed lines, it is sent through the
87 86 specified pager command. If none given, look for PAGER in the environment,
88 87 and ultimately default to less.
89 88
90 89 If no system pager works, the string is sent through a 'dumb pager'
91 90 written in python, very simplistic.
92 91 """
93 92
94 93 # Some routines may auto-compute start offsets incorrectly and pass a
95 94 # negative value. Offset to 0 for robustness.
96 95 start = max(0, start)
97 96
98 97 # first, try the hook
99 98 ip = ipapi.get()
100 99 if ip:
101 100 try:
102 101 ip.hooks.show_in_pager(strng)
103 102 return
104 103 except TryNext:
105 104 pass
106 105
107 106 # Ugly kludge, but calling curses.initscr() flat out crashes in emacs
108 107 TERM = os.environ.get('TERM','dumb')
109 108 if TERM in ['dumb','emacs'] and os.name != 'nt':
110 109 print strng
111 110 return
112 111 # chop off the topmost part of the string we don't want to see
113 112 str_lines = strng.splitlines()[start:]
114 113 str_toprint = os.linesep.join(str_lines)
115 114 num_newlines = len(str_lines)
116 115 len_str = len(str_toprint)
117 116
118 117 # Dumb heuristics to guesstimate number of on-screen lines the string
119 118 # takes. Very basic, but good enough for docstrings in reasonable
120 119 # terminals. If someone later feels like refining it, it's not hard.
121 120 numlines = max(num_newlines,int(len_str/80)+1)
122 121
123 122 screen_lines_def = get_terminal_size()[1]
124 123
125 124 # auto-determine screen size
126 125 if screen_lines <= 0:
127 126 if (TERM=='xterm' or TERM=='xterm-color') and sys.platform != 'sunos5':
128 127 local_use_curses = use_curses
129 128 else:
130 129 # curses causes problems on many terminals other than xterm, and
131 130 # some termios calls lock up on Sun OS5.
132 131 local_use_curses = False
133 132 if local_use_curses:
134 133 import termios
135 134 import curses
136 135 # There is a bug in curses, where *sometimes* it fails to properly
137 136 # initialize, and then after the endwin() call is made, the
138 137 # terminal is left in an unusable state. Rather than trying to
139 138 # check everytime for this (by requesting and comparing termios
140 139 # flags each time), we just save the initial terminal state and
141 140 # unconditionally reset it every time. It's cheaper than making
142 141 # the checks.
143 142 term_flags = termios.tcgetattr(sys.stdout)
144 143
145 144 # Curses modifies the stdout buffer size by default, which messes
146 145 # up Python's normal stdout buffering. This would manifest itself
147 146 # to IPython users as delayed printing on stdout after having used
148 147 # the pager.
149 148 #
150 149 # We can prevent this by manually setting the NCURSES_NO_SETBUF
151 150 # environment variable. For more details, see:
152 151 # http://bugs.python.org/issue10144
153 152 NCURSES_NO_SETBUF = os.environ.get('NCURSES_NO_SETBUF', None)
154 153 os.environ['NCURSES_NO_SETBUF'] = ''
155 154
156 155 # Proceed with curses initialization
157 156 scr = curses.initscr()
158 157 screen_lines_real,screen_cols = scr.getmaxyx()
159 158 curses.endwin()
160 159
161 160 # Restore environment
162 161 if NCURSES_NO_SETBUF is None:
163 162 del os.environ['NCURSES_NO_SETBUF']
164 163 else:
165 164 os.environ['NCURSES_NO_SETBUF'] = NCURSES_NO_SETBUF
166 165
167 166 # Restore terminal state in case endwin() didn't.
168 167 termios.tcsetattr(sys.stdout,termios.TCSANOW,term_flags)
169 168 # Now we have what we needed: the screen size in rows/columns
170 169 screen_lines += screen_lines_real
171 170 #print '***Screen size:',screen_lines_real,'lines x',\
172 171 #screen_cols,'columns.' # dbg
173 172 else:
174 173 screen_lines += screen_lines_def
175 174
176 175 #print 'numlines',numlines,'screenlines',screen_lines # dbg
177 176 if numlines <= screen_lines :
178 177 #print '*** normal print' # dbg
179 178 print >>io.stdout, str_toprint
180 179 else:
181 180 # Try to open pager and default to internal one if that fails.
182 181 # All failure modes are tagged as 'retval=1', to match the return
183 182 # value of a failed system command. If any intermediate attempt
184 183 # sets retval to 1, at the end we resort to our own page_dumb() pager.
185 184 pager_cmd = get_pager_cmd(pager_cmd)
186 185 pager_cmd += ' ' + get_pager_start(pager_cmd,start)
187 186 if os.name == 'nt':
188 187 if pager_cmd.startswith('type'):
189 188 # The default WinXP 'type' command is failing on complex strings.
190 189 retval = 1
191 190 else:
192 191 tmpname = tempfile.mktemp('.txt')
193 192 tmpfile = file(tmpname,'wt')
194 193 tmpfile.write(strng)
195 194 tmpfile.close()
196 195 cmd = "%s < %s" % (pager_cmd,tmpname)
197 196 if os.system(cmd):
198 197 retval = 1
199 198 else:
200 199 retval = None
201 200 os.remove(tmpname)
202 201 else:
203 202 try:
204 203 retval = None
205 204 # if I use popen4, things hang. No idea why.
206 205 #pager,shell_out = os.popen4(pager_cmd)
207 206 pager = os.popen(pager_cmd,'w')
208 207 pager.write(strng)
209 208 pager.close()
210 209 retval = pager.close() # success returns None
211 210 except IOError,msg: # broken pipe when user quits
212 211 if msg.args == (32,'Broken pipe'):
213 212 retval = None
214 213 else:
215 214 retval = 1
216 215 except OSError:
217 216 # Other strange problems, sometimes seen in Win2k/cygwin
218 217 retval = 1
219 218 if retval is not None:
220 219 page_dumb(strng,screen_lines=screen_lines)
221 220
222 221
223 222 def page_file(fname, start=0, pager_cmd=None):
224 223 """Page a file, using an optional pager command and starting line.
225 224 """
226 225
227 226 pager_cmd = get_pager_cmd(pager_cmd)
228 227 pager_cmd += ' ' + get_pager_start(pager_cmd,start)
229 228
230 229 try:
231 230 if os.environ['TERM'] in ['emacs','dumb']:
232 231 raise EnvironmentError
233 232 system(pager_cmd + ' ' + fname)
234 233 except:
235 234 try:
236 235 if start > 0:
237 236 start -= 1
238 237 page(open(fname).read(),start)
239 238 except:
240 239 print 'Unable to show file',`fname`
241 240
242 241
243 242 def get_pager_cmd(pager_cmd=None):
244 243 """Return a pager command.
245 244
246 245 Makes some attempts at finding an OS-correct one.
247 246 """
248 247 if os.name == 'posix':
249 248 default_pager_cmd = 'less -r' # -r for color control sequences
250 249 elif os.name in ['nt','dos']:
251 250 default_pager_cmd = 'type'
252 251
253 252 if pager_cmd is None:
254 253 try:
255 254 pager_cmd = os.environ['PAGER']
256 255 except:
257 256 pager_cmd = default_pager_cmd
258 257 return pager_cmd
259 258
260 259
261 260 def get_pager_start(pager, start):
262 261 """Return the string for paging files with an offset.
263 262
264 263 This is the '+N' argument which less and more (under Unix) accept.
265 264 """
266 265
267 266 if pager in ['less','more']:
268 267 if start:
269 268 start_string = '+' + str(start)
270 269 else:
271 270 start_string = ''
272 271 else:
273 272 start_string = ''
274 273 return start_string
275 274
276 275
277 276 # (X)emacs on win32 doesn't like to be bypassed with msvcrt.getch()
278 277 if os.name == 'nt' and os.environ.get('TERM','dumb') != 'emacs':
279 278 import msvcrt
280 279 def page_more():
281 280 """ Smart pausing between pages
282 281
283 282 @return: True if need print more lines, False if quit
284 283 """
285 284 io.stdout.write('---Return to continue, q to quit--- ')
286 285 ans = msvcrt.getch()
287 286 if ans in ("q", "Q"):
288 287 result = False
289 288 else:
290 289 result = True
291 290 io.stdout.write("\b"*37 + " "*37 + "\b"*37)
292 291 return result
293 292 else:
294 293 def page_more():
295 294 ans = raw_input('---Return to continue, q to quit--- ')
296 295 if ans.lower().startswith('q'):
297 296 return False
298 297 else:
299 298 return True
300 299
301 300
302 301 def snip_print(str,width = 75,print_full = 0,header = ''):
303 302 """Print a string snipping the midsection to fit in width.
304 303
305 304 print_full: mode control:
306 305 - 0: only snip long strings
307 306 - 1: send to page() directly.
308 307 - 2: snip long strings and ask for full length viewing with page()
309 308 Return 1 if snipping was necessary, 0 otherwise."""
310 309
311 310 if print_full == 1:
312 311 page(header+str)
313 312 return 0
314 313
315 314 print header,
316 315 if len(str) < width:
317 316 print str
318 317 snip = 0
319 318 else:
320 319 whalf = int((width -5)/2)
321 320 print str[:whalf] + ' <...> ' + str[-whalf:]
322 321 snip = 1
323 322 if snip and print_full == 2:
324 323 if raw_input(header+' Snipped. View (y/n)? [N]').lower() == 'y':
325 324 page(str)
326 325 return snip
327 326
@@ -1,97 +1,96 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 A payload based version of page.
5 4
6 5 Authors:
7 6
8 7 * Brian Granger
9 8 * Fernando Perez
10 9 """
11 10
12 11 #-----------------------------------------------------------------------------
13 12 # Copyright (C) 2008-2010 The IPython Development Team
14 13 #
15 14 # Distributed under the terms of the BSD License. The full license is in
16 15 # the file COPYING, distributed as part of this software.
17 16 #-----------------------------------------------------------------------------
18 17
19 18 #-----------------------------------------------------------------------------
20 19 # Imports
21 20 #-----------------------------------------------------------------------------
22 21
23 22 # Third-party
24 23 try:
25 24 from docutils.core import publish_string
26 25 except ImportError:
27 26 # html paging won't be available, but we don't raise any errors. It's a
28 27 # purely optional feature.
29 28 pass
30 29
31 30 # Our own
32 31 from IPython.core.interactiveshell import InteractiveShell
33 32
34 33 #-----------------------------------------------------------------------------
35 34 # Classes and functions
36 35 #-----------------------------------------------------------------------------
37 36
38 37 def page(strng, start=0, screen_lines=0, pager_cmd=None,
39 38 html=None, auto_html=False):
40 39 """Print a string, piping through a pager.
41 40
42 41 This version ignores the screen_lines and pager_cmd arguments and uses
43 42 IPython's payload system instead.
44 43
45 44 Parameters
46 45 ----------
47 46 strng : str
48 47 Text to page.
49 48
50 49 start : int
51 50 Starting line at which to place the display.
52 51
53 52 html : str, optional
54 53 If given, an html string to send as well.
55 54
56 55 auto_html : bool, optional
57 56 If true, the input string is assumed to be valid reStructuredText and is
58 57 converted to HTML with docutils. Note that if docutils is not found,
59 58 this option is silently ignored.
60 59
61 60 Note
62 61 ----
63 62
64 63 Only one of the ``html`` and ``auto_html`` options can be given, not
65 64 both.
66 65 """
67 66
68 67 # Some routines may auto-compute start offsets incorrectly and pass a
69 68 # negative value. Offset to 0 for robustness.
70 69 start = max(0, start)
71 70 shell = InteractiveShell.instance()
72 71
73 72 if auto_html:
74 73 try:
75 74 # These defaults ensure user configuration variables for docutils
76 75 # are not loaded, only our config is used here.
77 76 defaults = {'file_insertion_enabled': 0,
78 77 'raw_enabled': 0,
79 78 '_disable_config': 1}
80 79 html = publish_string(strng, writer_name='html',
81 80 settings_overrides=defaults)
82 81 except:
83 82 pass
84 83
85 84 payload = dict(
86 85 source='IPython.zmq.page.page',
87 86 text=strng,
88 87 html=html,
89 88 start_line_number=start
90 89 )
91 90 shell.payload_manager.write_payload(payload)
92 91
93 92
94 93 def install_payload_page():
95 94 """Install this version of page as IPython.core.page.page."""
96 95 from IPython.core import page as corepage
97 96 corepage.page = page
@@ -1,1013 +1,1012 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 Prefiltering components.
5 4
6 5 Prefilters transform user input before it is exec'd by Python. These
7 6 transforms are used to implement additional syntax such as !ls and %magic.
8 7
9 8 Authors:
10 9
11 10 * Brian Granger
12 11 * Fernando Perez
13 12 * Dan Milstein
14 13 * Ville Vainio
15 14 """
16 15
17 16 #-----------------------------------------------------------------------------
18 17 # Copyright (C) 2008-2009 The IPython Development Team
19 18 #
20 19 # Distributed under the terms of the BSD License. The full license is in
21 20 # the file COPYING, distributed as part of this software.
22 21 #-----------------------------------------------------------------------------
23 22
24 23 #-----------------------------------------------------------------------------
25 24 # Imports
26 25 #-----------------------------------------------------------------------------
27 26
28 27 import __builtin__
29 28 import codeop
30 29 import re
31 30
32 31 from IPython.core.alias import AliasManager
33 32 from IPython.core.autocall import IPyAutocall
34 33 from IPython.config.configurable import Configurable
35 34 from IPython.core.macro import Macro
36 35 from IPython.core.splitinput import split_user_input
37 36 from IPython.core import page
38 37
39 38 from IPython.utils.traitlets import List, Int, Any, Unicode, CBool, Bool, Instance
40 39 from IPython.utils.text import make_quoted_expr
41 40 from IPython.utils.autoattr import auto_attr
42 41
43 42 #-----------------------------------------------------------------------------
44 43 # Global utilities, errors and constants
45 44 #-----------------------------------------------------------------------------
46 45
47 46 # Warning, these cannot be changed unless various regular expressions
48 47 # are updated in a number of places. Not great, but at least we told you.
49 48 ESC_SHELL = '!'
50 49 ESC_SH_CAP = '!!'
51 50 ESC_HELP = '?'
52 51 ESC_MAGIC = '%'
53 52 ESC_QUOTE = ','
54 53 ESC_QUOTE2 = ';'
55 54 ESC_PAREN = '/'
56 55
57 56
58 57 class PrefilterError(Exception):
59 58 pass
60 59
61 60
62 61 # RegExp to identify potential function names
63 62 re_fun_name = re.compile(r'[a-zA-Z_]([a-zA-Z0-9_.]*) *$')
64 63
65 64 # RegExp to exclude strings with this start from autocalling. In
66 65 # particular, all binary operators should be excluded, so that if foo is
67 66 # callable, foo OP bar doesn't become foo(OP bar), which is invalid. The
68 67 # characters '!=()' don't need to be checked for, as the checkPythonChars
69 68 # routine explicitely does so, to catch direct calls and rebindings of
70 69 # existing names.
71 70
72 71 # Warning: the '-' HAS TO BE AT THE END of the first group, otherwise
73 72 # it affects the rest of the group in square brackets.
74 73 re_exclude_auto = re.compile(r'^[,&^\|\*/\+-]'
75 74 r'|^is |^not |^in |^and |^or ')
76 75
77 76 # try to catch also methods for stuff in lists/tuples/dicts: off
78 77 # (experimental). For this to work, the line_split regexp would need
79 78 # to be modified so it wouldn't break things at '['. That line is
80 79 # nasty enough that I shouldn't change it until I can test it _well_.
81 80 #self.re_fun_name = re.compile (r'[a-zA-Z_]([a-zA-Z0-9_.\[\]]*) ?$')
82 81
83 82
84 83 # Handler Check Utilities
85 84 def is_shadowed(identifier, ip):
86 85 """Is the given identifier defined in one of the namespaces which shadow
87 86 the alias and magic namespaces? Note that an identifier is different
88 87 than ifun, because it can not contain a '.' character."""
89 88 # This is much safer than calling ofind, which can change state
90 89 return (identifier in ip.user_ns \
91 90 or identifier in ip.internal_ns \
92 91 or identifier in ip.ns_table['builtin'])
93 92
94 93
95 94 #-----------------------------------------------------------------------------
96 95 # The LineInfo class used throughout
97 96 #-----------------------------------------------------------------------------
98 97
99 98
100 99 class LineInfo(object):
101 100 """A single line of input and associated info.
102 101
103 102 Includes the following as properties:
104 103
105 104 line
106 105 The original, raw line
107 106
108 107 continue_prompt
109 108 Is this line a continuation in a sequence of multiline input?
110 109
111 110 pre
112 111 The initial esc character or whitespace.
113 112
114 113 pre_char
115 114 The escape character(s) in pre or the empty string if there isn't one.
116 115 Note that '!!' is a possible value for pre_char. Otherwise it will
117 116 always be a single character.
118 117
119 118 pre_whitespace
120 119 The leading whitespace from pre if it exists. If there is a pre_char,
121 120 this is just ''.
122 121
123 122 ifun
124 123 The 'function part', which is basically the maximal initial sequence
125 124 of valid python identifiers and the '.' character. This is what is
126 125 checked for alias and magic transformations, used for auto-calling,
127 126 etc.
128 127
129 128 the_rest
130 129 Everything else on the line.
131 130 """
132 131 def __init__(self, line, continue_prompt):
133 132 self.line = line
134 133 self.continue_prompt = continue_prompt
135 134 self.pre, self.ifun, self.the_rest = split_user_input(line)
136 135
137 136 self.pre_char = self.pre.strip()
138 137 if self.pre_char:
139 138 self.pre_whitespace = '' # No whitespace allowd before esc chars
140 139 else:
141 140 self.pre_whitespace = self.pre
142 141
143 142 self._oinfo = None
144 143
145 144 def ofind(self, ip):
146 145 """Do a full, attribute-walking lookup of the ifun in the various
147 146 namespaces for the given IPython InteractiveShell instance.
148 147
149 148 Return a dict with keys: found,obj,ospace,ismagic
150 149
151 150 Note: can cause state changes because of calling getattr, but should
152 151 only be run if autocall is on and if the line hasn't matched any
153 152 other, less dangerous handlers.
154 153
155 154 Does cache the results of the call, so can be called multiple times
156 155 without worrying about *further* damaging state.
157 156 """
158 157 if not self._oinfo:
159 158 # ip.shell._ofind is actually on the Magic class!
160 159 self._oinfo = ip.shell._ofind(self.ifun)
161 160 return self._oinfo
162 161
163 162 def __str__(self):
164 163 return "Lineinfo [%s|%s|%s]" %(self.pre, self.ifun, self.the_rest)
165 164
166 165
167 166 #-----------------------------------------------------------------------------
168 167 # Main Prefilter manager
169 168 #-----------------------------------------------------------------------------
170 169
171 170
172 171 class PrefilterManager(Configurable):
173 172 """Main prefilter component.
174 173
175 174 The IPython prefilter is run on all user input before it is run. The
176 175 prefilter consumes lines of input and produces transformed lines of
177 176 input.
178 177
179 178 The iplementation consists of two phases:
180 179
181 180 1. Transformers
182 181 2. Checkers and handlers
183 182
184 183 Over time, we plan on deprecating the checkers and handlers and doing
185 184 everything in the transformers.
186 185
187 186 The transformers are instances of :class:`PrefilterTransformer` and have
188 187 a single method :meth:`transform` that takes a line and returns a
189 188 transformed line. The transformation can be accomplished using any
190 189 tool, but our current ones use regular expressions for speed. We also
191 190 ship :mod:`pyparsing` in :mod:`IPython.external` for use in transformers.
192 191
193 192 After all the transformers have been run, the line is fed to the checkers,
194 193 which are instances of :class:`PrefilterChecker`. The line is passed to
195 194 the :meth:`check` method, which either returns `None` or a
196 195 :class:`PrefilterHandler` instance. If `None` is returned, the other
197 196 checkers are tried. If an :class:`PrefilterHandler` instance is returned,
198 197 the line is passed to the :meth:`handle` method of the returned
199 198 handler and no further checkers are tried.
200 199
201 200 Both transformers and checkers have a `priority` attribute, that determines
202 201 the order in which they are called. Smaller priorities are tried first.
203 202
204 203 Both transformers and checkers also have `enabled` attribute, which is
205 204 a boolean that determines if the instance is used.
206 205
207 206 Users or developers can change the priority or enabled attribute of
208 207 transformers or checkers, but they must call the :meth:`sort_checkers`
209 208 or :meth:`sort_transformers` method after changing the priority.
210 209 """
211 210
212 211 multi_line_specials = CBool(True, config=True)
213 212 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
214 213
215 214 def __init__(self, shell=None, config=None):
216 215 super(PrefilterManager, self).__init__(shell=shell, config=config)
217 216 self.shell = shell
218 217 self.init_transformers()
219 218 self.init_handlers()
220 219 self.init_checkers()
221 220
222 221 #-------------------------------------------------------------------------
223 222 # API for managing transformers
224 223 #-------------------------------------------------------------------------
225 224
226 225 def init_transformers(self):
227 226 """Create the default transformers."""
228 227 self._transformers = []
229 228 for transformer_cls in _default_transformers:
230 229 transformer_cls(
231 230 shell=self.shell, prefilter_manager=self, config=self.config
232 231 )
233 232
234 233 def sort_transformers(self):
235 234 """Sort the transformers by priority.
236 235
237 236 This must be called after the priority of a transformer is changed.
238 237 The :meth:`register_transformer` method calls this automatically.
239 238 """
240 239 self._transformers.sort(key=lambda x: x.priority)
241 240
242 241 @property
243 242 def transformers(self):
244 243 """Return a list of checkers, sorted by priority."""
245 244 return self._transformers
246 245
247 246 def register_transformer(self, transformer):
248 247 """Register a transformer instance."""
249 248 if transformer not in self._transformers:
250 249 self._transformers.append(transformer)
251 250 self.sort_transformers()
252 251
253 252 def unregister_transformer(self, transformer):
254 253 """Unregister a transformer instance."""
255 254 if transformer in self._transformers:
256 255 self._transformers.remove(transformer)
257 256
258 257 #-------------------------------------------------------------------------
259 258 # API for managing checkers
260 259 #-------------------------------------------------------------------------
261 260
262 261 def init_checkers(self):
263 262 """Create the default checkers."""
264 263 self._checkers = []
265 264 for checker in _default_checkers:
266 265 checker(
267 266 shell=self.shell, prefilter_manager=self, config=self.config
268 267 )
269 268
270 269 def sort_checkers(self):
271 270 """Sort the checkers by priority.
272 271
273 272 This must be called after the priority of a checker is changed.
274 273 The :meth:`register_checker` method calls this automatically.
275 274 """
276 275 self._checkers.sort(key=lambda x: x.priority)
277 276
278 277 @property
279 278 def checkers(self):
280 279 """Return a list of checkers, sorted by priority."""
281 280 return self._checkers
282 281
283 282 def register_checker(self, checker):
284 283 """Register a checker instance."""
285 284 if checker not in self._checkers:
286 285 self._checkers.append(checker)
287 286 self.sort_checkers()
288 287
289 288 def unregister_checker(self, checker):
290 289 """Unregister a checker instance."""
291 290 if checker in self._checkers:
292 291 self._checkers.remove(checker)
293 292
294 293 #-------------------------------------------------------------------------
295 294 # API for managing checkers
296 295 #-------------------------------------------------------------------------
297 296
298 297 def init_handlers(self):
299 298 """Create the default handlers."""
300 299 self._handlers = {}
301 300 self._esc_handlers = {}
302 301 for handler in _default_handlers:
303 302 handler(
304 303 shell=self.shell, prefilter_manager=self, config=self.config
305 304 )
306 305
307 306 @property
308 307 def handlers(self):
309 308 """Return a dict of all the handlers."""
310 309 return self._handlers
311 310
312 311 def register_handler(self, name, handler, esc_strings):
313 312 """Register a handler instance by name with esc_strings."""
314 313 self._handlers[name] = handler
315 314 for esc_str in esc_strings:
316 315 self._esc_handlers[esc_str] = handler
317 316
318 317 def unregister_handler(self, name, handler, esc_strings):
319 318 """Unregister a handler instance by name with esc_strings."""
320 319 try:
321 320 del self._handlers[name]
322 321 except KeyError:
323 322 pass
324 323 for esc_str in esc_strings:
325 324 h = self._esc_handlers.get(esc_str)
326 325 if h is handler:
327 326 del self._esc_handlers[esc_str]
328 327
329 328 def get_handler_by_name(self, name):
330 329 """Get a handler by its name."""
331 330 return self._handlers.get(name)
332 331
333 332 def get_handler_by_esc(self, esc_str):
334 333 """Get a handler by its escape string."""
335 334 return self._esc_handlers.get(esc_str)
336 335
337 336 #-------------------------------------------------------------------------
338 337 # Main prefiltering API
339 338 #-------------------------------------------------------------------------
340 339
341 340 def prefilter_line_info(self, line_info):
342 341 """Prefilter a line that has been converted to a LineInfo object.
343 342
344 343 This implements the checker/handler part of the prefilter pipe.
345 344 """
346 345 # print "prefilter_line_info: ", line_info
347 346 handler = self.find_handler(line_info)
348 347 return handler.handle(line_info)
349 348
350 349 def find_handler(self, line_info):
351 350 """Find a handler for the line_info by trying checkers."""
352 351 for checker in self.checkers:
353 352 if checker.enabled:
354 353 handler = checker.check(line_info)
355 354 if handler:
356 355 return handler
357 356 return self.get_handler_by_name('normal')
358 357
359 358 def transform_line(self, line, continue_prompt):
360 359 """Calls the enabled transformers in order of increasing priority."""
361 360 for transformer in self.transformers:
362 361 if transformer.enabled:
363 362 line = transformer.transform(line, continue_prompt)
364 363 return line
365 364
366 365 def prefilter_line(self, line, continue_prompt=False):
367 366 """Prefilter a single input line as text.
368 367
369 368 This method prefilters a single line of text by calling the
370 369 transformers and then the checkers/handlers.
371 370 """
372 371
373 372 # print "prefilter_line: ", line, continue_prompt
374 373 # All handlers *must* return a value, even if it's blank ('').
375 374
376 375 # save the line away in case we crash, so the post-mortem handler can
377 376 # record it
378 377 self.shell._last_input_line = line
379 378
380 379 if not line:
381 380 # Return immediately on purely empty lines, so that if the user
382 381 # previously typed some whitespace that started a continuation
383 382 # prompt, he can break out of that loop with just an empty line.
384 383 # This is how the default python prompt works.
385 384 return ''
386 385
387 386 # At this point, we invoke our transformers.
388 387 if not continue_prompt or (continue_prompt and self.multi_line_specials):
389 388 line = self.transform_line(line, continue_prompt)
390 389
391 390 # Now we compute line_info for the checkers and handlers
392 391 line_info = LineInfo(line, continue_prompt)
393 392
394 393 # the input history needs to track even empty lines
395 394 stripped = line.strip()
396 395
397 396 normal_handler = self.get_handler_by_name('normal')
398 397 if not stripped:
399 398 if not continue_prompt:
400 399 self.shell.displayhook.prompt_count -= 1
401 400
402 401 return normal_handler.handle(line_info)
403 402
404 403 # special handlers are only allowed for single line statements
405 404 if continue_prompt and not self.multi_line_specials:
406 405 return normal_handler.handle(line_info)
407 406
408 407 prefiltered = self.prefilter_line_info(line_info)
409 408 # print "prefiltered line: %r" % prefiltered
410 409 return prefiltered
411 410
412 411 def prefilter_lines(self, lines, continue_prompt=False):
413 412 """Prefilter multiple input lines of text.
414 413
415 414 This is the main entry point for prefiltering multiple lines of
416 415 input. This simply calls :meth:`prefilter_line` for each line of
417 416 input.
418 417
419 418 This covers cases where there are multiple lines in the user entry,
420 419 which is the case when the user goes back to a multiline history
421 420 entry and presses enter.
422 421 """
423 422 llines = lines.rstrip('\n').split('\n')
424 423 # We can get multiple lines in one shot, where multiline input 'blends'
425 424 # into one line, in cases like recalling from the readline history
426 425 # buffer. We need to make sure that in such cases, we correctly
427 426 # communicate downstream which line is first and which are continuation
428 427 # ones.
429 428 if len(llines) > 1:
430 429 out = '\n'.join([self.prefilter_line(line, lnum>0)
431 430 for lnum, line in enumerate(llines) ])
432 431 else:
433 432 out = self.prefilter_line(llines[0], continue_prompt)
434 433
435 434 return out
436 435
437 436 #-----------------------------------------------------------------------------
438 437 # Prefilter transformers
439 438 #-----------------------------------------------------------------------------
440 439
441 440
442 441 class PrefilterTransformer(Configurable):
443 442 """Transform a line of user input."""
444 443
445 444 priority = Int(100, config=True)
446 445 # Transformers don't currently use shell or prefilter_manager, but as we
447 446 # move away from checkers and handlers, they will need them.
448 447 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
449 448 prefilter_manager = Instance('IPython.core.prefilter.PrefilterManager')
450 449 enabled = Bool(True, config=True)
451 450
452 451 def __init__(self, shell=None, prefilter_manager=None, config=None):
453 452 super(PrefilterTransformer, self).__init__(
454 453 shell=shell, prefilter_manager=prefilter_manager, config=config
455 454 )
456 455 self.prefilter_manager.register_transformer(self)
457 456
458 457 def transform(self, line, continue_prompt):
459 458 """Transform a line, returning the new one."""
460 459 return None
461 460
462 461 def __repr__(self):
463 462 return "<%s(priority=%r, enabled=%r)>" % (
464 463 self.__class__.__name__, self.priority, self.enabled)
465 464
466 465
467 466 _assign_system_re = re.compile(r'(?P<lhs>(\s*)([\w\.]+)((\s*,\s*[\w\.]+)*))'
468 467 r'\s*=\s*!(?P<cmd>.*)')
469 468
470 469
471 470 class AssignSystemTransformer(PrefilterTransformer):
472 471 """Handle the `files = !ls` syntax."""
473 472
474 473 priority = Int(100, config=True)
475 474
476 475 def transform(self, line, continue_prompt):
477 476 m = _assign_system_re.match(line)
478 477 if m is not None:
479 478 cmd = m.group('cmd')
480 479 lhs = m.group('lhs')
481 480 expr = make_quoted_expr("sc =%s" % cmd)
482 481 new_line = '%s = get_ipython().magic(%s)' % (lhs, expr)
483 482 return new_line
484 483 return line
485 484
486 485
487 486 _assign_magic_re = re.compile(r'(?P<lhs>(\s*)([\w\.]+)((\s*,\s*[\w\.]+)*))'
488 487 r'\s*=\s*%(?P<cmd>.*)')
489 488
490 489 class AssignMagicTransformer(PrefilterTransformer):
491 490 """Handle the `a = %who` syntax."""
492 491
493 492 priority = Int(200, config=True)
494 493
495 494 def transform(self, line, continue_prompt):
496 495 m = _assign_magic_re.match(line)
497 496 if m is not None:
498 497 cmd = m.group('cmd')
499 498 lhs = m.group('lhs')
500 499 expr = make_quoted_expr(cmd)
501 500 new_line = '%s = get_ipython().magic(%s)' % (lhs, expr)
502 501 return new_line
503 502 return line
504 503
505 504
506 505 _classic_prompt_re = re.compile(r'(^[ \t]*>>> |^[ \t]*\.\.\. )')
507 506
508 507 class PyPromptTransformer(PrefilterTransformer):
509 508 """Handle inputs that start with '>>> ' syntax."""
510 509
511 510 priority = Int(50, config=True)
512 511
513 512 def transform(self, line, continue_prompt):
514 513
515 514 if not line or line.isspace() or line.strip() == '...':
516 515 # This allows us to recognize multiple input prompts separated by
517 516 # blank lines and pasted in a single chunk, very common when
518 517 # pasting doctests or long tutorial passages.
519 518 return ''
520 519 m = _classic_prompt_re.match(line)
521 520 if m:
522 521 return line[len(m.group(0)):]
523 522 else:
524 523 return line
525 524
526 525
527 526 _ipy_prompt_re = re.compile(r'(^[ \t]*In \[\d+\]: |^[ \t]*\ \ \ \.\.\.+: )')
528 527
529 528 class IPyPromptTransformer(PrefilterTransformer):
530 529 """Handle inputs that start classic IPython prompt syntax."""
531 530
532 531 priority = Int(50, config=True)
533 532
534 533 def transform(self, line, continue_prompt):
535 534
536 535 if not line or line.isspace() or line.strip() == '...':
537 536 # This allows us to recognize multiple input prompts separated by
538 537 # blank lines and pasted in a single chunk, very common when
539 538 # pasting doctests or long tutorial passages.
540 539 return ''
541 540 m = _ipy_prompt_re.match(line)
542 541 if m:
543 542 return line[len(m.group(0)):]
544 543 else:
545 544 return line
546 545
547 546 #-----------------------------------------------------------------------------
548 547 # Prefilter checkers
549 548 #-----------------------------------------------------------------------------
550 549
551 550
552 551 class PrefilterChecker(Configurable):
553 552 """Inspect an input line and return a handler for that line."""
554 553
555 554 priority = Int(100, config=True)
556 555 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
557 556 prefilter_manager = Instance('IPython.core.prefilter.PrefilterManager')
558 557 enabled = Bool(True, config=True)
559 558
560 559 def __init__(self, shell=None, prefilter_manager=None, config=None):
561 560 super(PrefilterChecker, self).__init__(
562 561 shell=shell, prefilter_manager=prefilter_manager, config=config
563 562 )
564 563 self.prefilter_manager.register_checker(self)
565 564
566 565 def check(self, line_info):
567 566 """Inspect line_info and return a handler instance or None."""
568 567 return None
569 568
570 569 def __repr__(self):
571 570 return "<%s(priority=%r, enabled=%r)>" % (
572 571 self.__class__.__name__, self.priority, self.enabled)
573 572
574 573
575 574 class EmacsChecker(PrefilterChecker):
576 575
577 576 priority = Int(100, config=True)
578 577 enabled = Bool(False, config=True)
579 578
580 579 def check(self, line_info):
581 580 "Emacs ipython-mode tags certain input lines."
582 581 if line_info.line.endswith('# PYTHON-MODE'):
583 582 return self.prefilter_manager.get_handler_by_name('emacs')
584 583 else:
585 584 return None
586 585
587 586
588 587 class ShellEscapeChecker(PrefilterChecker):
589 588
590 589 priority = Int(200, config=True)
591 590
592 591 def check(self, line_info):
593 592 if line_info.line.lstrip().startswith(ESC_SHELL):
594 593 return self.prefilter_manager.get_handler_by_name('shell')
595 594
596 595
597 596 class MacroChecker(PrefilterChecker):
598 597
599 598 priority = Int(250, config=True)
600 599
601 600 def check(self, line_info):
602 601 obj = self.shell.user_ns.get(line_info.ifun)
603 602 if isinstance(obj, Macro):
604 603 return self.prefilter_manager.get_handler_by_name('macro')
605 604 else:
606 605 return None
607 606
608 607
609 608 class IPyAutocallChecker(PrefilterChecker):
610 609
611 610 priority = Int(300, config=True)
612 611
613 612 def check(self, line_info):
614 613 "Instances of IPyAutocall in user_ns get autocalled immediately"
615 614 obj = self.shell.user_ns.get(line_info.ifun, None)
616 615 if isinstance(obj, IPyAutocall):
617 616 obj.set_ip(self.shell)
618 617 return self.prefilter_manager.get_handler_by_name('auto')
619 618 else:
620 619 return None
621 620
622 621
623 622 class MultiLineMagicChecker(PrefilterChecker):
624 623
625 624 priority = Int(400, config=True)
626 625
627 626 def check(self, line_info):
628 627 "Allow ! and !! in multi-line statements if multi_line_specials is on"
629 628 # Note that this one of the only places we check the first character of
630 629 # ifun and *not* the pre_char. Also note that the below test matches
631 630 # both ! and !!.
632 631 if line_info.continue_prompt \
633 632 and self.prefilter_manager.multi_line_specials:
634 633 if line_info.ifun.startswith(ESC_MAGIC):
635 634 return self.prefilter_manager.get_handler_by_name('magic')
636 635 else:
637 636 return None
638 637
639 638
640 639 class EscCharsChecker(PrefilterChecker):
641 640
642 641 priority = Int(500, config=True)
643 642
644 643 def check(self, line_info):
645 644 """Check for escape character and return either a handler to handle it,
646 645 or None if there is no escape char."""
647 646 if line_info.line[-1] == ESC_HELP \
648 647 and line_info.pre_char != ESC_SHELL \
649 648 and line_info.pre_char != ESC_SH_CAP:
650 649 # the ? can be at the end, but *not* for either kind of shell escape,
651 650 # because a ? can be a vaild final char in a shell cmd
652 651 return self.prefilter_manager.get_handler_by_name('help')
653 652 else:
654 653 # This returns None like it should if no handler exists
655 654 return self.prefilter_manager.get_handler_by_esc(line_info.pre_char)
656 655
657 656
658 657 class AssignmentChecker(PrefilterChecker):
659 658
660 659 priority = Int(600, config=True)
661 660
662 661 def check(self, line_info):
663 662 """Check to see if user is assigning to a var for the first time, in
664 663 which case we want to avoid any sort of automagic / autocall games.
665 664
666 665 This allows users to assign to either alias or magic names true python
667 666 variables (the magic/alias systems always take second seat to true
668 667 python code). E.g. ls='hi', or ls,that=1,2"""
669 668 if line_info.the_rest:
670 669 if line_info.the_rest[0] in '=,':
671 670 return self.prefilter_manager.get_handler_by_name('normal')
672 671 else:
673 672 return None
674 673
675 674
676 675 class AutoMagicChecker(PrefilterChecker):
677 676
678 677 priority = Int(700, config=True)
679 678
680 679 def check(self, line_info):
681 680 """If the ifun is magic, and automagic is on, run it. Note: normal,
682 681 non-auto magic would already have been triggered via '%' in
683 682 check_esc_chars. This just checks for automagic. Also, before
684 683 triggering the magic handler, make sure that there is nothing in the
685 684 user namespace which could shadow it."""
686 685 if not self.shell.automagic or not hasattr(self.shell,'magic_'+line_info.ifun):
687 686 return None
688 687
689 688 # We have a likely magic method. Make sure we should actually call it.
690 689 if line_info.continue_prompt and not self.prefilter_manager.multi_line_specials:
691 690 return None
692 691
693 692 head = line_info.ifun.split('.',1)[0]
694 693 if is_shadowed(head, self.shell):
695 694 return None
696 695
697 696 return self.prefilter_manager.get_handler_by_name('magic')
698 697
699 698
700 699 class AliasChecker(PrefilterChecker):
701 700
702 701 priority = Int(800, config=True)
703 702
704 703 def check(self, line_info):
705 704 "Check if the initital identifier on the line is an alias."
706 705 # Note: aliases can not contain '.'
707 706 head = line_info.ifun.split('.',1)[0]
708 707 if line_info.ifun not in self.shell.alias_manager \
709 708 or head not in self.shell.alias_manager \
710 709 or is_shadowed(head, self.shell):
711 710 return None
712 711
713 712 return self.prefilter_manager.get_handler_by_name('alias')
714 713
715 714
716 715 class PythonOpsChecker(PrefilterChecker):
717 716
718 717 priority = Int(900, config=True)
719 718
720 719 def check(self, line_info):
721 720 """If the 'rest' of the line begins with a function call or pretty much
722 721 any python operator, we should simply execute the line (regardless of
723 722 whether or not there's a possible autocall expansion). This avoids
724 723 spurious (and very confusing) geattr() accesses."""
725 724 if line_info.the_rest and line_info.the_rest[0] in '!=()<>,+*/%^&|':
726 725 return self.prefilter_manager.get_handler_by_name('normal')
727 726 else:
728 727 return None
729 728
730 729
731 730 class AutocallChecker(PrefilterChecker):
732 731
733 732 priority = Int(1000, config=True)
734 733
735 734 def check(self, line_info):
736 735 "Check if the initial word/function is callable and autocall is on."
737 736 if not self.shell.autocall:
738 737 return None
739 738
740 739 oinfo = line_info.ofind(self.shell) # This can mutate state via getattr
741 740 if not oinfo['found']:
742 741 return None
743 742
744 743 if callable(oinfo['obj']) \
745 744 and (not re_exclude_auto.match(line_info.the_rest)) \
746 745 and re_fun_name.match(line_info.ifun):
747 746 return self.prefilter_manager.get_handler_by_name('auto')
748 747 else:
749 748 return None
750 749
751 750
752 751 #-----------------------------------------------------------------------------
753 752 # Prefilter handlers
754 753 #-----------------------------------------------------------------------------
755 754
756 755
757 756 class PrefilterHandler(Configurable):
758 757
759 758 handler_name = Unicode('normal')
760 759 esc_strings = List([])
761 760 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
762 761 prefilter_manager = Instance('IPython.core.prefilter.PrefilterManager')
763 762
764 763 def __init__(self, shell=None, prefilter_manager=None, config=None):
765 764 super(PrefilterHandler, self).__init__(
766 765 shell=shell, prefilter_manager=prefilter_manager, config=config
767 766 )
768 767 self.prefilter_manager.register_handler(
769 768 self.handler_name,
770 769 self,
771 770 self.esc_strings
772 771 )
773 772
774 773 def handle(self, line_info):
775 774 # print "normal: ", line_info
776 775 """Handle normal input lines. Use as a template for handlers."""
777 776
778 777 # With autoindent on, we need some way to exit the input loop, and I
779 778 # don't want to force the user to have to backspace all the way to
780 779 # clear the line. The rule will be in this case, that either two
781 780 # lines of pure whitespace in a row, or a line of pure whitespace but
782 781 # of a size different to the indent level, will exit the input loop.
783 782 line = line_info.line
784 783 continue_prompt = line_info.continue_prompt
785 784
786 785 if (continue_prompt and
787 786 self.shell.autoindent and
788 787 line.isspace() and
789 788 0 < abs(len(line) - self.shell.indent_current_nsp) <= 2):
790 789 line = ''
791 790
792 791 return line
793 792
794 793 def __str__(self):
795 794 return "<%s(name=%s)>" % (self.__class__.__name__, self.handler_name)
796 795
797 796
798 797 class AliasHandler(PrefilterHandler):
799 798
800 799 handler_name = Unicode('alias')
801 800
802 801 def handle(self, line_info):
803 802 """Handle alias input lines. """
804 803 transformed = self.shell.alias_manager.expand_aliases(line_info.ifun,line_info.the_rest)
805 804 # pre is needed, because it carries the leading whitespace. Otherwise
806 805 # aliases won't work in indented sections.
807 806 line_out = '%sget_ipython().system(%s)' % (line_info.pre_whitespace,
808 807 make_quoted_expr(transformed))
809 808
810 809 return line_out
811 810
812 811
813 812 class ShellEscapeHandler(PrefilterHandler):
814 813
815 814 handler_name = Unicode('shell')
816 815 esc_strings = List([ESC_SHELL, ESC_SH_CAP])
817 816
818 817 def handle(self, line_info):
819 818 """Execute the line in a shell, empty return value"""
820 819 magic_handler = self.prefilter_manager.get_handler_by_name('magic')
821 820
822 821 line = line_info.line
823 822 if line.lstrip().startswith(ESC_SH_CAP):
824 823 # rewrite LineInfo's line, ifun and the_rest to properly hold the
825 824 # call to %sx and the actual command to be executed, so
826 825 # handle_magic can work correctly. Note that this works even if
827 826 # the line is indented, so it handles multi_line_specials
828 827 # properly.
829 828 new_rest = line.lstrip()[2:]
830 829 line_info.line = '%ssx %s' % (ESC_MAGIC, new_rest)
831 830 line_info.ifun = 'sx'
832 831 line_info.the_rest = new_rest
833 832 return magic_handler.handle(line_info)
834 833 else:
835 834 cmd = line.lstrip().lstrip(ESC_SHELL)
836 835 line_out = '%sget_ipython().system(%s)' % (line_info.pre_whitespace,
837 836 make_quoted_expr(cmd))
838 837 return line_out
839 838
840 839
841 840 class MacroHandler(PrefilterHandler):
842 841 handler_name = Unicode("macro")
843 842
844 843 def handle(self, line_info):
845 844 obj = self.shell.user_ns.get(line_info.ifun)
846 845 pre_space = line_info.pre_whitespace
847 846 line_sep = "\n" + pre_space
848 847 return pre_space + line_sep.join(obj.value.splitlines())
849 848
850 849
851 850 class MagicHandler(PrefilterHandler):
852 851
853 852 handler_name = Unicode('magic')
854 853 esc_strings = List([ESC_MAGIC])
855 854
856 855 def handle(self, line_info):
857 856 """Execute magic functions."""
858 857 ifun = line_info.ifun
859 858 the_rest = line_info.the_rest
860 859 cmd = '%sget_ipython().magic(%s)' % (line_info.pre_whitespace,
861 860 make_quoted_expr(ifun + " " + the_rest))
862 861 return cmd
863 862
864 863
865 864 class AutoHandler(PrefilterHandler):
866 865
867 866 handler_name = Unicode('auto')
868 867 esc_strings = List([ESC_PAREN, ESC_QUOTE, ESC_QUOTE2])
869 868
870 869 def handle(self, line_info):
871 870 """Handle lines which can be auto-executed, quoting if requested."""
872 871 line = line_info.line
873 872 ifun = line_info.ifun
874 873 the_rest = line_info.the_rest
875 874 pre = line_info.pre
876 875 continue_prompt = line_info.continue_prompt
877 876 obj = line_info.ofind(self)['obj']
878 877 #print 'pre <%s> ifun <%s> rest <%s>' % (pre,ifun,the_rest) # dbg
879 878
880 879 # This should only be active for single-line input!
881 880 if continue_prompt:
882 881 return line
883 882
884 883 force_auto = isinstance(obj, IPyAutocall)
885 884 auto_rewrite = getattr(obj, 'rewrite', True)
886 885
887 886 if pre == ESC_QUOTE:
888 887 # Auto-quote splitting on whitespace
889 888 newcmd = '%s("%s")' % (ifun,'", "'.join(the_rest.split()) )
890 889 elif pre == ESC_QUOTE2:
891 890 # Auto-quote whole string
892 891 newcmd = '%s("%s")' % (ifun,the_rest)
893 892 elif pre == ESC_PAREN:
894 893 newcmd = '%s(%s)' % (ifun,",".join(the_rest.split()))
895 894 else:
896 895 # Auto-paren.
897 896 # We only apply it to argument-less calls if the autocall
898 897 # parameter is set to 2. We only need to check that autocall is <
899 898 # 2, since this function isn't called unless it's at least 1.
900 899 if not the_rest and (self.shell.autocall < 2) and not force_auto:
901 900 newcmd = '%s %s' % (ifun,the_rest)
902 901 auto_rewrite = False
903 902 else:
904 903 if not force_auto and the_rest.startswith('['):
905 904 if hasattr(obj,'__getitem__'):
906 905 # Don't autocall in this case: item access for an object
907 906 # which is BOTH callable and implements __getitem__.
908 907 newcmd = '%s %s' % (ifun,the_rest)
909 908 auto_rewrite = False
910 909 else:
911 910 # if the object doesn't support [] access, go ahead and
912 911 # autocall
913 912 newcmd = '%s(%s)' % (ifun.rstrip(),the_rest)
914 913 elif the_rest.endswith(';'):
915 914 newcmd = '%s(%s);' % (ifun.rstrip(),the_rest[:-1])
916 915 else:
917 916 newcmd = '%s(%s)' % (ifun.rstrip(), the_rest)
918 917
919 918 if auto_rewrite:
920 919 self.shell.auto_rewrite_input(newcmd)
921 920
922 921 return newcmd
923 922
924 923
925 924 class HelpHandler(PrefilterHandler):
926 925
927 926 handler_name = Unicode('help')
928 927 esc_strings = List([ESC_HELP])
929 928
930 929 def handle(self, line_info):
931 930 """Try to get some help for the object.
932 931
933 932 obj? or ?obj -> basic information.
934 933 obj?? or ??obj -> more details.
935 934 """
936 935 normal_handler = self.prefilter_manager.get_handler_by_name('normal')
937 936 line = line_info.line
938 937 # We need to make sure that we don't process lines which would be
939 938 # otherwise valid python, such as "x=1 # what?"
940 939 try:
941 940 codeop.compile_command(line)
942 941 except SyntaxError:
943 942 # We should only handle as help stuff which is NOT valid syntax
944 943 if line[0]==ESC_HELP:
945 944 line = line[1:]
946 945 elif line[-1]==ESC_HELP:
947 946 line = line[:-1]
948 947 if line:
949 948 #print 'line:<%r>' % line # dbg
950 949 self.shell.magic_pinfo(line)
951 950 else:
952 951 self.shell.show_usage()
953 952 return '' # Empty string is needed here!
954 953 except:
955 954 raise
956 955 # Pass any other exceptions through to the normal handler
957 956 return normal_handler.handle(line_info)
958 957 else:
959 958 # If the code compiles ok, we should handle it normally
960 959 return normal_handler.handle(line_info)
961 960
962 961
963 962 class EmacsHandler(PrefilterHandler):
964 963
965 964 handler_name = Unicode('emacs')
966 965 esc_strings = List([])
967 966
968 967 def handle(self, line_info):
969 968 """Handle input lines marked by python-mode."""
970 969
971 970 # Currently, nothing is done. Later more functionality can be added
972 971 # here if needed.
973 972
974 973 # The input cache shouldn't be updated
975 974 return line_info.line
976 975
977 976
978 977 #-----------------------------------------------------------------------------
979 978 # Defaults
980 979 #-----------------------------------------------------------------------------
981 980
982 981
983 982 _default_transformers = [
984 983 AssignSystemTransformer,
985 984 AssignMagicTransformer,
986 985 PyPromptTransformer,
987 986 IPyPromptTransformer,
988 987 ]
989 988
990 989 _default_checkers = [
991 990 EmacsChecker,
992 991 ShellEscapeChecker,
993 992 MacroChecker,
994 993 IPyAutocallChecker,
995 994 MultiLineMagicChecker,
996 995 EscCharsChecker,
997 996 AssignmentChecker,
998 997 AutoMagicChecker,
999 998 AliasChecker,
1000 999 PythonOpsChecker,
1001 1000 AutocallChecker
1002 1001 ]
1003 1002
1004 1003 _default_handlers = [
1005 1004 PrefilterHandler,
1006 1005 AliasHandler,
1007 1006 ShellEscapeHandler,
1008 1007 MacroHandler,
1009 1008 MagicHandler,
1010 1009 AutoHandler,
1011 1010 HelpHandler,
1012 1011 EmacsHandler
1013 1012 ]
@@ -1,253 +1,252 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 A mixin for :class:`~IPython.core.application.Application` classes that
5 4 launch InteractiveShell instances, load extensions, etc.
6 5
7 6 Authors
8 7 -------
9 8
10 9 * Min Ragan-Kelley
11 10 """
12 11
13 12 #-----------------------------------------------------------------------------
14 13 # Copyright (C) 2008-2011 The IPython Development Team
15 14 #
16 15 # Distributed under the terms of the BSD License. The full license is in
17 16 # the file COPYING, distributed as part of this software.
18 17 #-----------------------------------------------------------------------------
19 18
20 19 #-----------------------------------------------------------------------------
21 20 # Imports
22 21 #-----------------------------------------------------------------------------
23 22
24 23 from __future__ import absolute_import
25 24
26 25 import os
27 26 import sys
28 27
29 28 from IPython.config.application import boolean_flag
30 29 from IPython.config.configurable import Configurable
31 30 from IPython.config.loader import Config
32 31 from IPython.utils.path import filefind
33 32 from IPython.utils.traitlets import Unicode, Instance, List
34 33
35 34 #-----------------------------------------------------------------------------
36 35 # Aliases and Flags
37 36 #-----------------------------------------------------------------------------
38 37
39 38 shell_flags = {}
40 39
41 40 addflag = lambda *args: shell_flags.update(boolean_flag(*args))
42 41 addflag('autoindent', 'InteractiveShell.autoindent',
43 42 'Turn on autoindenting.', 'Turn off autoindenting.'
44 43 )
45 44 addflag('automagic', 'InteractiveShell.automagic',
46 45 """Turn on the auto calling of magic commands. Type %%magic at the
47 46 IPython prompt for more information.""",
48 47 'Turn off the auto calling of magic commands.'
49 48 )
50 49 addflag('pdb', 'InteractiveShell.pdb',
51 50 "Enable auto calling the pdb debugger after every exception.",
52 51 "Disable auto calling the pdb debugger after every exception."
53 52 )
54 53 addflag('pprint', 'PlainTextFormatter.pprint',
55 54 "Enable auto pretty printing of results.",
56 55 "Disable auto auto pretty printing of results."
57 56 )
58 57 addflag('color-info', 'InteractiveShell.color_info',
59 58 """IPython can display information about objects via a set of func-
60 59 tions, and optionally can use colors for this, syntax highlighting
61 60 source code and various other elements. However, because this
62 61 information is passed through a pager (like 'less') and many pagers get
63 62 confused with color codes, this option is off by default. You can test
64 63 it and turn it on permanently in your ipython_config.py file if it
65 64 works for you. Test it and turn it on permanently if it works with
66 65 your system. The magic function %%color_info allows you to toggle this
67 66 interactively for testing.""",
68 67 "Disable using colors for info related things."
69 68 )
70 69 addflag('deep-reload', 'InteractiveShell.deep_reload',
71 70 """Enable deep (recursive) reloading by default. IPython can use the
72 71 deep_reload module which reloads changes in modules recursively (it
73 72 replaces the reload() function, so you don't need to change anything to
74 73 use it). deep_reload() forces a full reload of modules whose code may
75 74 have changed, which the default reload() function does not. When
76 75 deep_reload is off, IPython will use the normal reload(), but
77 76 deep_reload will still be available as dreload(). This feature is off
78 77 by default [which means that you have both normal reload() and
79 78 dreload()].""",
80 79 "Disable deep (recursive) reloading by default."
81 80 )
82 81 nosep_config = Config()
83 82 nosep_config.InteractiveShell.separate_in = ''
84 83 nosep_config.InteractiveShell.separate_out = ''
85 84 nosep_config.InteractiveShell.separate_out2 = ''
86 85
87 86 shell_flags['nosep']=(nosep_config, "Eliminate all spacing between prompts.")
88 87
89 88
90 89 # it's possible we don't want short aliases for *all* of these:
91 90 shell_aliases = dict(
92 91 autocall='InteractiveShell.autocall',
93 92 colors='InteractiveShell.colors',
94 93 logfile='InteractiveShell.logfile',
95 94 logappend='InteractiveShell.logappend',
96 95 c='InteractiveShellApp.code_to_run',
97 96 ext='InteractiveShellApp.extra_extension',
98 97 )
99 98 shell_aliases['cache-size'] = 'InteractiveShell.cache_size'
100 99
101 100 #-----------------------------------------------------------------------------
102 101 # Main classes and functions
103 102 #-----------------------------------------------------------------------------
104 103
105 104 class InteractiveShellApp(Configurable):
106 105 """A Mixin for applications that start InteractiveShell instances.
107 106
108 107 Provides configurables for loading extensions and executing files
109 108 as part of configuring a Shell environment.
110 109
111 110 Provides init_extensions() and init_code() methods, to be called
112 111 after init_shell(), which must be implemented by subclasses.
113 112 """
114 113 extensions = List(Unicode, config=True,
115 114 help="A list of dotted module names of IPython extensions to load."
116 115 )
117 116 extra_extension = Unicode('', config=True,
118 117 help="dotted module name of an IPython extension to load."
119 118 )
120 119 def _extra_extension_changed(self, name, old, new):
121 120 if new:
122 121 # add to self.extensions
123 122 self.extensions.append(new)
124 123
125 124 exec_files = List(Unicode, config=True,
126 125 help="""List of files to run at IPython startup."""
127 126 )
128 127 file_to_run = Unicode('', config=True,
129 128 help="""A file to be run""")
130 129
131 130 exec_lines = List(Unicode, config=True,
132 131 help="""lines of code to run at IPython startup."""
133 132 )
134 133 code_to_run = Unicode('', config=True,
135 134 help="Execute the given command string."
136 135 )
137 136 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
138 137
139 138 def init_shell(self):
140 139 raise NotImplementedError("Override in subclasses")
141 140
142 141 def init_extensions(self):
143 142 """Load all IPython extensions in IPythonApp.extensions.
144 143
145 144 This uses the :meth:`ExtensionManager.load_extensions` to load all
146 145 the extensions listed in ``self.extensions``.
147 146 """
148 147 if not self.extensions:
149 148 return
150 149 try:
151 150 self.log.debug("Loading IPython extensions...")
152 151 extensions = self.extensions
153 152 for ext in extensions:
154 153 try:
155 154 self.log.info("Loading IPython extension: %s" % ext)
156 155 self.shell.extension_manager.load_extension(ext)
157 156 except:
158 157 self.log.warn("Error in loading extension: %s" % ext)
159 158 self.shell.showtraceback()
160 159 except:
161 160 self.log.warn("Unknown error in loading extensions:")
162 161 self.shell.showtraceback()
163 162
164 163 def init_code(self):
165 164 """run the pre-flight code, specified via exec_lines"""
166 165 self._run_exec_lines()
167 166 self._run_exec_files()
168 167 self._run_cmd_line_code()
169 168
170 169 def _run_exec_lines(self):
171 170 """Run lines of code in IPythonApp.exec_lines in the user's namespace."""
172 171 if not self.exec_lines:
173 172 return
174 173 try:
175 174 self.log.debug("Running code from IPythonApp.exec_lines...")
176 175 for line in self.exec_lines:
177 176 try:
178 177 self.log.info("Running code in user namespace: %s" %
179 178 line)
180 179 self.shell.run_cell(line, store_history=False)
181 180 except:
182 181 self.log.warn("Error in executing line in user "
183 182 "namespace: %s" % line)
184 183 self.shell.showtraceback()
185 184 except:
186 185 self.log.warn("Unknown error in handling IPythonApp.exec_lines:")
187 186 self.shell.showtraceback()
188 187
189 188 def _exec_file(self, fname):
190 189 try:
191 190 full_filename = filefind(fname, [u'.', self.ipython_dir])
192 191 except IOError as e:
193 192 self.log.warn("File not found: %r"%fname)
194 193 return
195 194 # Make sure that the running script gets a proper sys.argv as if it
196 195 # were run from a system shell.
197 196 save_argv = sys.argv
198 197 sys.argv = [full_filename] + self.extra_args[1:]
199 198 try:
200 199 if os.path.isfile(full_filename):
201 200 if full_filename.endswith('.ipy'):
202 201 self.log.info("Running file in user namespace: %s" %
203 202 full_filename)
204 203 self.shell.safe_execfile_ipy(full_filename)
205 204 else:
206 205 # default to python, even without extension
207 206 self.log.info("Running file in user namespace: %s" %
208 207 full_filename)
209 208 # Ensure that __file__ is always defined to match Python behavior
210 209 self.shell.user_ns['__file__'] = fname
211 210 try:
212 211 self.shell.safe_execfile(full_filename, self.shell.user_ns)
213 212 finally:
214 213 del self.shell.user_ns['__file__']
215 214 finally:
216 215 sys.argv = save_argv
217 216
218 217 def _run_exec_files(self):
219 218 """Run files from IPythonApp.exec_files"""
220 219 if not self.exec_files:
221 220 return
222 221
223 222 self.log.debug("Running files in IPythonApp.exec_files...")
224 223 try:
225 224 for fname in self.exec_files:
226 225 self._exec_file(fname)
227 226 except:
228 227 self.log.warn("Unknown error in handling IPythonApp.exec_files:")
229 228 self.shell.showtraceback()
230 229
231 230 def _run_cmd_line_code(self):
232 231 """Run code or file specified at the command-line"""
233 232 if self.code_to_run:
234 233 line = self.code_to_run
235 234 try:
236 235 self.log.info("Running code given at command line (c=): %s" %
237 236 line)
238 237 self.shell.run_cell(line, store_history=False)
239 238 except:
240 239 self.log.warn("Error in executing line in user namespace: %s" %
241 240 line)
242 241 self.shell.showtraceback()
243 242
244 243 # Like Python itself, ignore the second if the first of these is present
245 244 elif self.file_to_run:
246 245 fname = self.file_to_run
247 246 try:
248 247 self._exec_file(fname)
249 248 except:
250 249 self.log.warn("Error in executing file in user namespace: %s" %
251 250 fname)
252 251 self.shell.showtraceback()
253 252
@@ -1,91 +1,90 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 Simple utility for splitting user input.
5 4
6 5 Authors:
7 6
8 7 * Brian Granger
9 8 * Fernando Perez
10 9 """
11 10
12 11 #-----------------------------------------------------------------------------
13 12 # Copyright (C) 2008-2009 The IPython Development Team
14 13 #
15 14 # Distributed under the terms of the BSD License. The full license is in
16 15 # the file COPYING, distributed as part of this software.
17 16 #-----------------------------------------------------------------------------
18 17
19 18 #-----------------------------------------------------------------------------
20 19 # Imports
21 20 #-----------------------------------------------------------------------------
22 21
23 22 import re
24 23 import sys
25 24
26 25 #-----------------------------------------------------------------------------
27 26 # Main function
28 27 #-----------------------------------------------------------------------------
29 28
30 29
31 30 # RegExp for splitting line contents into pre-char//first word-method//rest.
32 31 # For clarity, each group in on one line.
33 32
34 33 # WARNING: update the regexp if the escapes in interactiveshell are changed, as they
35 34 # are hardwired in.
36 35
37 36 # Although it's not solely driven by the regex, note that:
38 37 # ,;/% only trigger if they are the first character on the line
39 38 # ! and !! trigger if they are first char(s) *or* follow an indent
40 39 # ? triggers as first or last char.
41 40
42 41 # The three parts of the regex are:
43 42 # 1) pre: pre_char *or* initial whitespace
44 43 # 2) ifun: first word/method (mix of \w and '.')
45 44 # 3) the_rest: rest of line (separated from ifun by space if non-empty)
46 45 line_split = re.compile(r'^([,;/%?]|!!?|\s*)'
47 46 r'\s*([\w\.]+)'
48 47 r'(\s+.*$|$)')
49 48
50 49 # r'[\w\.]+'
51 50 # r'\s*=\s*%.*'
52 51
53 52 def split_user_input(line, pattern=None):
54 53 """Split user input into pre-char/whitespace, function part and rest.
55 54
56 55 This is currently handles lines with '=' in them in a very inconsistent
57 56 manner.
58 57 """
59 58 # We need to ensure that the rest of this routine deals only with unicode
60 59 if type(line)==str:
61 60 codec = sys.stdin.encoding
62 61 if codec is None:
63 62 codec = 'utf-8'
64 63 line = line.decode(codec)
65 64
66 65 if pattern is None:
67 66 pattern = line_split
68 67 match = pattern.match(line)
69 68 if not match:
70 69 # print "match failed for line '%s'" % line
71 70 try:
72 71 ifun, the_rest = line.split(None,1)
73 72 except ValueError:
74 73 # print "split failed for line '%s'" % line
75 74 ifun, the_rest = line, u''
76 75 pre = re.match('^(\s*)(.*)',line).groups()[0]
77 76 else:
78 77 pre,ifun,the_rest = match.groups()
79 78
80 79 # ifun has to be a valid python identifier, so it better encode into
81 80 # ascii. We do still make it a unicode string so that we consistently
82 81 # return unicode, but it will be one that is guaranteed to be pure ascii
83 82 try:
84 83 ifun = unicode(ifun.encode('ascii'))
85 84 except UnicodeEncodeError:
86 85 the_rest = ifun + u' ' + the_rest
87 86 ifun = u''
88 87
89 88 #print 'line:<%s>' % line # dbg
90 89 #print 'pre <%s> ifun <%s> rest <%s>' % (pre,ifun.strip(),the_rest) # dbg
91 90 return pre, ifun.strip(), the_rest.lstrip()
@@ -1,59 +1,58 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2
4 3 def test_import_completer():
5 4 from IPython.core import completer
6 5
7 6 def test_import_crashhandler():
8 7 from IPython.core import crashhandler
9 8
10 9 def test_import_debugger():
11 10 from IPython.core import debugger
12 11
13 12 def test_import_fakemodule():
14 13 from IPython.core import fakemodule
15 14
16 15 def test_import_excolors():
17 16 from IPython.core import excolors
18 17
19 18 def test_import_history():
20 19 from IPython.core import history
21 20
22 21 def test_import_hooks():
23 22 from IPython.core import hooks
24 23
25 24 def test_import_ipapi():
26 25 from IPython.core import ipapi
27 26
28 27 def test_import_interactiveshell():
29 28 from IPython.core import interactiveshell
30 29
31 30 def test_import_logger():
32 31 from IPython.core import logger
33 32
34 33 def test_import_macro():
35 34 from IPython.core import macro
36 35
37 36 def test_import_magic():
38 37 from IPython.core import magic
39 38
40 39 def test_import_oinspect():
41 40 from IPython.core import oinspect
42 41
43 42 def test_import_prefilter():
44 43 from IPython.core import prefilter
45 44
46 45 def test_import_prompts():
47 46 from IPython.core import prompts
48 47
49 48 def test_import_release():
50 49 from IPython.core import release
51 50
52 51 def test_import_shadowns():
53 52 from IPython.core import shadowns
54 53
55 54 def test_import_ultratb():
56 55 from IPython.core import ultratb
57 56
58 57 def test_import_usage():
59 58 from IPython.core import usage
@@ -1,300 +1,299 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2
4 3 """Magic command interface for interactive parallel work."""
5 4
6 5 #-----------------------------------------------------------------------------
7 6 # Copyright (C) 2008-2009 The IPython Development Team
8 7 #
9 8 # Distributed under the terms of the BSD License. The full license is in
10 9 # the file COPYING, distributed as part of this software.
11 10 #-----------------------------------------------------------------------------
12 11
13 12 #-----------------------------------------------------------------------------
14 13 # Imports
15 14 #-----------------------------------------------------------------------------
16 15
17 16 import ast
18 17 import re
19 18
20 19 from IPython.core.plugin import Plugin
21 20 from IPython.utils.traitlets import Bool, Any, Instance
22 21 from IPython.testing.skipdoctest import skip_doctest
23 22
24 23 #-----------------------------------------------------------------------------
25 24 # Definitions of magic functions for use with IPython
26 25 #-----------------------------------------------------------------------------
27 26
28 27
29 28 NO_ACTIVE_VIEW = """
30 29 Use activate() on a DirectView object to activate it for magics.
31 30 """
32 31
33 32
34 33 class ParalleMagic(Plugin):
35 34 """A component to manage the %result, %px and %autopx magics."""
36 35
37 36 active_view = Instance('IPython.parallel.client.view.DirectView')
38 37 verbose = Bool(False, config=True)
39 38 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
40 39
41 40 def __init__(self, shell=None, config=None):
42 41 super(ParalleMagic, self).__init__(shell=shell, config=config)
43 42 self._define_magics()
44 43 # A flag showing if autopx is activated or not
45 44 self.autopx = False
46 45
47 46 def _define_magics(self):
48 47 """Define the magic functions."""
49 48 self.shell.define_magic('result', self.magic_result)
50 49 self.shell.define_magic('px', self.magic_px)
51 50 self.shell.define_magic('autopx', self.magic_autopx)
52 51
53 52 @skip_doctest
54 53 def magic_result(self, ipself, parameter_s=''):
55 54 """Print the result of command i on all engines..
56 55
57 56 To use this a :class:`DirectView` instance must be created
58 57 and then activated by calling its :meth:`activate` method.
59 58
60 59 Then you can do the following::
61 60
62 61 In [23]: %result
63 62 Out[23]:
64 63 <Results List>
65 64 [0] In [6]: a = 10
66 65 [1] In [6]: a = 10
67 66
68 67 In [22]: %result 6
69 68 Out[22]:
70 69 <Results List>
71 70 [0] In [6]: a = 10
72 71 [1] In [6]: a = 10
73 72 """
74 73 if self.active_view is None:
75 74 print NO_ACTIVE_VIEW
76 75 return
77 76
78 77 try:
79 78 index = int(parameter_s)
80 79 except:
81 80 index = None
82 81 result = self.active_view.get_result(index)
83 82 return result
84 83
85 84 @skip_doctest
86 85 def magic_px(self, ipself, parameter_s=''):
87 86 """Executes the given python command in parallel.
88 87
89 88 To use this a :class:`DirectView` instance must be created
90 89 and then activated by calling its :meth:`activate` method.
91 90
92 91 Then you can do the following::
93 92
94 93 In [24]: %px a = 5
95 94 Parallel execution on engine(s): all
96 95 Out[24]:
97 96 <Results List>
98 97 [0] In [7]: a = 5
99 98 [1] In [7]: a = 5
100 99 """
101 100
102 101 if self.active_view is None:
103 102 print NO_ACTIVE_VIEW
104 103 return
105 104 print "Parallel execution on engine(s): %s" % self.active_view.targets
106 105 result = self.active_view.execute(parameter_s, block=False)
107 106 if self.active_view.block:
108 107 result.get()
109 108 self._maybe_display_output(result)
110 109
111 110 @skip_doctest
112 111 def magic_autopx(self, ipself, parameter_s=''):
113 112 """Toggles auto parallel mode.
114 113
115 114 To use this a :class:`DirectView` instance must be created
116 115 and then activated by calling its :meth:`activate` method. Once this
117 116 is called, all commands typed at the command line are send to
118 117 the engines to be executed in parallel. To control which engine
119 118 are used, set the ``targets`` attributed of the multiengine client
120 119 before entering ``%autopx`` mode.
121 120
122 121 Then you can do the following::
123 122
124 123 In [25]: %autopx
125 124 %autopx to enabled
126 125
127 126 In [26]: a = 10
128 127 Parallel execution on engine(s): [0,1,2,3]
129 128 In [27]: print a
130 129 Parallel execution on engine(s): [0,1,2,3]
131 130 [stdout:0] 10
132 131 [stdout:1] 10
133 132 [stdout:2] 10
134 133 [stdout:3] 10
135 134
136 135
137 136 In [27]: %autopx
138 137 %autopx disabled
139 138 """
140 139 if self.autopx:
141 140 self._disable_autopx()
142 141 else:
143 142 self._enable_autopx()
144 143
145 144 def _enable_autopx(self):
146 145 """Enable %autopx mode by saving the original run_cell and installing
147 146 pxrun_cell.
148 147 """
149 148 if self.active_view is None:
150 149 print NO_ACTIVE_VIEW
151 150 return
152 151
153 152 # override run_cell and run_code
154 153 self._original_run_cell = self.shell.run_cell
155 154 self.shell.run_cell = self.pxrun_cell
156 155 self._original_run_code = self.shell.run_code
157 156 self.shell.run_code = self.pxrun_code
158 157
159 158 self.autopx = True
160 159 print "%autopx enabled"
161 160
162 161 def _disable_autopx(self):
163 162 """Disable %autopx by restoring the original InteractiveShell.run_cell.
164 163 """
165 164 if self.autopx:
166 165 self.shell.run_cell = self._original_run_cell
167 166 self.shell.run_code = self._original_run_code
168 167 self.autopx = False
169 168 print "%autopx disabled"
170 169
171 170 def _maybe_display_output(self, result):
172 171 """Maybe display the output of a parallel result.
173 172
174 173 If self.active_view.block is True, wait for the result
175 174 and display the result. Otherwise, this is a noop.
176 175 """
177 176 if isinstance(result.stdout, basestring):
178 177 # single result
179 178 stdouts = [result.stdout.rstrip()]
180 179 else:
181 180 stdouts = [s.rstrip() for s in result.stdout]
182 181
183 182 targets = self.active_view.targets
184 183 if isinstance(targets, int):
185 184 targets = [targets]
186 185 elif targets == 'all':
187 186 targets = self.active_view.client.ids
188 187
189 188 if any(stdouts):
190 189 for eid,stdout in zip(targets, stdouts):
191 190 print '[stdout:%i]'%eid, stdout
192 191
193 192
194 193 def pxrun_cell(self, raw_cell, store_history=True):
195 194 """drop-in replacement for InteractiveShell.run_cell.
196 195
197 196 This executes code remotely, instead of in the local namespace.
198 197
199 198 See InteractiveShell.run_cell for details.
200 199 """
201 200
202 201 if (not raw_cell) or raw_cell.isspace():
203 202 return
204 203
205 204 ipself = self.shell
206 205
207 206 with ipself.builtin_trap:
208 207 cell = ipself.prefilter_manager.prefilter_lines(raw_cell)
209 208
210 209 # Store raw and processed history
211 210 if store_history:
212 211 ipself.history_manager.store_inputs(ipself.execution_count,
213 212 cell, raw_cell)
214 213
215 214 # ipself.logger.log(cell, raw_cell)
216 215
217 216 cell_name = ipself.compile.cache(cell, ipself.execution_count)
218 217
219 218 try:
220 219 code_ast = ast.parse(cell, filename=cell_name)
221 220 except (OverflowError, SyntaxError, ValueError, TypeError, MemoryError):
222 221 # Case 1
223 222 ipself.showsyntaxerror()
224 223 ipself.execution_count += 1
225 224 return None
226 225 except NameError:
227 226 # ignore name errors, because we don't know the remote keys
228 227 pass
229 228
230 229 if store_history:
231 230 # Write output to the database. Does nothing unless
232 231 # history output logging is enabled.
233 232 ipself.history_manager.store_output(ipself.execution_count)
234 233 # Each cell is a *single* input, regardless of how many lines it has
235 234 ipself.execution_count += 1
236 235
237 236 if re.search(r'get_ipython\(\)\.magic\(u?"%?autopx', cell):
238 237 self._disable_autopx()
239 238 return False
240 239 else:
241 240 try:
242 241 result = self.active_view.execute(cell, block=False)
243 242 except:
244 243 ipself.showtraceback()
245 244 return True
246 245 else:
247 246 if self.active_view.block:
248 247 try:
249 248 result.get()
250 249 except:
251 250 self.shell.showtraceback()
252 251 return True
253 252 else:
254 253 self._maybe_display_output(result)
255 254 return False
256 255
257 256 def pxrun_code(self, code_obj):
258 257 """drop-in replacement for InteractiveShell.run_code.
259 258
260 259 This executes code remotely, instead of in the local namespace.
261 260
262 261 See InteractiveShell.run_code for details.
263 262 """
264 263 ipself = self.shell
265 264 # check code object for the autopx magic
266 265 if 'get_ipython' in code_obj.co_names and 'magic' in code_obj.co_names and \
267 266 any( [ isinstance(c, basestring) and 'autopx' in c for c in code_obj.co_consts ]):
268 267 self._disable_autopx()
269 268 return False
270 269 else:
271 270 try:
272 271 result = self.active_view.execute(code_obj, block=False)
273 272 except:
274 273 ipself.showtraceback()
275 274 return True
276 275 else:
277 276 if self.active_view.block:
278 277 try:
279 278 result.get()
280 279 except:
281 280 self.shell.showtraceback()
282 281 return True
283 282 else:
284 283 self._maybe_display_output(result)
285 284 return False
286 285
287 286
288 287
289 288
290 289 _loaded = False
291 290
292 291
293 292 def load_ipython_extension(ip):
294 293 """Load the extension in IPython."""
295 294 global _loaded
296 295 if not _loaded:
297 296 plugin = ParalleMagic(shell=ip, config=ip.config)
298 297 ip.plugin_manager.register_plugin('parallelmagic', plugin)
299 298 _loaded = True
300 299
@@ -1,170 +1,169 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2
4 3 # GUID.py
5 4 # Version 2.6
6 5 #
7 6 # Copyright (c) 2006 Conan C. Albrecht
8 7 #
9 8 # Permission is hereby granted, free of charge, to any person obtaining a copy
10 9 # of this software and associated documentation files (the "Software"), to deal
11 10 # in the Software without restriction, including without limitation the rights
12 11 # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 12 # copies of the Software, and to permit persons to whom the Software is furnished
14 13 # to do so, subject to the following conditions:
15 14 #
16 15 # The above copyright notice and this permission notice shall be included in all
17 16 # copies or substantial portions of the Software.
18 17 #
19 18 # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
20 19 # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
21 20 # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
22 21 # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
23 22 # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
24 23 # DEALINGS IN THE SOFTWARE.
25 24
26 25
27 26
28 27 ##################################################################################################
29 28 ### A globally-unique identifier made up of time and ip and 8 digits for a counter:
30 29 ### each GUID is 40 characters wide
31 30 ###
32 31 ### A globally unique identifier that combines ip, time, and a counter. Since the
33 32 ### time is listed first, you can sort records by guid. You can also extract the time
34 33 ### and ip if needed.
35 34 ###
36 35 ### Since the counter has eight hex characters, you can create up to
37 36 ### 0xffffffff (4294967295) GUIDs every millisecond. If your processor
38 37 ### is somehow fast enough to create more than that in a millisecond (looking
39 38 ### toward the future, of course), the function will wait until the next
40 39 ### millisecond to return.
41 40 ###
42 41 ### GUIDs make wonderful database keys. They require no access to the
43 42 ### database (to get the max index number), they are extremely unique, and they sort
44 43 ### automatically by time. GUIDs prevent key clashes when merging
45 44 ### two databases together, combining data, or generating keys in distributed
46 45 ### systems.
47 46 ###
48 47 ### There is an Internet Draft for UUIDs, but this module does not implement it.
49 48 ### If the draft catches on, perhaps I'll conform the module to it.
50 49 ###
51 50
52 51
53 52 # Changelog
54 53 # Sometime, 1997 Created the Java version of GUID
55 54 # Went through many versions in Java
56 55 # Sometime, 2002 Created the Python version of GUID, mirroring the Java version
57 56 # November 24, 2003 Changed Python version to be more pythonic, took out object and made just a module
58 57 # December 2, 2003 Fixed duplicating GUIDs. Sometimes they duplicate if multiples are created
59 58 # in the same millisecond (it checks the last 100 GUIDs now and has a larger random part)
60 59 # December 9, 2003 Fixed MAX_RANDOM, which was going over sys.maxint
61 60 # June 12, 2004 Allowed a custom IP address to be sent in rather than always using the
62 61 # local IP address.
63 62 # November 4, 2005 Changed the random part to a counter variable. Now GUIDs are totally
64 63 # unique and more efficient, as long as they are created by only
65 64 # on runtime on a given machine. The counter part is after the time
66 65 # part so it sorts correctly.
67 66 # November 8, 2005 The counter variable now starts at a random long now and cycles
68 67 # around. This is in case two guids are created on the same
69 68 # machine at the same millisecond (by different processes). Even though
70 69 # it is possible the GUID can be created, this makes it highly unlikely
71 70 # since the counter will likely be different.
72 71 # November 11, 2005 Fixed a bug in the new IP getting algorithm. Also, use IPv6 range
73 72 # for IP when we make it up (when it's no accessible)
74 73 # November 21, 2005 Added better IP-finding code. It finds IP address better now.
75 74 # January 5, 2006 Fixed a small bug caused in old versions of python (random module use)
76 75
77 76 import math
78 77 import socket
79 78 import random
80 79 import sys
81 80 import time
82 81 import threading
83 82
84 83
85 84
86 85 #############################
87 86 ### global module variables
88 87
89 88 #Makes a hex IP from a decimal dot-separated ip (eg: 127.0.0.1)
90 89 make_hexip = lambda ip: ''.join(["%04x" % long(i) for i in ip.split('.')]) # leave space for ip v6 (65K in each sub)
91 90
92 91 MAX_COUNTER = 0xfffffffe
93 92 counter = 0L
94 93 firstcounter = MAX_COUNTER
95 94 lasttime = 0
96 95 ip = ''
97 96 lock = threading.RLock()
98 97 try: # only need to get the IP addresss once
99 98 ip = socket.getaddrinfo(socket.gethostname(),0)[-1][-1][0]
100 99 hexip = make_hexip(ip)
101 100 except: # if we don't have an ip, default to someting in the 10.x.x.x private range
102 101 ip = '10'
103 102 rand = random.Random()
104 103 for i in range(3):
105 104 ip += '.' + str(rand.randrange(1, 0xffff)) # might as well use IPv6 range if we're making it up
106 105 hexip = make_hexip(ip)
107 106
108 107
109 108 #################################
110 109 ### Public module functions
111 110
112 111 def generate(ip=None):
113 112 '''Generates a new guid. A guid is unique in space and time because it combines
114 113 the machine IP with the current time in milliseconds. Be careful about sending in
115 114 a specified IP address because the ip makes it unique in space. You could send in
116 115 the same IP address that is created on another machine.
117 116 '''
118 117 global counter, firstcounter, lasttime
119 118 lock.acquire() # can't generate two guids at the same time
120 119 try:
121 120 parts = []
122 121
123 122 # do we need to wait for the next millisecond (are we out of counters?)
124 123 now = long(time.time() * 1000)
125 124 while lasttime == now and counter == firstcounter:
126 125 time.sleep(.01)
127 126 now = long(time.time() * 1000)
128 127
129 128 # time part
130 129 parts.append("%016x" % now)
131 130
132 131 # counter part
133 132 if lasttime != now: # time to start counter over since we have a different millisecond
134 133 firstcounter = long(random.uniform(1, MAX_COUNTER)) # start at random position
135 134 counter = firstcounter
136 135 counter += 1
137 136 if counter > MAX_COUNTER:
138 137 counter = 0
139 138 lasttime = now
140 139 parts.append("%08x" % (counter))
141 140
142 141 # ip part
143 142 parts.append(hexip)
144 143
145 144 # put them all together
146 145 return ''.join(parts)
147 146 finally:
148 147 lock.release()
149 148
150 149
151 150 def extract_time(guid):
152 151 '''Extracts the time portion out of the guid and returns the
153 152 number of seconds since the epoch as a float'''
154 153 return float(long(guid[0:16], 16)) / 1000.0
155 154
156 155
157 156 def extract_counter(guid):
158 157 '''Extracts the counter from the guid (returns the bits in decimal)'''
159 158 return int(guid[16:24], 16)
160 159
161 160
162 161 def extract_ip(guid):
163 162 '''Extracts the ip portion out of the guid and returns it
164 163 as a string like 10.10.10.10'''
165 164 # there's probably a more elegant way to do this
166 165 thisip = []
167 166 for index in range(24, 40, 4):
168 167 thisip.append(str(int(guid[index: index + 4], 16)))
169 168 return '.'.join(thisip)
170 169
@@ -1,90 +1,88 b''
1 #!/usr/bin/env python
2
3 1 #
4 2 # This file is adapted from a paramiko demo, and thus licensed under LGPL 2.1.
5 3 # Original Copyright (C) 2003-2007 Robey Pointer <robeypointer@gmail.com>
6 4 # Edits Copyright (C) 2010 The IPython Team
7 5 #
8 6 # Paramiko is free software; you can redistribute it and/or modify it under the
9 7 # terms of the GNU Lesser General Public License as published by the Free
10 8 # Software Foundation; either version 2.1 of the License, or (at your option)
11 9 # any later version.
12 10 #
13 11 # Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY
14 12 # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
15 13 # A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
16 14 # details.
17 15 #
18 16 # You should have received a copy of the GNU Lesser General Public License
19 17 # along with Paramiko; if not, write to the Free Software Foundation, Inc.,
20 18 # 51 Franklin Street, Fifth Floor, Boston, MA 02111-1301 USA.
21 19
22 20 """
23 21 Sample script showing how to do local port forwarding over paramiko.
24 22
25 23 This script connects to the requested SSH server and sets up local port
26 24 forwarding (the openssh -L option) from a local port through a tunneled
27 25 connection to a destination reachable from the SSH server machine.
28 26 """
29 27
30 28 from __future__ import print_function
31 29
32 30 import logging
33 31 import select
34 32 import SocketServer
35 33
36 34 logger = logging.getLogger('ssh')
37 35
38 36 class ForwardServer (SocketServer.ThreadingTCPServer):
39 37 daemon_threads = True
40 38 allow_reuse_address = True
41 39
42 40
43 41 class Handler (SocketServer.BaseRequestHandler):
44 42
45 43 def handle(self):
46 44 try:
47 45 chan = self.ssh_transport.open_channel('direct-tcpip',
48 46 (self.chain_host, self.chain_port),
49 47 self.request.getpeername())
50 48 except Exception, e:
51 49 logger.debug('Incoming request to %s:%d failed: %s' % (self.chain_host,
52 50 self.chain_port,
53 51 repr(e)))
54 52 return
55 53 if chan is None:
56 54 logger.debug('Incoming request to %s:%d was rejected by the SSH server.' %
57 55 (self.chain_host, self.chain_port))
58 56 return
59 57
60 58 logger.debug('Connected! Tunnel open %r -> %r -> %r' % (self.request.getpeername(),
61 59 chan.getpeername(), (self.chain_host, self.chain_port)))
62 60 while True:
63 61 r, w, x = select.select([self.request, chan], [], [])
64 62 if self.request in r:
65 63 data = self.request.recv(1024)
66 64 if len(data) == 0:
67 65 break
68 66 chan.send(data)
69 67 if chan in r:
70 68 data = chan.recv(1024)
71 69 if len(data) == 0:
72 70 break
73 71 self.request.send(data)
74 72 chan.close()
75 73 self.request.close()
76 74 logger.debug('Tunnel closed ')
77 75
78 76
79 77 def forward_tunnel(local_port, remote_host, remote_port, transport):
80 78 # this is a little convoluted, but lets me configure things for the Handler
81 79 # object. (SocketServer doesn't give Handlers any way to access the outer
82 80 # server normally.)
83 81 class SubHander (Handler):
84 82 chain_host = remote_host
85 83 chain_port = remote_port
86 84 ssh_transport = transport
87 85 ForwardServer(('127.0.0.1', local_port), SubHander).serve_forever()
88 86
89 87
90 88 __all__ = ['forward_tunnel']
@@ -1,29 +1,28 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 Extra capabilities for IPython
5 4 """
6 5
7 6 #-----------------------------------------------------------------------------
8 7 # Copyright (C) 2008-2009 The IPython Development Team
9 8 #
10 9 # Distributed under the terms of the BSD License. The full license is in
11 10 # the file COPYING, distributed as part of this software.
12 11 #-----------------------------------------------------------------------------
13 12
14 13 #-----------------------------------------------------------------------------
15 14 # Imports
16 15 #-----------------------------------------------------------------------------
17 16
18 17 from IPython.lib.inputhook import (
19 18 enable_wx, disable_wx,
20 19 enable_gtk, disable_gtk,
21 20 enable_qt4, disable_qt4,
22 21 enable_tk, disable_tk,
23 22 set_inputhook, clear_inputhook,
24 23 current_gui
25 24 )
26 25
27 26 #-----------------------------------------------------------------------------
28 27 # Code
29 28 #-----------------------------------------------------------------------------
@@ -1,147 +1,146 b''
1 #!/usr/bin/env python
2 1 # coding: utf-8
3 2 """
4 3 Support for creating GUI apps and starting event loops.
5 4
6 5 IPython's GUI integration allows interative plotting and GUI usage in IPython
7 6 session. IPython has two different types of GUI integration:
8 7
9 8 1. The terminal based IPython supports GUI event loops through Python's
10 9 PyOS_InputHook. PyOS_InputHook is a hook that Python calls periodically
11 10 whenever raw_input is waiting for a user to type code. We implement GUI
12 11 support in the terminal by setting PyOS_InputHook to a function that
13 12 iterates the event loop for a short while. It is important to note that
14 13 in this situation, the real GUI event loop is NOT run in the normal
15 14 manner, so you can't use the normal means to detect that it is running.
16 15 2. In the two process IPython kernel/frontend, the GUI event loop is run in
17 16 the kernel. In this case, the event loop is run in the normal manner by
18 17 calling the function or method of the GUI toolkit that starts the event
19 18 loop.
20 19
21 20 In addition to starting the GUI event loops in one of these two ways, IPython
22 21 will *always* create an appropriate GUI application object when GUi
23 22 integration is enabled.
24 23
25 24 If you want your GUI apps to run in IPython you need to do two things:
26 25
27 26 1. Test to see if there is already an existing main application object. If
28 27 there is, you should use it. If there is not an existing application object
29 28 you should create one.
30 29 2. Test to see if the GUI event loop is running. If it is, you should not
31 30 start it. If the event loop is not running you may start it.
32 31
33 32 This module contains functions for each toolkit that perform these things
34 33 in a consistent manner. Because of how PyOS_InputHook runs the event loop
35 34 you cannot detect if the event loop is running using the traditional calls
36 35 (such as ``wx.GetApp.IsMainLoopRunning()`` in wxPython). If PyOS_InputHook is
37 36 set These methods will return a false negative. That is, they will say the
38 37 event loop is not running, when is actually is. To work around this limitation
39 38 we proposed the following informal protocol:
40 39
41 40 * Whenever someone starts the event loop, they *must* set the ``_in_event_loop``
42 41 attribute of the main application object to ``True``. This should be done
43 42 regardless of how the event loop is actually run.
44 43 * Whenever someone stops the event loop, they *must* set the ``_in_event_loop``
45 44 attribute of the main application object to ``False``.
46 45 * If you want to see if the event loop is running, you *must* use ``hasattr``
47 46 to see if ``_in_event_loop`` attribute has been set. If it is set, you
48 47 *must* use its value. If it has not been set, you can query the toolkit
49 48 in the normal manner.
50 49 * If you want GUI support and no one else has created an application or
51 50 started the event loop you *must* do this. We don't want projects to
52 51 attempt to defer these things to someone else if they themselves need it.
53 52
54 53 The functions below implement this logic for each GUI toolkit. If you need
55 54 to create custom application subclasses, you will likely have to modify this
56 55 code for your own purposes. This code can be copied into your own project
57 56 so you don't have to depend on IPython.
58 57
59 58 """
60 59
61 60 #-----------------------------------------------------------------------------
62 61 # Copyright (C) 2008-2010 The IPython Development Team
63 62 #
64 63 # Distributed under the terms of the BSD License. The full license is in
65 64 # the file COPYING, distributed as part of this software.
66 65 #-----------------------------------------------------------------------------
67 66
68 67 #-----------------------------------------------------------------------------
69 68 # Imports
70 69 #-----------------------------------------------------------------------------
71 70
72 71 #-----------------------------------------------------------------------------
73 72 # wx
74 73 #-----------------------------------------------------------------------------
75 74
76 75 def get_app_wx(*args, **kwargs):
77 76 """Create a new wx app or return an exiting one."""
78 77 import wx
79 78 app = wx.GetApp()
80 79 if app is None:
81 80 if not kwargs.has_key('redirect'):
82 81 kwargs['redirect'] = False
83 82 app = wx.PySimpleApp(*args, **kwargs)
84 83 return app
85 84
86 85 def is_event_loop_running_wx(app=None):
87 86 """Is the wx event loop running."""
88 87 if app is None:
89 88 app = get_app_wx()
90 89 if hasattr(app, '_in_event_loop'):
91 90 return app._in_event_loop
92 91 else:
93 92 return app.IsMainLoopRunning()
94 93
95 94 def start_event_loop_wx(app=None):
96 95 """Start the wx event loop in a consistent manner."""
97 96 if app is None:
98 97 app = get_app_wx()
99 98 if not is_event_loop_running_wx(app):
100 99 app._in_event_loop = True
101 100 app.MainLoop()
102 101 app._in_event_loop = False
103 102 else:
104 103 app._in_event_loop = True
105 104
106 105 #-----------------------------------------------------------------------------
107 106 # qt4
108 107 #-----------------------------------------------------------------------------
109 108
110 109 def get_app_qt4(*args, **kwargs):
111 110 """Create a new qt4 app or return an existing one."""
112 111 from IPython.external.qt_for_kernel import QtGui
113 112 app = QtGui.QApplication.instance()
114 113 if app is None:
115 114 if not args:
116 115 args = ([''],)
117 116 app = QtGui.QApplication(*args, **kwargs)
118 117 return app
119 118
120 119 def is_event_loop_running_qt4(app=None):
121 120 """Is the qt4 event loop running."""
122 121 if app is None:
123 122 app = get_app_qt4([''])
124 123 if hasattr(app, '_in_event_loop'):
125 124 return app._in_event_loop
126 125 else:
127 126 # Does qt4 provide a other way to detect this?
128 127 return False
129 128
130 129 def start_event_loop_qt4(app=None):
131 130 """Start the qt4 event loop in a consistent manner."""
132 131 if app is None:
133 132 app = get_app_qt4([''])
134 133 if not is_event_loop_running_qt4(app):
135 134 app._in_event_loop = True
136 135 app.exec_()
137 136 app._in_event_loop = False
138 137 else:
139 138 app._in_event_loop = True
140 139
141 140 #-----------------------------------------------------------------------------
142 141 # Tk
143 142 #-----------------------------------------------------------------------------
144 143
145 144 #-----------------------------------------------------------------------------
146 145 # gtk
147 146 #-----------------------------------------------------------------------------
@@ -1,345 +1,344 b''
1 #!/usr/bin/env python
2 1 # coding: utf-8
3 2 """
4 3 Inputhook management for GUI event loop integration.
5 4 """
6 5
7 6 #-----------------------------------------------------------------------------
8 7 # Copyright (C) 2008-2009 The IPython Development Team
9 8 #
10 9 # Distributed under the terms of the BSD License. The full license is in
11 10 # the file COPYING, distributed as part of this software.
12 11 #-----------------------------------------------------------------------------
13 12
14 13 #-----------------------------------------------------------------------------
15 14 # Imports
16 15 #-----------------------------------------------------------------------------
17 16
18 17 import ctypes
19 18 import sys
20 19 import warnings
21 20
22 21 #-----------------------------------------------------------------------------
23 22 # Constants
24 23 #-----------------------------------------------------------------------------
25 24
26 25 # Constants for identifying the GUI toolkits.
27 26 GUI_WX = 'wx'
28 27 GUI_QT = 'qt'
29 28 GUI_QT4 = 'qt4'
30 29 GUI_GTK = 'gtk'
31 30 GUI_TK = 'tk'
32 31 GUI_OSX = 'osx'
33 32
34 33 #-----------------------------------------------------------------------------
35 34 # Utility classes
36 35 #-----------------------------------------------------------------------------
37 36
38 37
39 38 #-----------------------------------------------------------------------------
40 39 # Main InputHookManager class
41 40 #-----------------------------------------------------------------------------
42 41
43 42
44 43 class InputHookManager(object):
45 44 """Manage PyOS_InputHook for different GUI toolkits.
46 45
47 46 This class installs various hooks under ``PyOSInputHook`` to handle
48 47 GUI event loop integration.
49 48 """
50 49
51 50 def __init__(self):
52 51 self.PYFUNC = ctypes.PYFUNCTYPE(ctypes.c_int)
53 52 self._apps = {}
54 53 self._reset()
55 54
56 55 def _reset(self):
57 56 self._callback_pyfunctype = None
58 57 self._callback = None
59 58 self._installed = False
60 59 self._current_gui = None
61 60
62 61 def get_pyos_inputhook(self):
63 62 """Return the current PyOS_InputHook as a ctypes.c_void_p."""
64 63 return ctypes.c_void_p.in_dll(ctypes.pythonapi,"PyOS_InputHook")
65 64
66 65 def get_pyos_inputhook_as_func(self):
67 66 """Return the current PyOS_InputHook as a ctypes.PYFUNCYPE."""
68 67 return self.PYFUNC.in_dll(ctypes.pythonapi,"PyOS_InputHook")
69 68
70 69 def set_inputhook(self, callback):
71 70 """Set PyOS_InputHook to callback and return the previous one."""
72 71 self._callback = callback
73 72 self._callback_pyfunctype = self.PYFUNC(callback)
74 73 pyos_inputhook_ptr = self.get_pyos_inputhook()
75 74 original = self.get_pyos_inputhook_as_func()
76 75 pyos_inputhook_ptr.value = \
77 76 ctypes.cast(self._callback_pyfunctype, ctypes.c_void_p).value
78 77 self._installed = True
79 78 return original
80 79
81 80 def clear_inputhook(self, app=None):
82 81 """Set PyOS_InputHook to NULL and return the previous one.
83 82
84 83 Parameters
85 84 ----------
86 85 app : optional, ignored
87 86 This parameter is allowed only so that clear_inputhook() can be
88 87 called with a similar interface as all the ``enable_*`` methods. But
89 88 the actual value of the parameter is ignored. This uniform interface
90 89 makes it easier to have user-level entry points in the main IPython
91 90 app like :meth:`enable_gui`."""
92 91 pyos_inputhook_ptr = self.get_pyos_inputhook()
93 92 original = self.get_pyos_inputhook_as_func()
94 93 pyos_inputhook_ptr.value = ctypes.c_void_p(None).value
95 94 self._reset()
96 95 return original
97 96
98 97 def clear_app_refs(self, gui=None):
99 98 """Clear IPython's internal reference to an application instance.
100 99
101 100 Whenever we create an app for a user on qt4 or wx, we hold a
102 101 reference to the app. This is needed because in some cases bad things
103 102 can happen if a user doesn't hold a reference themselves. This
104 103 method is provided to clear the references we are holding.
105 104
106 105 Parameters
107 106 ----------
108 107 gui : None or str
109 108 If None, clear all app references. If ('wx', 'qt4') clear
110 109 the app for that toolkit. References are not held for gtk or tk
111 110 as those toolkits don't have the notion of an app.
112 111 """
113 112 if gui is None:
114 113 self._apps = {}
115 114 elif self._apps.has_key(gui):
116 115 del self._apps[gui]
117 116
118 117 def enable_wx(self, app=None):
119 118 """Enable event loop integration with wxPython.
120 119
121 120 Parameters
122 121 ----------
123 122 app : WX Application, optional.
124 123 Running application to use. If not given, we probe WX for an
125 124 existing application object, and create a new one if none is found.
126 125
127 126 Notes
128 127 -----
129 128 This methods sets the ``PyOS_InputHook`` for wxPython, which allows
130 129 the wxPython to integrate with terminal based applications like
131 130 IPython.
132 131
133 132 If ``app`` is not given we probe for an existing one, and return it if
134 133 found. If no existing app is found, we create an :class:`wx.App` as
135 134 follows::
136 135
137 136 import wx
138 137 app = wx.App(redirect=False, clearSigInt=False)
139 138 """
140 139 from IPython.lib.inputhookwx import inputhook_wx
141 140 self.set_inputhook(inputhook_wx)
142 141 self._current_gui = GUI_WX
143 142 import wx
144 143 if app is None:
145 144 app = wx.GetApp()
146 145 if app is None:
147 146 app = wx.App(redirect=False, clearSigInt=False)
148 147 app._in_event_loop = True
149 148 self._apps[GUI_WX] = app
150 149 return app
151 150
152 151 def disable_wx(self):
153 152 """Disable event loop integration with wxPython.
154 153
155 154 This merely sets PyOS_InputHook to NULL.
156 155 """
157 156 if self._apps.has_key(GUI_WX):
158 157 self._apps[GUI_WX]._in_event_loop = False
159 158 self.clear_inputhook()
160 159
161 160 def enable_qt4(self, app=None):
162 161 """Enable event loop integration with PyQt4.
163 162
164 163 Parameters
165 164 ----------
166 165 app : Qt Application, optional.
167 166 Running application to use. If not given, we probe Qt for an
168 167 existing application object, and create a new one if none is found.
169 168
170 169 Notes
171 170 -----
172 171 This methods sets the PyOS_InputHook for PyQt4, which allows
173 172 the PyQt4 to integrate with terminal based applications like
174 173 IPython.
175 174
176 175 If ``app`` is not given we probe for an existing one, and return it if
177 176 found. If no existing app is found, we create an :class:`QApplication`
178 177 as follows::
179 178
180 179 from PyQt4 import QtCore
181 180 app = QtGui.QApplication(sys.argv)
182 181 """
183 182 from IPython.external.qt_for_kernel import QtCore, QtGui
184 183
185 184 if 'pyreadline' in sys.modules:
186 185 # see IPython GitHub Issue #281 for more info on this issue
187 186 # Similar intermittent behavior has been reported on OSX,
188 187 # but not consistently reproducible
189 188 warnings.warn("""PyReadline's inputhook can conflict with Qt, causing delays
190 189 in interactive input. If you do see this issue, we recommend using another GUI
191 190 toolkit if you can, or disable readline with the configuration option
192 191 'TerminalInteractiveShell.readline_use=False', specified in a config file or
193 192 at the command-line""",
194 193 RuntimeWarning)
195 194
196 195 # PyQt4 has had this since 4.3.1. In version 4.2, PyOS_InputHook
197 196 # was set when QtCore was imported, but if it ever got removed,
198 197 # you couldn't reset it. For earlier versions we can
199 198 # probably implement a ctypes version.
200 199 try:
201 200 QtCore.pyqtRestoreInputHook()
202 201 except AttributeError:
203 202 pass
204 203
205 204 self._current_gui = GUI_QT4
206 205 if app is None:
207 206 app = QtCore.QCoreApplication.instance()
208 207 if app is None:
209 208 app = QtGui.QApplication([" "])
210 209 app._in_event_loop = True
211 210 self._apps[GUI_QT4] = app
212 211 return app
213 212
214 213 def disable_qt4(self):
215 214 """Disable event loop integration with PyQt4.
216 215
217 216 This merely sets PyOS_InputHook to NULL.
218 217 """
219 218 if self._apps.has_key(GUI_QT4):
220 219 self._apps[GUI_QT4]._in_event_loop = False
221 220 self.clear_inputhook()
222 221
223 222 def enable_gtk(self, app=None):
224 223 """Enable event loop integration with PyGTK.
225 224
226 225 Parameters
227 226 ----------
228 227 app : ignored
229 228 Ignored, it's only a placeholder to keep the call signature of all
230 229 gui activation methods consistent, which simplifies the logic of
231 230 supporting magics.
232 231
233 232 Notes
234 233 -----
235 234 This methods sets the PyOS_InputHook for PyGTK, which allows
236 235 the PyGTK to integrate with terminal based applications like
237 236 IPython.
238 237 """
239 238 import gtk
240 239 try:
241 240 gtk.set_interactive(True)
242 241 self._current_gui = GUI_GTK
243 242 except AttributeError:
244 243 # For older versions of gtk, use our own ctypes version
245 244 from IPython.lib.inputhookgtk import inputhook_gtk
246 245 self.set_inputhook(inputhook_gtk)
247 246 self._current_gui = GUI_GTK
248 247
249 248 def disable_gtk(self):
250 249 """Disable event loop integration with PyGTK.
251 250
252 251 This merely sets PyOS_InputHook to NULL.
253 252 """
254 253 self.clear_inputhook()
255 254
256 255 def enable_tk(self, app=None):
257 256 """Enable event loop integration with Tk.
258 257
259 258 Parameters
260 259 ----------
261 260 app : toplevel :class:`Tkinter.Tk` widget, optional.
262 261 Running toplevel widget to use. If not given, we probe Tk for an
263 262 existing one, and create a new one if none is found.
264 263
265 264 Notes
266 265 -----
267 266 If you have already created a :class:`Tkinter.Tk` object, the only
268 267 thing done by this method is to register with the
269 268 :class:`InputHookManager`, since creating that object automatically
270 269 sets ``PyOS_InputHook``.
271 270 """
272 271 self._current_gui = GUI_TK
273 272 if app is None:
274 273 import Tkinter
275 274 app = Tkinter.Tk()
276 275 app.withdraw()
277 276 self._apps[GUI_TK] = app
278 277 return app
279 278
280 279 def disable_tk(self):
281 280 """Disable event loop integration with Tkinter.
282 281
283 282 This merely sets PyOS_InputHook to NULL.
284 283 """
285 284 self.clear_inputhook()
286 285
287 286 def current_gui(self):
288 287 """Return a string indicating the currently active GUI or None."""
289 288 return self._current_gui
290 289
291 290 inputhook_manager = InputHookManager()
292 291
293 292 enable_wx = inputhook_manager.enable_wx
294 293 disable_wx = inputhook_manager.disable_wx
295 294 enable_qt4 = inputhook_manager.enable_qt4
296 295 disable_qt4 = inputhook_manager.disable_qt4
297 296 enable_gtk = inputhook_manager.enable_gtk
298 297 disable_gtk = inputhook_manager.disable_gtk
299 298 enable_tk = inputhook_manager.enable_tk
300 299 disable_tk = inputhook_manager.disable_tk
301 300 clear_inputhook = inputhook_manager.clear_inputhook
302 301 set_inputhook = inputhook_manager.set_inputhook
303 302 current_gui = inputhook_manager.current_gui
304 303 clear_app_refs = inputhook_manager.clear_app_refs
305 304
306 305
307 306 # Convenience function to switch amongst them
308 307 def enable_gui(gui=None, app=None):
309 308 """Switch amongst GUI input hooks by name.
310 309
311 310 This is just a utility wrapper around the methods of the InputHookManager
312 311 object.
313 312
314 313 Parameters
315 314 ----------
316 315 gui : optional, string or None
317 316 If None, clears input hook, otherwise it must be one of the recognized
318 317 GUI names (see ``GUI_*`` constants in module).
319 318
320 319 app : optional, existing application object.
321 320 For toolkits that have the concept of a global app, you can supply an
322 321 existing one. If not given, the toolkit will be probed for one, and if
323 322 none is found, a new one will be created. Note that GTK does not have
324 323 this concept, and passing an app if `gui`=="GTK" will raise an error.
325 324
326 325 Returns
327 326 -------
328 327 The output of the underlying gui switch routine, typically the actual
329 328 PyOS_InputHook wrapper object or the GUI toolkit app created, if there was
330 329 one.
331 330 """
332 331 guis = {None: clear_inputhook,
333 332 GUI_OSX: lambda app=False: None,
334 333 GUI_TK: enable_tk,
335 334 GUI_GTK: enable_gtk,
336 335 GUI_WX: enable_wx,
337 336 GUI_QT: enable_qt4, # qt3 not supported
338 337 GUI_QT4: enable_qt4 }
339 338 try:
340 339 gui_hook = guis[gui]
341 340 except KeyError:
342 341 e = "Invalid GUI request %r, valid ones are:%s" % (gui, guis.keys())
343 342 raise ValueError(e)
344 343 return gui_hook(app)
345 344
@@ -1,36 +1,35 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 Enable pygtk to be used interacive by setting PyOS_InputHook.
5 4
6 5 Authors: Brian Granger
7 6 """
8 7
9 8 #-----------------------------------------------------------------------------
10 9 # Copyright (C) 2008-2009 The IPython Development Team
11 10 #
12 11 # Distributed under the terms of the BSD License. The full license is in
13 12 # the file COPYING, distributed as part of this software.
14 13 #-----------------------------------------------------------------------------
15 14
16 15 #-----------------------------------------------------------------------------
17 16 # Imports
18 17 #-----------------------------------------------------------------------------
19 18
20 19 import sys
21 20 import gtk, gobject
22 21
23 22 #-----------------------------------------------------------------------------
24 23 # Code
25 24 #-----------------------------------------------------------------------------
26 25
27 26
28 27 def _main_quit(*args, **kwargs):
29 28 gtk.main_quit()
30 29 return False
31 30
32 31 def inputhook_gtk():
33 32 gobject.io_add_watch(sys.stdin, gobject.IO_IN, _main_quit)
34 33 gtk.main()
35 34 return 0
36 35
@@ -1,179 +1,178 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2
4 3 """
5 4 Enable wxPython to be used interacive by setting PyOS_InputHook.
6 5
7 6 Authors: Robin Dunn, Brian Granger, Ondrej Certik
8 7 """
9 8
10 9 #-----------------------------------------------------------------------------
11 10 # Copyright (C) 2008-2009 The IPython Development Team
12 11 #
13 12 # Distributed under the terms of the BSD License. The full license is in
14 13 # the file COPYING, distributed as part of this software.
15 14 #-----------------------------------------------------------------------------
16 15
17 16 #-----------------------------------------------------------------------------
18 17 # Imports
19 18 #-----------------------------------------------------------------------------
20 19
21 20 import os
22 21 import signal
23 22 import sys
24 23 import time
25 24 from timeit import default_timer as clock
26 25 import wx
27 26
28 27 if os.name == 'posix':
29 28 import select
30 29 elif sys.platform == 'win32':
31 30 import msvcrt
32 31
33 32 #-----------------------------------------------------------------------------
34 33 # Code
35 34 #-----------------------------------------------------------------------------
36 35
37 36 def stdin_ready():
38 37 if os.name == 'posix':
39 38 infds, outfds, erfds = select.select([sys.stdin],[],[],0)
40 39 if infds:
41 40 return True
42 41 else:
43 42 return False
44 43 elif sys.platform == 'win32':
45 44 return msvcrt.kbhit()
46 45
47 46
48 47 def inputhook_wx1():
49 48 """Run the wx event loop by processing pending events only.
50 49
51 50 This approach seems to work, but its performance is not great as it
52 51 relies on having PyOS_InputHook called regularly.
53 52 """
54 53 try:
55 54 app = wx.GetApp()
56 55 if app is not None:
57 56 assert wx.Thread_IsMain()
58 57
59 58 # Make a temporary event loop and process system events until
60 59 # there are no more waiting, then allow idle events (which
61 60 # will also deal with pending or posted wx events.)
62 61 evtloop = wx.EventLoop()
63 62 ea = wx.EventLoopActivator(evtloop)
64 63 while evtloop.Pending():
65 64 evtloop.Dispatch()
66 65 app.ProcessIdle()
67 66 del ea
68 67 except KeyboardInterrupt:
69 68 pass
70 69 return 0
71 70
72 71 class EventLoopTimer(wx.Timer):
73 72
74 73 def __init__(self, func):
75 74 self.func = func
76 75 wx.Timer.__init__(self)
77 76
78 77 def Notify(self):
79 78 self.func()
80 79
81 80 class EventLoopRunner(object):
82 81
83 82 def Run(self, time):
84 83 self.evtloop = wx.EventLoop()
85 84 self.timer = EventLoopTimer(self.check_stdin)
86 85 self.timer.Start(time)
87 86 self.evtloop.Run()
88 87
89 88 def check_stdin(self):
90 89 if stdin_ready():
91 90 self.timer.Stop()
92 91 self.evtloop.Exit()
93 92
94 93 def inputhook_wx2():
95 94 """Run the wx event loop, polling for stdin.
96 95
97 96 This version runs the wx eventloop for an undetermined amount of time,
98 97 during which it periodically checks to see if anything is ready on
99 98 stdin. If anything is ready on stdin, the event loop exits.
100 99
101 100 The argument to elr.Run controls how often the event loop looks at stdin.
102 101 This determines the responsiveness at the keyboard. A setting of 1000
103 102 enables a user to type at most 1 char per second. I have found that a
104 103 setting of 10 gives good keyboard response. We can shorten it further,
105 104 but eventually performance would suffer from calling select/kbhit too
106 105 often.
107 106 """
108 107 try:
109 108 app = wx.GetApp()
110 109 if app is not None:
111 110 assert wx.Thread_IsMain()
112 111 elr = EventLoopRunner()
113 112 # As this time is made shorter, keyboard response improves, but idle
114 113 # CPU load goes up. 10 ms seems like a good compromise.
115 114 elr.Run(time=10) # CHANGE time here to control polling interval
116 115 except KeyboardInterrupt:
117 116 pass
118 117 return 0
119 118
120 119 def inputhook_wx3():
121 120 """Run the wx event loop by processing pending events only.
122 121
123 122 This is like inputhook_wx1, but it keeps processing pending events
124 123 until stdin is ready. After processing all pending events, a call to
125 124 time.sleep is inserted. This is needed, otherwise, CPU usage is at 100%.
126 125 This sleep time should be tuned though for best performance.
127 126 """
128 127 # We need to protect against a user pressing Control-C when IPython is
129 128 # idle and this is running. We trap KeyboardInterrupt and pass.
130 129 try:
131 130 app = wx.GetApp()
132 131 if app is not None:
133 132 assert wx.Thread_IsMain()
134 133
135 134 # The import of wx on Linux sets the handler for signal.SIGINT
136 135 # to 0. This is a bug in wx or gtk. We fix by just setting it
137 136 # back to the Python default.
138 137 if not callable(signal.getsignal(signal.SIGINT)):
139 138 signal.signal(signal.SIGINT, signal.default_int_handler)
140 139
141 140 evtloop = wx.EventLoop()
142 141 ea = wx.EventLoopActivator(evtloop)
143 142 t = clock()
144 143 while not stdin_ready():
145 144 while evtloop.Pending():
146 145 t = clock()
147 146 evtloop.Dispatch()
148 147 app.ProcessIdle()
149 148 # We need to sleep at this point to keep the idle CPU load
150 149 # low. However, if sleep to long, GUI response is poor. As
151 150 # a compromise, we watch how often GUI events are being processed
152 151 # and switch between a short and long sleep time. Here are some
153 152 # stats useful in helping to tune this.
154 153 # time CPU load
155 154 # 0.001 13%
156 155 # 0.005 3%
157 156 # 0.01 1.5%
158 157 # 0.05 0.5%
159 158 used_time = clock() - t
160 159 if used_time > 5*60.0:
161 160 # print 'Sleep for 5 s' # dbg
162 161 time.sleep(5.0)
163 162 elif used_time > 10.0:
164 163 # print 'Sleep for 1 s' # dbg
165 164 time.sleep(1.0)
166 165 elif used_time > 0.1:
167 166 # Few GUI events coming in, so we can sleep longer
168 167 # print 'Sleep for 0.05 s' # dbg
169 168 time.sleep(0.05)
170 169 else:
171 170 # Many GUI events coming in, so sleep only very little
172 171 time.sleep(0.001)
173 172 del ea
174 173 except KeyboardInterrupt:
175 174 pass
176 175 return 0
177 176
178 177 # This is our default implementation
179 178 inputhook_wx = inputhook_wx3
@@ -1,14 +1,13 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2
4 3 def test_import_backgroundjobs():
5 4 from IPython.lib import backgroundjobs
6 5
7 6 def test_import_deepreload():
8 7 from IPython.lib import deepreload
9 8
10 9 def test_import_demo():
11 10 from IPython.lib import demo
12 11
13 12 def test_import_irunner():
14 13 from IPython.lib import demo
@@ -1,241 +1,240 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 The Base Application class for IPython.parallel apps
5 4
6 5 Authors:
7 6
8 7 * Brian Granger
9 8 * Min RK
10 9
11 10 """
12 11
13 12 #-----------------------------------------------------------------------------
14 13 # Copyright (C) 2008-2011 The IPython Development Team
15 14 #
16 15 # Distributed under the terms of the BSD License. The full license is in
17 16 # the file COPYING, distributed as part of this software.
18 17 #-----------------------------------------------------------------------------
19 18
20 19 #-----------------------------------------------------------------------------
21 20 # Imports
22 21 #-----------------------------------------------------------------------------
23 22
24 23 from __future__ import with_statement
25 24
26 25 import os
27 26 import logging
28 27 import re
29 28 import sys
30 29
31 30 from subprocess import Popen, PIPE
32 31
33 32 from IPython.core import release
34 33 from IPython.core.crashhandler import CrashHandler
35 34 from IPython.core.application import (
36 35 BaseIPythonApplication,
37 36 base_aliases as base_ip_aliases,
38 37 base_flags as base_ip_flags
39 38 )
40 39 from IPython.utils.path import expand_path
41 40
42 41 from IPython.utils.traitlets import Unicode, Bool, Instance, Dict, List
43 42
44 43 #-----------------------------------------------------------------------------
45 44 # Module errors
46 45 #-----------------------------------------------------------------------------
47 46
48 47 class PIDFileError(Exception):
49 48 pass
50 49
51 50
52 51 #-----------------------------------------------------------------------------
53 52 # Crash handler for this application
54 53 #-----------------------------------------------------------------------------
55 54
56 55 class ParallelCrashHandler(CrashHandler):
57 56 """sys.excepthook for IPython itself, leaves a detailed report on disk."""
58 57
59 58 def __init__(self, app):
60 59 contact_name = release.authors['Min'][0]
61 60 contact_email = release.authors['Min'][1]
62 61 bug_tracker = 'http://github.com/ipython/ipython/issues'
63 62 super(ParallelCrashHandler,self).__init__(
64 63 app, contact_name, contact_email, bug_tracker
65 64 )
66 65
67 66
68 67 #-----------------------------------------------------------------------------
69 68 # Main application
70 69 #-----------------------------------------------------------------------------
71 70 base_aliases = {}
72 71 base_aliases.update(base_ip_aliases)
73 72 base_aliases.update({
74 73 'profile-dir' : 'ProfileDir.location',
75 74 'work-dir' : 'BaseParallelApplication.work_dir',
76 75 'log-to-file' : 'BaseParallelApplication.log_to_file',
77 76 'clean-logs' : 'BaseParallelApplication.clean_logs',
78 77 'log-url' : 'BaseParallelApplication.log_url',
79 78 })
80 79
81 80 base_flags = {
82 81 'log-to-file' : (
83 82 {'BaseParallelApplication' : {'log_to_file' : True}},
84 83 "send log output to a file"
85 84 )
86 85 }
87 86 base_flags.update(base_ip_flags)
88 87
89 88 class BaseParallelApplication(BaseIPythonApplication):
90 89 """The base Application for IPython.parallel apps
91 90
92 91 Principle extensions to BaseIPyythonApplication:
93 92
94 93 * work_dir
95 94 * remote logging via pyzmq
96 95 * IOLoop instance
97 96 """
98 97
99 98 crash_handler_class = ParallelCrashHandler
100 99
101 100 def _log_level_default(self):
102 101 # temporarily override default_log_level to INFO
103 102 return logging.INFO
104 103
105 104 work_dir = Unicode(os.getcwdu(), config=True,
106 105 help='Set the working dir for the process.'
107 106 )
108 107 def _work_dir_changed(self, name, old, new):
109 108 self.work_dir = unicode(expand_path(new))
110 109
111 110 log_to_file = Bool(config=True,
112 111 help="whether to log to a file")
113 112
114 113 clean_logs = Bool(False, config=True,
115 114 help="whether to cleanup old logfiles before starting")
116 115
117 116 log_url = Unicode('', config=True,
118 117 help="The ZMQ URL of the iplogger to aggregate logging.")
119 118
120 119 def _config_files_default(self):
121 120 return ['ipcontroller_config.py', 'ipengine_config.py', 'ipcluster_config.py']
122 121
123 122 loop = Instance('zmq.eventloop.ioloop.IOLoop')
124 123 def _loop_default(self):
125 124 from zmq.eventloop.ioloop import IOLoop
126 125 return IOLoop.instance()
127 126
128 127 aliases = Dict(base_aliases)
129 128 flags = Dict(base_flags)
130 129
131 130 def initialize(self, argv=None):
132 131 """initialize the app"""
133 132 super(BaseParallelApplication, self).initialize(argv)
134 133 self.to_work_dir()
135 134 self.reinit_logging()
136 135
137 136 def to_work_dir(self):
138 137 wd = self.work_dir
139 138 if unicode(wd) != os.getcwdu():
140 139 os.chdir(wd)
141 140 self.log.info("Changing to working dir: %s" % wd)
142 141 # This is the working dir by now.
143 142 sys.path.insert(0, '')
144 143
145 144 def reinit_logging(self):
146 145 # Remove old log files
147 146 log_dir = self.profile_dir.log_dir
148 147 if self.clean_logs:
149 148 for f in os.listdir(log_dir):
150 149 if re.match(r'%s-\d+\.(log|err|out)'%self.name,f):
151 150 os.remove(os.path.join(log_dir, f))
152 151 if self.log_to_file:
153 152 # Start logging to the new log file
154 153 log_filename = self.name + u'-' + str(os.getpid()) + u'.log'
155 154 logfile = os.path.join(log_dir, log_filename)
156 155 open_log_file = open(logfile, 'w')
157 156 else:
158 157 open_log_file = None
159 158 if open_log_file is not None:
160 159 self.log.removeHandler(self._log_handler)
161 160 self._log_handler = logging.StreamHandler(open_log_file)
162 161 self._log_formatter = logging.Formatter("[%(name)s] %(message)s")
163 162 self._log_handler.setFormatter(self._log_formatter)
164 163 self.log.addHandler(self._log_handler)
165 164
166 165 def write_pid_file(self, overwrite=False):
167 166 """Create a .pid file in the pid_dir with my pid.
168 167
169 168 This must be called after pre_construct, which sets `self.pid_dir`.
170 169 This raises :exc:`PIDFileError` if the pid file exists already.
171 170 """
172 171 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
173 172 if os.path.isfile(pid_file):
174 173 pid = self.get_pid_from_file()
175 174 if not overwrite:
176 175 raise PIDFileError(
177 176 'The pid file [%s] already exists. \nThis could mean that this '
178 177 'server is already running with [pid=%s].' % (pid_file, pid)
179 178 )
180 179 with open(pid_file, 'w') as f:
181 180 self.log.info("Creating pid file: %s" % pid_file)
182 181 f.write(repr(os.getpid())+'\n')
183 182
184 183 def remove_pid_file(self):
185 184 """Remove the pid file.
186 185
187 186 This should be called at shutdown by registering a callback with
188 187 :func:`reactor.addSystemEventTrigger`. This needs to return
189 188 ``None``.
190 189 """
191 190 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
192 191 if os.path.isfile(pid_file):
193 192 try:
194 193 self.log.info("Removing pid file: %s" % pid_file)
195 194 os.remove(pid_file)
196 195 except:
197 196 self.log.warn("Error removing the pid file: %s" % pid_file)
198 197
199 198 def get_pid_from_file(self):
200 199 """Get the pid from the pid file.
201 200
202 201 If the pid file doesn't exist a :exc:`PIDFileError` is raised.
203 202 """
204 203 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
205 204 if os.path.isfile(pid_file):
206 205 with open(pid_file, 'r') as f:
207 206 s = f.read().strip()
208 207 try:
209 208 pid = int(s)
210 209 except:
211 210 raise PIDFileError("invalid pid file: %s (contents: %r)"%(pid_file, s))
212 211 return pid
213 212 else:
214 213 raise PIDFileError('pid file not found: %s' % pid_file)
215 214
216 215 def check_pid(self, pid):
217 216 if os.name == 'nt':
218 217 try:
219 218 import ctypes
220 219 # returns 0 if no such process (of ours) exists
221 220 # positive int otherwise
222 221 p = ctypes.windll.kernel32.OpenProcess(1,0,pid)
223 222 except Exception:
224 223 self.log.warn(
225 224 "Could not determine whether pid %i is running via `OpenProcess`. "
226 225 " Making the likely assumption that it is."%pid
227 226 )
228 227 return True
229 228 return bool(p)
230 229 else:
231 230 try:
232 231 p = Popen(['ps','x'], stdout=PIPE, stderr=PIPE)
233 232 output,_ = p.communicate()
234 233 except OSError:
235 234 self.log.warn(
236 235 "Could not determine whether pid %i is running via `ps x`. "
237 236 " Making the likely assumption that it is."%pid
238 237 )
239 238 return True
240 239 pids = map(int, re.findall(r'^\W*\d+', output, re.MULTILINE))
241 240 return pid in pids
@@ -1,1142 +1,1141 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 Facilities for launching IPython processes asynchronously.
5 4
6 5 Authors:
7 6
8 7 * Brian Granger
9 8 * MinRK
10 9 """
11 10
12 11 #-----------------------------------------------------------------------------
13 12 # Copyright (C) 2008-2011 The IPython Development Team
14 13 #
15 14 # Distributed under the terms of the BSD License. The full license is in
16 15 # the file COPYING, distributed as part of this software.
17 16 #-----------------------------------------------------------------------------
18 17
19 18 #-----------------------------------------------------------------------------
20 19 # Imports
21 20 #-----------------------------------------------------------------------------
22 21
23 22 import copy
24 23 import logging
25 24 import os
26 25 import re
27 26 import stat
28 27
29 28 # signal imports, handling various platforms, versions
30 29
31 30 from signal import SIGINT, SIGTERM
32 31 try:
33 32 from signal import SIGKILL
34 33 except ImportError:
35 34 # Windows
36 35 SIGKILL=SIGTERM
37 36
38 37 try:
39 38 # Windows >= 2.7, 3.2
40 39 from signal import CTRL_C_EVENT as SIGINT
41 40 except ImportError:
42 41 pass
43 42
44 43 from subprocess import Popen, PIPE, STDOUT
45 44 try:
46 45 from subprocess import check_output
47 46 except ImportError:
48 47 # pre-2.7, define check_output with Popen
49 48 def check_output(*args, **kwargs):
50 49 kwargs.update(dict(stdout=PIPE))
51 50 p = Popen(*args, **kwargs)
52 51 out,err = p.communicate()
53 52 return out
54 53
55 54 from zmq.eventloop import ioloop
56 55
57 56 from IPython.config.application import Application
58 57 from IPython.config.configurable import LoggingConfigurable
59 58 from IPython.utils.text import EvalFormatter
60 59 from IPython.utils.traitlets import Any, Int, List, Unicode, Dict, Instance
61 60 from IPython.utils.path import get_ipython_module_path
62 61 from IPython.utils.process import find_cmd, pycmd2argv, FindCmdError
63 62
64 63 from .win32support import forward_read_events
65 64
66 65 from .winhpcjob import IPControllerTask, IPEngineTask, IPControllerJob, IPEngineSetJob
67 66
68 67 WINDOWS = os.name == 'nt'
69 68
70 69 #-----------------------------------------------------------------------------
71 70 # Paths to the kernel apps
72 71 #-----------------------------------------------------------------------------
73 72
74 73
75 74 ipcluster_cmd_argv = pycmd2argv(get_ipython_module_path(
76 75 'IPython.parallel.apps.ipclusterapp'
77 76 ))
78 77
79 78 ipengine_cmd_argv = pycmd2argv(get_ipython_module_path(
80 79 'IPython.parallel.apps.ipengineapp'
81 80 ))
82 81
83 82 ipcontroller_cmd_argv = pycmd2argv(get_ipython_module_path(
84 83 'IPython.parallel.apps.ipcontrollerapp'
85 84 ))
86 85
87 86 #-----------------------------------------------------------------------------
88 87 # Base launchers and errors
89 88 #-----------------------------------------------------------------------------
90 89
91 90
92 91 class LauncherError(Exception):
93 92 pass
94 93
95 94
96 95 class ProcessStateError(LauncherError):
97 96 pass
98 97
99 98
100 99 class UnknownStatus(LauncherError):
101 100 pass
102 101
103 102
104 103 class BaseLauncher(LoggingConfigurable):
105 104 """An asbtraction for starting, stopping and signaling a process."""
106 105
107 106 # In all of the launchers, the work_dir is where child processes will be
108 107 # run. This will usually be the profile_dir, but may not be. any work_dir
109 108 # passed into the __init__ method will override the config value.
110 109 # This should not be used to set the work_dir for the actual engine
111 110 # and controller. Instead, use their own config files or the
112 111 # controller_args, engine_args attributes of the launchers to add
113 112 # the work_dir option.
114 113 work_dir = Unicode(u'.')
115 114 loop = Instance('zmq.eventloop.ioloop.IOLoop')
116 115
117 116 start_data = Any()
118 117 stop_data = Any()
119 118
120 119 def _loop_default(self):
121 120 return ioloop.IOLoop.instance()
122 121
123 122 def __init__(self, work_dir=u'.', config=None, **kwargs):
124 123 super(BaseLauncher, self).__init__(work_dir=work_dir, config=config, **kwargs)
125 124 self.state = 'before' # can be before, running, after
126 125 self.stop_callbacks = []
127 126 self.start_data = None
128 127 self.stop_data = None
129 128
130 129 @property
131 130 def args(self):
132 131 """A list of cmd and args that will be used to start the process.
133 132
134 133 This is what is passed to :func:`spawnProcess` and the first element
135 134 will be the process name.
136 135 """
137 136 return self.find_args()
138 137
139 138 def find_args(self):
140 139 """The ``.args`` property calls this to find the args list.
141 140
142 141 Subcommand should implement this to construct the cmd and args.
143 142 """
144 143 raise NotImplementedError('find_args must be implemented in a subclass')
145 144
146 145 @property
147 146 def arg_str(self):
148 147 """The string form of the program arguments."""
149 148 return ' '.join(self.args)
150 149
151 150 @property
152 151 def running(self):
153 152 """Am I running."""
154 153 if self.state == 'running':
155 154 return True
156 155 else:
157 156 return False
158 157
159 158 def start(self):
160 159 """Start the process."""
161 160 raise NotImplementedError('start must be implemented in a subclass')
162 161
163 162 def stop(self):
164 163 """Stop the process and notify observers of stopping.
165 164
166 165 This method will return None immediately.
167 166 To observe the actual process stopping, see :meth:`on_stop`.
168 167 """
169 168 raise NotImplementedError('stop must be implemented in a subclass')
170 169
171 170 def on_stop(self, f):
172 171 """Register a callback to be called with this Launcher's stop_data
173 172 when the process actually finishes.
174 173 """
175 174 if self.state=='after':
176 175 return f(self.stop_data)
177 176 else:
178 177 self.stop_callbacks.append(f)
179 178
180 179 def notify_start(self, data):
181 180 """Call this to trigger startup actions.
182 181
183 182 This logs the process startup and sets the state to 'running'. It is
184 183 a pass-through so it can be used as a callback.
185 184 """
186 185
187 186 self.log.info('Process %r started: %r' % (self.args[0], data))
188 187 self.start_data = data
189 188 self.state = 'running'
190 189 return data
191 190
192 191 def notify_stop(self, data):
193 192 """Call this to trigger process stop actions.
194 193
195 194 This logs the process stopping and sets the state to 'after'. Call
196 195 this to trigger callbacks registered via :meth:`on_stop`."""
197 196
198 197 self.log.info('Process %r stopped: %r' % (self.args[0], data))
199 198 self.stop_data = data
200 199 self.state = 'after'
201 200 for i in range(len(self.stop_callbacks)):
202 201 d = self.stop_callbacks.pop()
203 202 d(data)
204 203 return data
205 204
206 205 def signal(self, sig):
207 206 """Signal the process.
208 207
209 208 Parameters
210 209 ----------
211 210 sig : str or int
212 211 'KILL', 'INT', etc., or any signal number
213 212 """
214 213 raise NotImplementedError('signal must be implemented in a subclass')
215 214
216 215
217 216 #-----------------------------------------------------------------------------
218 217 # Local process launchers
219 218 #-----------------------------------------------------------------------------
220 219
221 220
222 221 class LocalProcessLauncher(BaseLauncher):
223 222 """Start and stop an external process in an asynchronous manner.
224 223
225 224 This will launch the external process with a working directory of
226 225 ``self.work_dir``.
227 226 """
228 227
229 228 # This is used to to construct self.args, which is passed to
230 229 # spawnProcess.
231 230 cmd_and_args = List([])
232 231 poll_frequency = Int(100) # in ms
233 232
234 233 def __init__(self, work_dir=u'.', config=None, **kwargs):
235 234 super(LocalProcessLauncher, self).__init__(
236 235 work_dir=work_dir, config=config, **kwargs
237 236 )
238 237 self.process = None
239 238 self.poller = None
240 239
241 240 def find_args(self):
242 241 return self.cmd_and_args
243 242
244 243 def start(self):
245 244 if self.state == 'before':
246 245 self.process = Popen(self.args,
247 246 stdout=PIPE,stderr=PIPE,stdin=PIPE,
248 247 env=os.environ,
249 248 cwd=self.work_dir
250 249 )
251 250 if WINDOWS:
252 251 self.stdout = forward_read_events(self.process.stdout)
253 252 self.stderr = forward_read_events(self.process.stderr)
254 253 else:
255 254 self.stdout = self.process.stdout.fileno()
256 255 self.stderr = self.process.stderr.fileno()
257 256 self.loop.add_handler(self.stdout, self.handle_stdout, self.loop.READ)
258 257 self.loop.add_handler(self.stderr, self.handle_stderr, self.loop.READ)
259 258 self.poller = ioloop.PeriodicCallback(self.poll, self.poll_frequency, self.loop)
260 259 self.poller.start()
261 260 self.notify_start(self.process.pid)
262 261 else:
263 262 s = 'The process was already started and has state: %r' % self.state
264 263 raise ProcessStateError(s)
265 264
266 265 def stop(self):
267 266 return self.interrupt_then_kill()
268 267
269 268 def signal(self, sig):
270 269 if self.state == 'running':
271 270 if WINDOWS and sig != SIGINT:
272 271 # use Windows tree-kill for better child cleanup
273 272 check_output(['taskkill', '-pid', str(self.process.pid), '-t', '-f'])
274 273 else:
275 274 self.process.send_signal(sig)
276 275
277 276 def interrupt_then_kill(self, delay=2.0):
278 277 """Send INT, wait a delay and then send KILL."""
279 278 try:
280 279 self.signal(SIGINT)
281 280 except Exception:
282 281 self.log.debug("interrupt failed")
283 282 pass
284 283 self.killer = ioloop.DelayedCallback(lambda : self.signal(SIGKILL), delay*1000, self.loop)
285 284 self.killer.start()
286 285
287 286 # callbacks, etc:
288 287
289 288 def handle_stdout(self, fd, events):
290 289 if WINDOWS:
291 290 line = self.stdout.recv()
292 291 else:
293 292 line = self.process.stdout.readline()
294 293 # a stopped process will be readable but return empty strings
295 294 if line:
296 295 self.log.info(line[:-1])
297 296 else:
298 297 self.poll()
299 298
300 299 def handle_stderr(self, fd, events):
301 300 if WINDOWS:
302 301 line = self.stderr.recv()
303 302 else:
304 303 line = self.process.stderr.readline()
305 304 # a stopped process will be readable but return empty strings
306 305 if line:
307 306 self.log.error(line[:-1])
308 307 else:
309 308 self.poll()
310 309
311 310 def poll(self):
312 311 status = self.process.poll()
313 312 if status is not None:
314 313 self.poller.stop()
315 314 self.loop.remove_handler(self.stdout)
316 315 self.loop.remove_handler(self.stderr)
317 316 self.notify_stop(dict(exit_code=status, pid=self.process.pid))
318 317 return status
319 318
320 319 class LocalControllerLauncher(LocalProcessLauncher):
321 320 """Launch a controller as a regular external process."""
322 321
323 322 controller_cmd = List(ipcontroller_cmd_argv, config=True,
324 323 help="""Popen command to launch ipcontroller.""")
325 324 # Command line arguments to ipcontroller.
326 325 controller_args = List(['--log-to-file','--log-level=%i'%logging.INFO], config=True,
327 326 help="""command-line args to pass to ipcontroller""")
328 327
329 328 def find_args(self):
330 329 return self.controller_cmd + self.controller_args
331 330
332 331 def start(self, profile_dir):
333 332 """Start the controller by profile_dir."""
334 333 self.controller_args.extend(['--profile-dir=%s'%profile_dir])
335 334 self.profile_dir = unicode(profile_dir)
336 335 self.log.info("Starting LocalControllerLauncher: %r" % self.args)
337 336 return super(LocalControllerLauncher, self).start()
338 337
339 338
340 339 class LocalEngineLauncher(LocalProcessLauncher):
341 340 """Launch a single engine as a regular externall process."""
342 341
343 342 engine_cmd = List(ipengine_cmd_argv, config=True,
344 343 help="""command to launch the Engine.""")
345 344 # Command line arguments for ipengine.
346 345 engine_args = List(['--log-to-file','--log-level=%i'%logging.INFO], config=True,
347 346 help="command-line arguments to pass to ipengine"
348 347 )
349 348
350 349 def find_args(self):
351 350 return self.engine_cmd + self.engine_args
352 351
353 352 def start(self, profile_dir):
354 353 """Start the engine by profile_dir."""
355 354 self.engine_args.extend(['--profile-dir=%s'%profile_dir])
356 355 self.profile_dir = unicode(profile_dir)
357 356 return super(LocalEngineLauncher, self).start()
358 357
359 358
360 359 class LocalEngineSetLauncher(BaseLauncher):
361 360 """Launch a set of engines as regular external processes."""
362 361
363 362 # Command line arguments for ipengine.
364 363 engine_args = List(
365 364 ['--log-to-file','--log-level=%i'%logging.INFO], config=True,
366 365 help="command-line arguments to pass to ipengine"
367 366 )
368 367 # launcher class
369 368 launcher_class = LocalEngineLauncher
370 369
371 370 launchers = Dict()
372 371 stop_data = Dict()
373 372
374 373 def __init__(self, work_dir=u'.', config=None, **kwargs):
375 374 super(LocalEngineSetLauncher, self).__init__(
376 375 work_dir=work_dir, config=config, **kwargs
377 376 )
378 377 self.stop_data = {}
379 378
380 379 def start(self, n, profile_dir):
381 380 """Start n engines by profile or profile_dir."""
382 381 self.profile_dir = unicode(profile_dir)
383 382 dlist = []
384 383 for i in range(n):
385 384 el = self.launcher_class(work_dir=self.work_dir, config=self.config, log=self.log)
386 385 # Copy the engine args over to each engine launcher.
387 386 el.engine_args = copy.deepcopy(self.engine_args)
388 387 el.on_stop(self._notice_engine_stopped)
389 388 d = el.start(profile_dir)
390 389 if i==0:
391 390 self.log.info("Starting LocalEngineSetLauncher: %r" % el.args)
392 391 self.launchers[i] = el
393 392 dlist.append(d)
394 393 self.notify_start(dlist)
395 394 # The consumeErrors here could be dangerous
396 395 # dfinal = gatherBoth(dlist, consumeErrors=True)
397 396 # dfinal.addCallback(self.notify_start)
398 397 return dlist
399 398
400 399 def find_args(self):
401 400 return ['engine set']
402 401
403 402 def signal(self, sig):
404 403 dlist = []
405 404 for el in self.launchers.itervalues():
406 405 d = el.signal(sig)
407 406 dlist.append(d)
408 407 # dfinal = gatherBoth(dlist, consumeErrors=True)
409 408 return dlist
410 409
411 410 def interrupt_then_kill(self, delay=1.0):
412 411 dlist = []
413 412 for el in self.launchers.itervalues():
414 413 d = el.interrupt_then_kill(delay)
415 414 dlist.append(d)
416 415 # dfinal = gatherBoth(dlist, consumeErrors=True)
417 416 return dlist
418 417
419 418 def stop(self):
420 419 return self.interrupt_then_kill()
421 420
422 421 def _notice_engine_stopped(self, data):
423 422 pid = data['pid']
424 423 for idx,el in self.launchers.iteritems():
425 424 if el.process.pid == pid:
426 425 break
427 426 self.launchers.pop(idx)
428 427 self.stop_data[idx] = data
429 428 if not self.launchers:
430 429 self.notify_stop(self.stop_data)
431 430
432 431
433 432 #-----------------------------------------------------------------------------
434 433 # MPIExec launchers
435 434 #-----------------------------------------------------------------------------
436 435
437 436
438 437 class MPIExecLauncher(LocalProcessLauncher):
439 438 """Launch an external process using mpiexec."""
440 439
441 440 mpi_cmd = List(['mpiexec'], config=True,
442 441 help="The mpiexec command to use in starting the process."
443 442 )
444 443 mpi_args = List([], config=True,
445 444 help="The command line arguments to pass to mpiexec."
446 445 )
447 446 program = List(['date'], config=True,
448 447 help="The program to start via mpiexec.")
449 448 program_args = List([], config=True,
450 449 help="The command line argument to the program."
451 450 )
452 451 n = Int(1)
453 452
454 453 def find_args(self):
455 454 """Build self.args using all the fields."""
456 455 return self.mpi_cmd + ['-n', str(self.n)] + self.mpi_args + \
457 456 self.program + self.program_args
458 457
459 458 def start(self, n):
460 459 """Start n instances of the program using mpiexec."""
461 460 self.n = n
462 461 return super(MPIExecLauncher, self).start()
463 462
464 463
465 464 class MPIExecControllerLauncher(MPIExecLauncher):
466 465 """Launch a controller using mpiexec."""
467 466
468 467 controller_cmd = List(ipcontroller_cmd_argv, config=True,
469 468 help="Popen command to launch the Contropper"
470 469 )
471 470 controller_args = List(['--log-to-file','--log-level=%i'%logging.INFO], config=True,
472 471 help="Command line arguments to pass to ipcontroller."
473 472 )
474 473 n = Int(1)
475 474
476 475 def start(self, profile_dir):
477 476 """Start the controller by profile_dir."""
478 477 self.controller_args.extend(['--profile-dir=%s'%profile_dir])
479 478 self.profile_dir = unicode(profile_dir)
480 479 self.log.info("Starting MPIExecControllerLauncher: %r" % self.args)
481 480 return super(MPIExecControllerLauncher, self).start(1)
482 481
483 482 def find_args(self):
484 483 return self.mpi_cmd + ['-n', str(self.n)] + self.mpi_args + \
485 484 self.controller_cmd + self.controller_args
486 485
487 486
488 487 class MPIExecEngineSetLauncher(MPIExecLauncher):
489 488
490 489 program = List(ipengine_cmd_argv, config=True,
491 490 help="Popen command for ipengine"
492 491 )
493 492 program_args = List(
494 493 ['--log-to-file','--log-level=%i'%logging.INFO], config=True,
495 494 help="Command line arguments for ipengine."
496 495 )
497 496 n = Int(1)
498 497
499 498 def start(self, n, profile_dir):
500 499 """Start n engines by profile or profile_dir."""
501 500 self.program_args.extend(['--profile-dir=%s'%profile_dir])
502 501 self.profile_dir = unicode(profile_dir)
503 502 self.n = n
504 503 self.log.info('Starting MPIExecEngineSetLauncher: %r' % self.args)
505 504 return super(MPIExecEngineSetLauncher, self).start(n)
506 505
507 506 #-----------------------------------------------------------------------------
508 507 # SSH launchers
509 508 #-----------------------------------------------------------------------------
510 509
511 510 # TODO: Get SSH Launcher back to level of sshx in 0.10.2
512 511
513 512 class SSHLauncher(LocalProcessLauncher):
514 513 """A minimal launcher for ssh.
515 514
516 515 To be useful this will probably have to be extended to use the ``sshx``
517 516 idea for environment variables. There could be other things this needs
518 517 as well.
519 518 """
520 519
521 520 ssh_cmd = List(['ssh'], config=True,
522 521 help="command for starting ssh")
523 522 ssh_args = List(['-tt'], config=True,
524 523 help="args to pass to ssh")
525 524 program = List(['date'], config=True,
526 525 help="Program to launch via ssh")
527 526 program_args = List([], config=True,
528 527 help="args to pass to remote program")
529 528 hostname = Unicode('', config=True,
530 529 help="hostname on which to launch the program")
531 530 user = Unicode('', config=True,
532 531 help="username for ssh")
533 532 location = Unicode('', config=True,
534 533 help="user@hostname location for ssh in one setting")
535 534
536 535 def _hostname_changed(self, name, old, new):
537 536 if self.user:
538 537 self.location = u'%s@%s' % (self.user, new)
539 538 else:
540 539 self.location = new
541 540
542 541 def _user_changed(self, name, old, new):
543 542 self.location = u'%s@%s' % (new, self.hostname)
544 543
545 544 def find_args(self):
546 545 return self.ssh_cmd + self.ssh_args + [self.location] + \
547 546 self.program + self.program_args
548 547
549 548 def start(self, profile_dir, hostname=None, user=None):
550 549 self.profile_dir = unicode(profile_dir)
551 550 if hostname is not None:
552 551 self.hostname = hostname
553 552 if user is not None:
554 553 self.user = user
555 554
556 555 return super(SSHLauncher, self).start()
557 556
558 557 def signal(self, sig):
559 558 if self.state == 'running':
560 559 # send escaped ssh connection-closer
561 560 self.process.stdin.write('~.')
562 561 self.process.stdin.flush()
563 562
564 563
565 564
566 565 class SSHControllerLauncher(SSHLauncher):
567 566
568 567 program = List(ipcontroller_cmd_argv, config=True,
569 568 help="remote ipcontroller command.")
570 569 program_args = List(['--reuse-files', '--log-to-file','--log-level=%i'%logging.INFO], config=True,
571 570 help="Command line arguments to ipcontroller.")
572 571
573 572
574 573 class SSHEngineLauncher(SSHLauncher):
575 574 program = List(ipengine_cmd_argv, config=True,
576 575 help="remote ipengine command.")
577 576 # Command line arguments for ipengine.
578 577 program_args = List(
579 578 ['--log-to-file','--log_level=%i'%logging.INFO], config=True,
580 579 help="Command line arguments to ipengine."
581 580 )
582 581
583 582 class SSHEngineSetLauncher(LocalEngineSetLauncher):
584 583 launcher_class = SSHEngineLauncher
585 584 engines = Dict(config=True,
586 585 help="""dict of engines to launch. This is a dict by hostname of ints,
587 586 corresponding to the number of engines to start on that host.""")
588 587
589 588 def start(self, n, profile_dir):
590 589 """Start engines by profile or profile_dir.
591 590 `n` is ignored, and the `engines` config property is used instead.
592 591 """
593 592
594 593 self.profile_dir = unicode(profile_dir)
595 594 dlist = []
596 595 for host, n in self.engines.iteritems():
597 596 if isinstance(n, (tuple, list)):
598 597 n, args = n
599 598 else:
600 599 args = copy.deepcopy(self.engine_args)
601 600
602 601 if '@' in host:
603 602 user,host = host.split('@',1)
604 603 else:
605 604 user=None
606 605 for i in range(n):
607 606 el = self.launcher_class(work_dir=self.work_dir, config=self.config, log=self.log)
608 607
609 608 # Copy the engine args over to each engine launcher.
610 609 i
611 610 el.program_args = args
612 611 el.on_stop(self._notice_engine_stopped)
613 612 d = el.start(profile_dir, user=user, hostname=host)
614 613 if i==0:
615 614 self.log.info("Starting SSHEngineSetLauncher: %r" % el.args)
616 615 self.launchers[host+str(i)] = el
617 616 dlist.append(d)
618 617 self.notify_start(dlist)
619 618 return dlist
620 619
621 620
622 621
623 622 #-----------------------------------------------------------------------------
624 623 # Windows HPC Server 2008 scheduler launchers
625 624 #-----------------------------------------------------------------------------
626 625
627 626
628 627 # This is only used on Windows.
629 628 def find_job_cmd():
630 629 if WINDOWS:
631 630 try:
632 631 return find_cmd('job')
633 632 except (FindCmdError, ImportError):
634 633 # ImportError will be raised if win32api is not installed
635 634 return 'job'
636 635 else:
637 636 return 'job'
638 637
639 638
640 639 class WindowsHPCLauncher(BaseLauncher):
641 640
642 641 job_id_regexp = Unicode(r'\d+', config=True,
643 642 help="""A regular expression used to get the job id from the output of the
644 643 submit_command. """
645 644 )
646 645 job_file_name = Unicode(u'ipython_job.xml', config=True,
647 646 help="The filename of the instantiated job script.")
648 647 # The full path to the instantiated job script. This gets made dynamically
649 648 # by combining the work_dir with the job_file_name.
650 649 job_file = Unicode(u'')
651 650 scheduler = Unicode('', config=True,
652 651 help="The hostname of the scheduler to submit the job to.")
653 652 job_cmd = Unicode(find_job_cmd(), config=True,
654 653 help="The command for submitting jobs.")
655 654
656 655 def __init__(self, work_dir=u'.', config=None, **kwargs):
657 656 super(WindowsHPCLauncher, self).__init__(
658 657 work_dir=work_dir, config=config, **kwargs
659 658 )
660 659
661 660 @property
662 661 def job_file(self):
663 662 return os.path.join(self.work_dir, self.job_file_name)
664 663
665 664 def write_job_file(self, n):
666 665 raise NotImplementedError("Implement write_job_file in a subclass.")
667 666
668 667 def find_args(self):
669 668 return [u'job.exe']
670 669
671 670 def parse_job_id(self, output):
672 671 """Take the output of the submit command and return the job id."""
673 672 m = re.search(self.job_id_regexp, output)
674 673 if m is not None:
675 674 job_id = m.group()
676 675 else:
677 676 raise LauncherError("Job id couldn't be determined: %s" % output)
678 677 self.job_id = job_id
679 678 self.log.info('Job started with job id: %r' % job_id)
680 679 return job_id
681 680
682 681 def start(self, n):
683 682 """Start n copies of the process using the Win HPC job scheduler."""
684 683 self.write_job_file(n)
685 684 args = [
686 685 'submit',
687 686 '/jobfile:%s' % self.job_file,
688 687 '/scheduler:%s' % self.scheduler
689 688 ]
690 689 self.log.info("Starting Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
691 690
692 691 output = check_output([self.job_cmd]+args,
693 692 env=os.environ,
694 693 cwd=self.work_dir,
695 694 stderr=STDOUT
696 695 )
697 696 job_id = self.parse_job_id(output)
698 697 self.notify_start(job_id)
699 698 return job_id
700 699
701 700 def stop(self):
702 701 args = [
703 702 'cancel',
704 703 self.job_id,
705 704 '/scheduler:%s' % self.scheduler
706 705 ]
707 706 self.log.info("Stopping Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
708 707 try:
709 708 output = check_output([self.job_cmd]+args,
710 709 env=os.environ,
711 710 cwd=self.work_dir,
712 711 stderr=STDOUT
713 712 )
714 713 except:
715 714 output = 'The job already appears to be stoppped: %r' % self.job_id
716 715 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
717 716 return output
718 717
719 718
720 719 class WindowsHPCControllerLauncher(WindowsHPCLauncher):
721 720
722 721 job_file_name = Unicode(u'ipcontroller_job.xml', config=True,
723 722 help="WinHPC xml job file.")
724 723 extra_args = List([], config=False,
725 724 help="extra args to pass to ipcontroller")
726 725
727 726 def write_job_file(self, n):
728 727 job = IPControllerJob(config=self.config)
729 728
730 729 t = IPControllerTask(config=self.config)
731 730 # The tasks work directory is *not* the actual work directory of
732 731 # the controller. It is used as the base path for the stdout/stderr
733 732 # files that the scheduler redirects to.
734 733 t.work_directory = self.profile_dir
735 734 # Add the profile_dir and from self.start().
736 735 t.controller_args.extend(self.extra_args)
737 736 job.add_task(t)
738 737
739 738 self.log.info("Writing job description file: %s" % self.job_file)
740 739 job.write(self.job_file)
741 740
742 741 @property
743 742 def job_file(self):
744 743 return os.path.join(self.profile_dir, self.job_file_name)
745 744
746 745 def start(self, profile_dir):
747 746 """Start the controller by profile_dir."""
748 747 self.extra_args = ['--profile-dir=%s'%profile_dir]
749 748 self.profile_dir = unicode(profile_dir)
750 749 return super(WindowsHPCControllerLauncher, self).start(1)
751 750
752 751
753 752 class WindowsHPCEngineSetLauncher(WindowsHPCLauncher):
754 753
755 754 job_file_name = Unicode(u'ipengineset_job.xml', config=True,
756 755 help="jobfile for ipengines job")
757 756 extra_args = List([], config=False,
758 757 help="extra args to pas to ipengine")
759 758
760 759 def write_job_file(self, n):
761 760 job = IPEngineSetJob(config=self.config)
762 761
763 762 for i in range(n):
764 763 t = IPEngineTask(config=self.config)
765 764 # The tasks work directory is *not* the actual work directory of
766 765 # the engine. It is used as the base path for the stdout/stderr
767 766 # files that the scheduler redirects to.
768 767 t.work_directory = self.profile_dir
769 768 # Add the profile_dir and from self.start().
770 769 t.engine_args.extend(self.extra_args)
771 770 job.add_task(t)
772 771
773 772 self.log.info("Writing job description file: %s" % self.job_file)
774 773 job.write(self.job_file)
775 774
776 775 @property
777 776 def job_file(self):
778 777 return os.path.join(self.profile_dir, self.job_file_name)
779 778
780 779 def start(self, n, profile_dir):
781 780 """Start the controller by profile_dir."""
782 781 self.extra_args = ['--profile-dir=%s'%profile_dir]
783 782 self.profile_dir = unicode(profile_dir)
784 783 return super(WindowsHPCEngineSetLauncher, self).start(n)
785 784
786 785
787 786 #-----------------------------------------------------------------------------
788 787 # Batch (PBS) system launchers
789 788 #-----------------------------------------------------------------------------
790 789
791 790 class BatchSystemLauncher(BaseLauncher):
792 791 """Launch an external process using a batch system.
793 792
794 793 This class is designed to work with UNIX batch systems like PBS, LSF,
795 794 GridEngine, etc. The overall model is that there are different commands
796 795 like qsub, qdel, etc. that handle the starting and stopping of the process.
797 796
798 797 This class also has the notion of a batch script. The ``batch_template``
799 798 attribute can be set to a string that is a template for the batch script.
800 799 This template is instantiated using string formatting. Thus the template can
801 800 use {n} fot the number of instances. Subclasses can add additional variables
802 801 to the template dict.
803 802 """
804 803
805 804 # Subclasses must fill these in. See PBSEngineSet
806 805 submit_command = List([''], config=True,
807 806 help="The name of the command line program used to submit jobs.")
808 807 delete_command = List([''], config=True,
809 808 help="The name of the command line program used to delete jobs.")
810 809 job_id_regexp = Unicode('', config=True,
811 810 help="""A regular expression used to get the job id from the output of the
812 811 submit_command.""")
813 812 batch_template = Unicode('', config=True,
814 813 help="The string that is the batch script template itself.")
815 814 batch_template_file = Unicode(u'', config=True,
816 815 help="The file that contains the batch template.")
817 816 batch_file_name = Unicode(u'batch_script', config=True,
818 817 help="The filename of the instantiated batch script.")
819 818 queue = Unicode(u'', config=True,
820 819 help="The PBS Queue.")
821 820
822 821 # not configurable, override in subclasses
823 822 # PBS Job Array regex
824 823 job_array_regexp = Unicode('')
825 824 job_array_template = Unicode('')
826 825 # PBS Queue regex
827 826 queue_regexp = Unicode('')
828 827 queue_template = Unicode('')
829 828 # The default batch template, override in subclasses
830 829 default_template = Unicode('')
831 830 # The full path to the instantiated batch script.
832 831 batch_file = Unicode(u'')
833 832 # the format dict used with batch_template:
834 833 context = Dict()
835 834 # the Formatter instance for rendering the templates:
836 835 formatter = Instance(EvalFormatter, (), {})
837 836
838 837
839 838 def find_args(self):
840 839 return self.submit_command + [self.batch_file]
841 840
842 841 def __init__(self, work_dir=u'.', config=None, **kwargs):
843 842 super(BatchSystemLauncher, self).__init__(
844 843 work_dir=work_dir, config=config, **kwargs
845 844 )
846 845 self.batch_file = os.path.join(self.work_dir, self.batch_file_name)
847 846
848 847 def parse_job_id(self, output):
849 848 """Take the output of the submit command and return the job id."""
850 849 m = re.search(self.job_id_regexp, output)
851 850 if m is not None:
852 851 job_id = m.group()
853 852 else:
854 853 raise LauncherError("Job id couldn't be determined: %s" % output)
855 854 self.job_id = job_id
856 855 self.log.info('Job submitted with job id: %r' % job_id)
857 856 return job_id
858 857
859 858 def write_batch_script(self, n):
860 859 """Instantiate and write the batch script to the work_dir."""
861 860 self.context['n'] = n
862 861 self.context['queue'] = self.queue
863 862 # first priority is batch_template if set
864 863 if self.batch_template_file and not self.batch_template:
865 864 # second priority is batch_template_file
866 865 with open(self.batch_template_file) as f:
867 866 self.batch_template = f.read()
868 867 if not self.batch_template:
869 868 # third (last) priority is default_template
870 869 self.batch_template = self.default_template
871 870
872 871 # add jobarray or queue lines to user-specified template
873 872 # note that this is *only* when user did not specify a template.
874 873 regex = re.compile(self.job_array_regexp)
875 874 # print regex.search(self.batch_template)
876 875 if not regex.search(self.batch_template):
877 876 self.log.info("adding job array settings to batch script")
878 877 firstline, rest = self.batch_template.split('\n',1)
879 878 self.batch_template = u'\n'.join([firstline, self.job_array_template, rest])
880 879
881 880 regex = re.compile(self.queue_regexp)
882 881 # print regex.search(self.batch_template)
883 882 if self.queue and not regex.search(self.batch_template):
884 883 self.log.info("adding PBS queue settings to batch script")
885 884 firstline, rest = self.batch_template.split('\n',1)
886 885 self.batch_template = u'\n'.join([firstline, self.queue_template, rest])
887 886
888 887 script_as_string = self.formatter.format(self.batch_template, **self.context)
889 888 self.log.info('Writing instantiated batch script: %s' % self.batch_file)
890 889
891 890 with open(self.batch_file, 'w') as f:
892 891 f.write(script_as_string)
893 892 os.chmod(self.batch_file, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
894 893
895 894 def start(self, n, profile_dir):
896 895 """Start n copies of the process using a batch system."""
897 896 # Here we save profile_dir in the context so they
898 897 # can be used in the batch script template as {profile_dir}
899 898 self.context['profile_dir'] = profile_dir
900 899 self.profile_dir = unicode(profile_dir)
901 900 self.write_batch_script(n)
902 901 output = check_output(self.args, env=os.environ)
903 902
904 903 job_id = self.parse_job_id(output)
905 904 self.notify_start(job_id)
906 905 return job_id
907 906
908 907 def stop(self):
909 908 output = check_output(self.delete_command+[self.job_id], env=os.environ)
910 909 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
911 910 return output
912 911
913 912
914 913 class PBSLauncher(BatchSystemLauncher):
915 914 """A BatchSystemLauncher subclass for PBS."""
916 915
917 916 submit_command = List(['qsub'], config=True,
918 917 help="The PBS submit command ['qsub']")
919 918 delete_command = List(['qdel'], config=True,
920 919 help="The PBS delete command ['qsub']")
921 920 job_id_regexp = Unicode(r'\d+', config=True,
922 921 help="Regular expresion for identifying the job ID [r'\d+']")
923 922
924 923 batch_file = Unicode(u'')
925 924 job_array_regexp = Unicode('#PBS\W+-t\W+[\w\d\-\$]+')
926 925 job_array_template = Unicode('#PBS -t 1-{n}')
927 926 queue_regexp = Unicode('#PBS\W+-q\W+\$?\w+')
928 927 queue_template = Unicode('#PBS -q {queue}')
929 928
930 929
931 930 class PBSControllerLauncher(PBSLauncher):
932 931 """Launch a controller using PBS."""
933 932
934 933 batch_file_name = Unicode(u'pbs_controller', config=True,
935 934 help="batch file name for the controller job.")
936 935 default_template= Unicode("""#!/bin/sh
937 936 #PBS -V
938 937 #PBS -N ipcontroller
939 938 %s --log-to-file --profile-dir={profile_dir}
940 939 """%(' '.join(ipcontroller_cmd_argv)))
941 940
942 941 def start(self, profile_dir):
943 942 """Start the controller by profile or profile_dir."""
944 943 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
945 944 return super(PBSControllerLauncher, self).start(1, profile_dir)
946 945
947 946
948 947 class PBSEngineSetLauncher(PBSLauncher):
949 948 """Launch Engines using PBS"""
950 949 batch_file_name = Unicode(u'pbs_engines', config=True,
951 950 help="batch file name for the engine(s) job.")
952 951 default_template= Unicode(u"""#!/bin/sh
953 952 #PBS -V
954 953 #PBS -N ipengine
955 954 %s --profile-dir={profile_dir}
956 955 """%(' '.join(ipengine_cmd_argv)))
957 956
958 957 def start(self, n, profile_dir):
959 958 """Start n engines by profile or profile_dir."""
960 959 self.log.info('Starting %i engines with PBSEngineSetLauncher: %r' % (n, self.args))
961 960 return super(PBSEngineSetLauncher, self).start(n, profile_dir)
962 961
963 962 #SGE is very similar to PBS
964 963
965 964 class SGELauncher(PBSLauncher):
966 965 """Sun GridEngine is a PBS clone with slightly different syntax"""
967 966 job_array_regexp = Unicode('#\$\W+\-t')
968 967 job_array_template = Unicode('#$ -t 1-{n}')
969 968 queue_regexp = Unicode('#\$\W+-q\W+\$?\w+')
970 969 queue_template = Unicode('#$ -q {queue}')
971 970
972 971 class SGEControllerLauncher(SGELauncher):
973 972 """Launch a controller using SGE."""
974 973
975 974 batch_file_name = Unicode(u'sge_controller', config=True,
976 975 help="batch file name for the ipontroller job.")
977 976 default_template= Unicode(u"""#$ -V
978 977 #$ -S /bin/sh
979 978 #$ -N ipcontroller
980 979 %s --log-to-file --profile-dir={profile_dir}
981 980 """%(' '.join(ipcontroller_cmd_argv)))
982 981
983 982 def start(self, profile_dir):
984 983 """Start the controller by profile or profile_dir."""
985 984 self.log.info("Starting PBSControllerLauncher: %r" % self.args)
986 985 return super(SGEControllerLauncher, self).start(1, profile_dir)
987 986
988 987 class SGEEngineSetLauncher(SGELauncher):
989 988 """Launch Engines with SGE"""
990 989 batch_file_name = Unicode(u'sge_engines', config=True,
991 990 help="batch file name for the engine(s) job.")
992 991 default_template = Unicode("""#$ -V
993 992 #$ -S /bin/sh
994 993 #$ -N ipengine
995 994 %s --profile-dir={profile_dir}
996 995 """%(' '.join(ipengine_cmd_argv)))
997 996
998 997 def start(self, n, profile_dir):
999 998 """Start n engines by profile or profile_dir."""
1000 999 self.log.info('Starting %i engines with SGEEngineSetLauncher: %r' % (n, self.args))
1001 1000 return super(SGEEngineSetLauncher, self).start(n, profile_dir)
1002 1001
1003 1002
1004 1003 # LSF launchers
1005 1004
1006 1005 class LSFLauncher(BatchSystemLauncher):
1007 1006 """A BatchSystemLauncher subclass for LSF."""
1008 1007
1009 1008 submit_command = List(['bsub'], config=True,
1010 1009 help="The PBS submit command ['bsub']")
1011 1010 delete_command = List(['bkill'], config=True,
1012 1011 help="The PBS delete command ['bkill']")
1013 1012 job_id_regexp = Unicode(r'\d+', config=True,
1014 1013 help="Regular expresion for identifying the job ID [r'\d+']")
1015 1014
1016 1015 batch_file = Unicode(u'')
1017 1016 job_array_regexp = Unicode('#BSUB[ \t]-J+\w+\[\d+-\d+\]')
1018 1017 job_array_template = Unicode('#BSUB -J ipengine[1-{n}]')
1019 1018 queue_regexp = Unicode('#BSUB[ \t]+-q[ \t]+\w+')
1020 1019 queue_template = Unicode('#BSUB -q {queue}')
1021 1020
1022 1021 def start(self, n, profile_dir):
1023 1022 """Start n copies of the process using LSF batch system.
1024 1023 This cant inherit from the base class because bsub expects
1025 1024 to be piped a shell script in order to honor the #BSUB directives :
1026 1025 bsub < script
1027 1026 """
1028 1027 # Here we save profile_dir in the context so they
1029 1028 # can be used in the batch script template as {profile_dir}
1030 1029 self.context['profile_dir'] = profile_dir
1031 1030 self.profile_dir = unicode(profile_dir)
1032 1031 self.write_batch_script(n)
1033 1032 #output = check_output(self.args, env=os.environ)
1034 1033 piped_cmd = self.args[0]+'<\"'+self.args[1]+'\"'
1035 1034 p = Popen(piped_cmd, shell=True,env=os.environ,stdout=PIPE)
1036 1035 output,err = p.communicate()
1037 1036 job_id = self.parse_job_id(output)
1038 1037 self.notify_start(job_id)
1039 1038 return job_id
1040 1039
1041 1040
1042 1041 class LSFControllerLauncher(LSFLauncher):
1043 1042 """Launch a controller using LSF."""
1044 1043
1045 1044 batch_file_name = Unicode(u'lsf_controller', config=True,
1046 1045 help="batch file name for the controller job.")
1047 1046 default_template= Unicode("""#!/bin/sh
1048 1047 #BSUB -J ipcontroller
1049 1048 #BSUB -oo ipcontroller.o.%%J
1050 1049 #BSUB -eo ipcontroller.e.%%J
1051 1050 %s --log-to-file --profile-dir={profile_dir}
1052 1051 """%(' '.join(ipcontroller_cmd_argv)))
1053 1052
1054 1053 def start(self, profile_dir):
1055 1054 """Start the controller by profile or profile_dir."""
1056 1055 self.log.info("Starting LSFControllerLauncher: %r" % self.args)
1057 1056 return super(LSFControllerLauncher, self).start(1, profile_dir)
1058 1057
1059 1058
1060 1059 class LSFEngineSetLauncher(LSFLauncher):
1061 1060 """Launch Engines using LSF"""
1062 1061 batch_file_name = Unicode(u'lsf_engines', config=True,
1063 1062 help="batch file name for the engine(s) job.")
1064 1063 default_template= Unicode(u"""#!/bin/sh
1065 1064 #BSUB -oo ipengine.o.%%J
1066 1065 #BSUB -eo ipengine.e.%%J
1067 1066 %s --profile-dir={profile_dir}
1068 1067 """%(' '.join(ipengine_cmd_argv)))
1069 1068
1070 1069 def start(self, n, profile_dir):
1071 1070 """Start n engines by profile or profile_dir."""
1072 1071 self.log.info('Starting %i engines with LSFEngineSetLauncher: %r' % (n, self.args))
1073 1072 return super(LSFEngineSetLauncher, self).start(n, profile_dir)
1074 1073
1075 1074
1076 1075 #-----------------------------------------------------------------------------
1077 1076 # A launcher for ipcluster itself!
1078 1077 #-----------------------------------------------------------------------------
1079 1078
1080 1079
1081 1080 class IPClusterLauncher(LocalProcessLauncher):
1082 1081 """Launch the ipcluster program in an external process."""
1083 1082
1084 1083 ipcluster_cmd = List(ipcluster_cmd_argv, config=True,
1085 1084 help="Popen command for ipcluster")
1086 1085 ipcluster_args = List(
1087 1086 ['--clean-logs', '--log-to-file', '--log-level=%i'%logging.INFO], config=True,
1088 1087 help="Command line arguments to pass to ipcluster.")
1089 1088 ipcluster_subcommand = Unicode('start')
1090 1089 ipcluster_n = Int(2)
1091 1090
1092 1091 def find_args(self):
1093 1092 return self.ipcluster_cmd + [self.ipcluster_subcommand] + \
1094 1093 ['--n=%i'%self.ipcluster_n] + self.ipcluster_args
1095 1094
1096 1095 def start(self):
1097 1096 self.log.info("Starting ipcluster: %r" % self.args)
1098 1097 return super(IPClusterLauncher, self).start()
1099 1098
1100 1099 #-----------------------------------------------------------------------------
1101 1100 # Collections of launchers
1102 1101 #-----------------------------------------------------------------------------
1103 1102
1104 1103 local_launchers = [
1105 1104 LocalControllerLauncher,
1106 1105 LocalEngineLauncher,
1107 1106 LocalEngineSetLauncher,
1108 1107 ]
1109 1108 mpi_launchers = [
1110 1109 MPIExecLauncher,
1111 1110 MPIExecControllerLauncher,
1112 1111 MPIExecEngineSetLauncher,
1113 1112 ]
1114 1113 ssh_launchers = [
1115 1114 SSHLauncher,
1116 1115 SSHControllerLauncher,
1117 1116 SSHEngineLauncher,
1118 1117 SSHEngineSetLauncher,
1119 1118 ]
1120 1119 winhpc_launchers = [
1121 1120 WindowsHPCLauncher,
1122 1121 WindowsHPCControllerLauncher,
1123 1122 WindowsHPCEngineSetLauncher,
1124 1123 ]
1125 1124 pbs_launchers = [
1126 1125 PBSLauncher,
1127 1126 PBSControllerLauncher,
1128 1127 PBSEngineSetLauncher,
1129 1128 ]
1130 1129 sge_launchers = [
1131 1130 SGELauncher,
1132 1131 SGEControllerLauncher,
1133 1132 SGEEngineSetLauncher,
1134 1133 ]
1135 1134 lsf_launchers = [
1136 1135 LSFLauncher,
1137 1136 LSFControllerLauncher,
1138 1137 LSFEngineSetLauncher,
1139 1138 ]
1140 1139 all_launchers = local_launchers + mpi_launchers + ssh_launchers + winhpc_launchers\
1141 1140 + pbs_launchers + sge_launchers + lsf_launchers
1142 1141
@@ -1,115 +1,114 b''
1 #!/usr/bin/env python
2 1 """
3 2 A simple logger object that consolidates messages incoming from ipcluster processes.
4 3
5 4 Authors:
6 5
7 6 * MinRK
8 7
9 8 """
10 9
11 10 #-----------------------------------------------------------------------------
12 11 # Copyright (C) 2011 The IPython Development Team
13 12 #
14 13 # Distributed under the terms of the BSD License. The full license is in
15 14 # the file COPYING, distributed as part of this software.
16 15 #-----------------------------------------------------------------------------
17 16
18 17 #-----------------------------------------------------------------------------
19 18 # Imports
20 19 #-----------------------------------------------------------------------------
21 20
22 21
23 22 import logging
24 23 import sys
25 24
26 25 import zmq
27 26 from zmq.eventloop import ioloop, zmqstream
28 27
29 28 from IPython.config.configurable import LoggingConfigurable
30 29 from IPython.utils.traitlets import Int, Unicode, Instance, List
31 30
32 31 #-----------------------------------------------------------------------------
33 32 # Classes
34 33 #-----------------------------------------------------------------------------
35 34
36 35
37 36 class LogWatcher(LoggingConfigurable):
38 37 """A simple class that receives messages on a SUB socket, as published
39 38 by subclasses of `zmq.log.handlers.PUBHandler`, and logs them itself.
40 39
41 40 This can subscribe to multiple topics, but defaults to all topics.
42 41 """
43 42
44 43 # configurables
45 44 topics = List([''], config=True,
46 45 help="The ZMQ topics to subscribe to. Default is to subscribe to all messages")
47 46 url = Unicode('tcp://127.0.0.1:20202', config=True,
48 47 help="ZMQ url on which to listen for log messages")
49 48
50 49 # internals
51 50 stream = Instance('zmq.eventloop.zmqstream.ZMQStream')
52 51
53 52 context = Instance(zmq.Context)
54 53 def _context_default(self):
55 54 return zmq.Context.instance()
56 55
57 56 loop = Instance(zmq.eventloop.ioloop.IOLoop)
58 57 def _loop_default(self):
59 58 return ioloop.IOLoop.instance()
60 59
61 60 def __init__(self, **kwargs):
62 61 super(LogWatcher, self).__init__(**kwargs)
63 62 s = self.context.socket(zmq.SUB)
64 63 s.bind(self.url)
65 64 self.stream = zmqstream.ZMQStream(s, self.loop)
66 65 self.subscribe()
67 66 self.on_trait_change(self.subscribe, 'topics')
68 67
69 68 def start(self):
70 69 self.stream.on_recv(self.log_message)
71 70
72 71 def stop(self):
73 72 self.stream.stop_on_recv()
74 73
75 74 def subscribe(self):
76 75 """Update our SUB socket's subscriptions."""
77 76 self.stream.setsockopt(zmq.UNSUBSCRIBE, '')
78 77 if '' in self.topics:
79 78 self.log.debug("Subscribing to: everything")
80 79 self.stream.setsockopt(zmq.SUBSCRIBE, '')
81 80 else:
82 81 for topic in self.topics:
83 82 self.log.debug("Subscribing to: %r"%(topic))
84 83 self.stream.setsockopt(zmq.SUBSCRIBE, topic)
85 84
86 85 def _extract_level(self, topic_str):
87 86 """Turn 'engine.0.INFO.extra' into (logging.INFO, 'engine.0.extra')"""
88 87 topics = topic_str.split('.')
89 88 for idx,t in enumerate(topics):
90 89 level = getattr(logging, t, None)
91 90 if level is not None:
92 91 break
93 92
94 93 if level is None:
95 94 level = logging.INFO
96 95 else:
97 96 topics.pop(idx)
98 97
99 98 return level, '.'.join(topics)
100 99
101 100
102 101 def log_message(self, raw):
103 102 """receive and parse a message, then log it."""
104 103 if len(raw) != 2 or '.' not in raw[0]:
105 104 self.log.error("Invalid log message: %s"%raw)
106 105 return
107 106 else:
108 107 topic, msg = raw
109 108 # don't newline, since log messages always newline:
110 109 topic,level_name = topic.rsplit('.',1)
111 110 level,topic = self._extract_level(topic)
112 111 if msg[-1] == '\n':
113 112 msg = msg[:-1]
114 113 self.log.log(level, "[%s] %s" % (topic, msg))
115 114
@@ -1,73 +1,72 b''
1 #!/usr/bin/env python
2 1 """Utility for forwarding file read events over a zmq socket.
3 2
4 3 This is necessary because select on Windows only supports sockets, not FDs.
5 4
6 5 Authors:
7 6
8 7 * MinRK
9 8
10 9 """
11 10
12 11 #-----------------------------------------------------------------------------
13 12 # Copyright (C) 2011 The IPython Development Team
14 13 #
15 14 # Distributed under the terms of the BSD License. The full license is in
16 15 # the file COPYING, distributed as part of this software.
17 16 #-----------------------------------------------------------------------------
18 17
19 18 #-----------------------------------------------------------------------------
20 19 # Imports
21 20 #-----------------------------------------------------------------------------
22 21
23 22 import uuid
24 23 import zmq
25 24
26 25 from threading import Thread
27 26
28 27 #-----------------------------------------------------------------------------
29 28 # Code
30 29 #-----------------------------------------------------------------------------
31 30
32 31 class ForwarderThread(Thread):
33 32 def __init__(self, sock, fd):
34 33 Thread.__init__(self)
35 34 self.daemon=True
36 35 self.sock = sock
37 36 self.fd = fd
38 37
39 38 def run(self):
40 39 """Loop through lines in self.fd, and send them over self.sock."""
41 40 line = self.fd.readline()
42 41 # allow for files opened in unicode mode
43 42 if isinstance(line, unicode):
44 43 send = self.sock.send_unicode
45 44 else:
46 45 send = self.sock.send
47 46 while line:
48 47 send(line)
49 48 line = self.fd.readline()
50 49 # line == '' means EOF
51 50 self.fd.close()
52 51 self.sock.close()
53 52
54 53 def forward_read_events(fd, context=None):
55 54 """Forward read events from an FD over a socket.
56 55
57 56 This method wraps a file in a socket pair, so it can
58 57 be polled for read events by select (specifically zmq.eventloop.ioloop)
59 58 """
60 59 if context is None:
61 60 context = zmq.Context.instance()
62 61 push = context.socket(zmq.PUSH)
63 62 push.setsockopt(zmq.LINGER, -1)
64 63 pull = context.socket(zmq.PULL)
65 64 addr='inproc://%s'%uuid.uuid4()
66 65 push.bind(addr)
67 66 pull.connect(addr)
68 67 forwarder = ForwarderThread(push, fd)
69 68 forwarder.start()
70 69 return pull
71 70
72 71
73 72 __all__ = ['forward_read_events']
@@ -1,320 +1,319 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 Job and task components for writing .xml files that the Windows HPC Server
5 4 2008 can use to start jobs.
6 5
7 6 Authors:
8 7
9 8 * Brian Granger
10 9 * MinRK
11 10
12 11 """
13 12
14 13 #-----------------------------------------------------------------------------
15 14 # Copyright (C) 2008-2011 The IPython Development Team
16 15 #
17 16 # Distributed under the terms of the BSD License. The full license is in
18 17 # the file COPYING, distributed as part of this software.
19 18 #-----------------------------------------------------------------------------
20 19
21 20 #-----------------------------------------------------------------------------
22 21 # Imports
23 22 #-----------------------------------------------------------------------------
24 23
25 24 import os
26 25 import re
27 26 import uuid
28 27
29 28 from xml.etree import ElementTree as ET
30 29
31 30 from IPython.config.configurable import Configurable
32 31 from IPython.utils.traitlets import (
33 32 Unicode, Int, List, Instance,
34 33 Enum, Bool
35 34 )
36 35
37 36 #-----------------------------------------------------------------------------
38 37 # Job and Task classes
39 38 #-----------------------------------------------------------------------------
40 39
41 40
42 41 def as_str(value):
43 42 if isinstance(value, str):
44 43 return value
45 44 elif isinstance(value, bool):
46 45 if value:
47 46 return 'true'
48 47 else:
49 48 return 'false'
50 49 elif isinstance(value, (int, float)):
51 50 return repr(value)
52 51 else:
53 52 return value
54 53
55 54
56 55 def indent(elem, level=0):
57 56 i = "\n" + level*" "
58 57 if len(elem):
59 58 if not elem.text or not elem.text.strip():
60 59 elem.text = i + " "
61 60 if not elem.tail or not elem.tail.strip():
62 61 elem.tail = i
63 62 for elem in elem:
64 63 indent(elem, level+1)
65 64 if not elem.tail or not elem.tail.strip():
66 65 elem.tail = i
67 66 else:
68 67 if level and (not elem.tail or not elem.tail.strip()):
69 68 elem.tail = i
70 69
71 70
72 71 def find_username():
73 72 domain = os.environ.get('USERDOMAIN')
74 73 username = os.environ.get('USERNAME','')
75 74 if domain is None:
76 75 return username
77 76 else:
78 77 return '%s\\%s' % (domain, username)
79 78
80 79
81 80 class WinHPCJob(Configurable):
82 81
83 82 job_id = Unicode('')
84 83 job_name = Unicode('MyJob', config=True)
85 84 min_cores = Int(1, config=True)
86 85 max_cores = Int(1, config=True)
87 86 min_sockets = Int(1, config=True)
88 87 max_sockets = Int(1, config=True)
89 88 min_nodes = Int(1, config=True)
90 89 max_nodes = Int(1, config=True)
91 90 unit_type = Unicode("Core", config=True)
92 91 auto_calculate_min = Bool(True, config=True)
93 92 auto_calculate_max = Bool(True, config=True)
94 93 run_until_canceled = Bool(False, config=True)
95 94 is_exclusive = Bool(False, config=True)
96 95 username = Unicode(find_username(), config=True)
97 96 job_type = Unicode('Batch', config=True)
98 97 priority = Enum(('Lowest','BelowNormal','Normal','AboveNormal','Highest'),
99 98 default_value='Highest', config=True)
100 99 requested_nodes = Unicode('', config=True)
101 100 project = Unicode('IPython', config=True)
102 101 xmlns = Unicode('http://schemas.microsoft.com/HPCS2008/scheduler/')
103 102 version = Unicode("2.000")
104 103 tasks = List([])
105 104
106 105 @property
107 106 def owner(self):
108 107 return self.username
109 108
110 109 def _write_attr(self, root, attr, key):
111 110 s = as_str(getattr(self, attr, ''))
112 111 if s:
113 112 root.set(key, s)
114 113
115 114 def as_element(self):
116 115 # We have to add _A_ type things to get the right order than
117 116 # the MSFT XML parser expects.
118 117 root = ET.Element('Job')
119 118 self._write_attr(root, 'version', '_A_Version')
120 119 self._write_attr(root, 'job_name', '_B_Name')
121 120 self._write_attr(root, 'unit_type', '_C_UnitType')
122 121 self._write_attr(root, 'min_cores', '_D_MinCores')
123 122 self._write_attr(root, 'max_cores', '_E_MaxCores')
124 123 self._write_attr(root, 'min_sockets', '_F_MinSockets')
125 124 self._write_attr(root, 'max_sockets', '_G_MaxSockets')
126 125 self._write_attr(root, 'min_nodes', '_H_MinNodes')
127 126 self._write_attr(root, 'max_nodes', '_I_MaxNodes')
128 127 self._write_attr(root, 'run_until_canceled', '_J_RunUntilCanceled')
129 128 self._write_attr(root, 'is_exclusive', '_K_IsExclusive')
130 129 self._write_attr(root, 'username', '_L_UserName')
131 130 self._write_attr(root, 'job_type', '_M_JobType')
132 131 self._write_attr(root, 'priority', '_N_Priority')
133 132 self._write_attr(root, 'requested_nodes', '_O_RequestedNodes')
134 133 self._write_attr(root, 'auto_calculate_max', '_P_AutoCalculateMax')
135 134 self._write_attr(root, 'auto_calculate_min', '_Q_AutoCalculateMin')
136 135 self._write_attr(root, 'project', '_R_Project')
137 136 self._write_attr(root, 'owner', '_S_Owner')
138 137 self._write_attr(root, 'xmlns', '_T_xmlns')
139 138 dependencies = ET.SubElement(root, "Dependencies")
140 139 etasks = ET.SubElement(root, "Tasks")
141 140 for t in self.tasks:
142 141 etasks.append(t.as_element())
143 142 return root
144 143
145 144 def tostring(self):
146 145 """Return the string representation of the job description XML."""
147 146 root = self.as_element()
148 147 indent(root)
149 148 txt = ET.tostring(root, encoding="utf-8")
150 149 # Now remove the tokens used to order the attributes.
151 150 txt = re.sub(r'_[A-Z]_','',txt)
152 151 txt = '<?xml version="1.0" encoding="utf-8"?>\n' + txt
153 152 return txt
154 153
155 154 def write(self, filename):
156 155 """Write the XML job description to a file."""
157 156 txt = self.tostring()
158 157 with open(filename, 'w') as f:
159 158 f.write(txt)
160 159
161 160 def add_task(self, task):
162 161 """Add a task to the job.
163 162
164 163 Parameters
165 164 ----------
166 165 task : :class:`WinHPCTask`
167 166 The task object to add.
168 167 """
169 168 self.tasks.append(task)
170 169
171 170
172 171 class WinHPCTask(Configurable):
173 172
174 173 task_id = Unicode('')
175 174 task_name = Unicode('')
176 175 version = Unicode("2.000")
177 176 min_cores = Int(1, config=True)
178 177 max_cores = Int(1, config=True)
179 178 min_sockets = Int(1, config=True)
180 179 max_sockets = Int(1, config=True)
181 180 min_nodes = Int(1, config=True)
182 181 max_nodes = Int(1, config=True)
183 182 unit_type = Unicode("Core", config=True)
184 183 command_line = Unicode('', config=True)
185 184 work_directory = Unicode('', config=True)
186 185 is_rerunnaable = Bool(True, config=True)
187 186 std_out_file_path = Unicode('', config=True)
188 187 std_err_file_path = Unicode('', config=True)
189 188 is_parametric = Bool(False, config=True)
190 189 environment_variables = Instance(dict, args=(), config=True)
191 190
192 191 def _write_attr(self, root, attr, key):
193 192 s = as_str(getattr(self, attr, ''))
194 193 if s:
195 194 root.set(key, s)
196 195
197 196 def as_element(self):
198 197 root = ET.Element('Task')
199 198 self._write_attr(root, 'version', '_A_Version')
200 199 self._write_attr(root, 'task_name', '_B_Name')
201 200 self._write_attr(root, 'min_cores', '_C_MinCores')
202 201 self._write_attr(root, 'max_cores', '_D_MaxCores')
203 202 self._write_attr(root, 'min_sockets', '_E_MinSockets')
204 203 self._write_attr(root, 'max_sockets', '_F_MaxSockets')
205 204 self._write_attr(root, 'min_nodes', '_G_MinNodes')
206 205 self._write_attr(root, 'max_nodes', '_H_MaxNodes')
207 206 self._write_attr(root, 'command_line', '_I_CommandLine')
208 207 self._write_attr(root, 'work_directory', '_J_WorkDirectory')
209 208 self._write_attr(root, 'is_rerunnaable', '_K_IsRerunnable')
210 209 self._write_attr(root, 'std_out_file_path', '_L_StdOutFilePath')
211 210 self._write_attr(root, 'std_err_file_path', '_M_StdErrFilePath')
212 211 self._write_attr(root, 'is_parametric', '_N_IsParametric')
213 212 self._write_attr(root, 'unit_type', '_O_UnitType')
214 213 root.append(self.get_env_vars())
215 214 return root
216 215
217 216 def get_env_vars(self):
218 217 env_vars = ET.Element('EnvironmentVariables')
219 218 for k, v in self.environment_variables.iteritems():
220 219 variable = ET.SubElement(env_vars, "Variable")
221 220 name = ET.SubElement(variable, "Name")
222 221 name.text = k
223 222 value = ET.SubElement(variable, "Value")
224 223 value.text = v
225 224 return env_vars
226 225
227 226
228 227
229 228 # By declaring these, we can configure the controller and engine separately!
230 229
231 230 class IPControllerJob(WinHPCJob):
232 231 job_name = Unicode('IPController', config=False)
233 232 is_exclusive = Bool(False, config=True)
234 233 username = Unicode(find_username(), config=True)
235 234 priority = Enum(('Lowest','BelowNormal','Normal','AboveNormal','Highest'),
236 235 default_value='Highest', config=True)
237 236 requested_nodes = Unicode('', config=True)
238 237 project = Unicode('IPython', config=True)
239 238
240 239
241 240 class IPEngineSetJob(WinHPCJob):
242 241 job_name = Unicode('IPEngineSet', config=False)
243 242 is_exclusive = Bool(False, config=True)
244 243 username = Unicode(find_username(), config=True)
245 244 priority = Enum(('Lowest','BelowNormal','Normal','AboveNormal','Highest'),
246 245 default_value='Highest', config=True)
247 246 requested_nodes = Unicode('', config=True)
248 247 project = Unicode('IPython', config=True)
249 248
250 249
251 250 class IPControllerTask(WinHPCTask):
252 251
253 252 task_name = Unicode('IPController', config=True)
254 253 controller_cmd = List(['ipcontroller.exe'], config=True)
255 254 controller_args = List(['--log-to-file', '--log-level=40'], config=True)
256 255 # I don't want these to be configurable
257 256 std_out_file_path = Unicode('', config=False)
258 257 std_err_file_path = Unicode('', config=False)
259 258 min_cores = Int(1, config=False)
260 259 max_cores = Int(1, config=False)
261 260 min_sockets = Int(1, config=False)
262 261 max_sockets = Int(1, config=False)
263 262 min_nodes = Int(1, config=False)
264 263 max_nodes = Int(1, config=False)
265 264 unit_type = Unicode("Core", config=False)
266 265 work_directory = Unicode('', config=False)
267 266
268 267 def __init__(self, config=None):
269 268 super(IPControllerTask, self).__init__(config=config)
270 269 the_uuid = uuid.uuid1()
271 270 self.std_out_file_path = os.path.join('log','ipcontroller-%s.out' % the_uuid)
272 271 self.std_err_file_path = os.path.join('log','ipcontroller-%s.err' % the_uuid)
273 272
274 273 @property
275 274 def command_line(self):
276 275 return ' '.join(self.controller_cmd + self.controller_args)
277 276
278 277
279 278 class IPEngineTask(WinHPCTask):
280 279
281 280 task_name = Unicode('IPEngine', config=True)
282 281 engine_cmd = List(['ipengine.exe'], config=True)
283 282 engine_args = List(['--log-to-file', '--log-level=40'], config=True)
284 283 # I don't want these to be configurable
285 284 std_out_file_path = Unicode('', config=False)
286 285 std_err_file_path = Unicode('', config=False)
287 286 min_cores = Int(1, config=False)
288 287 max_cores = Int(1, config=False)
289 288 min_sockets = Int(1, config=False)
290 289 max_sockets = Int(1, config=False)
291 290 min_nodes = Int(1, config=False)
292 291 max_nodes = Int(1, config=False)
293 292 unit_type = Unicode("Core", config=False)
294 293 work_directory = Unicode('', config=False)
295 294
296 295 def __init__(self, config=None):
297 296 super(IPEngineTask,self).__init__(config=config)
298 297 the_uuid = uuid.uuid1()
299 298 self.std_out_file_path = os.path.join('log','ipengine-%s.out' % the_uuid)
300 299 self.std_err_file_path = os.path.join('log','ipengine-%s.err' % the_uuid)
301 300
302 301 @property
303 302 def command_line(self):
304 303 return ' '.join(self.engine_cmd + self.engine_args)
305 304
306 305
307 306 # j = WinHPCJob(None)
308 307 # j.job_name = 'IPCluster'
309 308 # j.username = 'GNET\\bgranger'
310 309 # j.requested_nodes = 'GREEN'
311 310 #
312 311 # t = WinHPCTask(None)
313 312 # t.task_name = 'Controller'
314 313 # t.command_line = r"\\blue\domainusers$\bgranger\Python\Python25\Scripts\ipcontroller.exe --log-to-file -p default --log-level 10"
315 314 # t.work_directory = r"\\blue\domainusers$\bgranger\.ipython\cluster_default"
316 315 # t.std_out_file_path = 'controller-out.txt'
317 316 # t.std_err_file_path = 'controller-err.txt'
318 317 # t.environment_variables['PYTHONPATH'] = r"\\blue\domainusers$\bgranger\Python\Python25\Lib\site-packages"
319 318 # j.add_task(t)
320 319
1 NO CONTENT: modified file chmod 100644 => 100755
@@ -1,1291 +1,1290 b''
1 #!/usr/bin/env python
2 1 """The IPython Controller Hub with 0MQ
3 2 This is the master object that handles connections from engines and clients,
4 3 and monitors traffic through the various queues.
5 4
6 5 Authors:
7 6
8 7 * Min RK
9 8 """
10 9 #-----------------------------------------------------------------------------
11 10 # Copyright (C) 2010 The IPython Development Team
12 11 #
13 12 # Distributed under the terms of the BSD License. The full license is in
14 13 # the file COPYING, distributed as part of this software.
15 14 #-----------------------------------------------------------------------------
16 15
17 16 #-----------------------------------------------------------------------------
18 17 # Imports
19 18 #-----------------------------------------------------------------------------
20 19 from __future__ import print_function
21 20
22 21 import sys
23 22 import time
24 23 from datetime import datetime
25 24
26 25 import zmq
27 26 from zmq.eventloop import ioloop
28 27 from zmq.eventloop.zmqstream import ZMQStream
29 28
30 29 # internal:
31 30 from IPython.utils.importstring import import_item
32 31 from IPython.utils.traitlets import (
33 32 HasTraits, Instance, Int, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
34 33 )
35 34
36 35 from IPython.parallel import error, util
37 36 from IPython.parallel.factory import RegistrationFactory
38 37
39 38 from IPython.zmq.session import SessionFactory
40 39
41 40 from .heartmonitor import HeartMonitor
42 41
43 42 #-----------------------------------------------------------------------------
44 43 # Code
45 44 #-----------------------------------------------------------------------------
46 45
47 46 def _passer(*args, **kwargs):
48 47 return
49 48
50 49 def _printer(*args, **kwargs):
51 50 print (args)
52 51 print (kwargs)
53 52
54 53 def empty_record():
55 54 """Return an empty dict with all record keys."""
56 55 return {
57 56 'msg_id' : None,
58 57 'header' : None,
59 58 'content': None,
60 59 'buffers': None,
61 60 'submitted': None,
62 61 'client_uuid' : None,
63 62 'engine_uuid' : None,
64 63 'started': None,
65 64 'completed': None,
66 65 'resubmitted': None,
67 66 'result_header' : None,
68 67 'result_content' : None,
69 68 'result_buffers' : None,
70 69 'queue' : None,
71 70 'pyin' : None,
72 71 'pyout': None,
73 72 'pyerr': None,
74 73 'stdout': '',
75 74 'stderr': '',
76 75 }
77 76
78 77 def init_record(msg):
79 78 """Initialize a TaskRecord based on a request."""
80 79 header = msg['header']
81 80 return {
82 81 'msg_id' : header['msg_id'],
83 82 'header' : header,
84 83 'content': msg['content'],
85 84 'buffers': msg['buffers'],
86 85 'submitted': header['date'],
87 86 'client_uuid' : None,
88 87 'engine_uuid' : None,
89 88 'started': None,
90 89 'completed': None,
91 90 'resubmitted': None,
92 91 'result_header' : None,
93 92 'result_content' : None,
94 93 'result_buffers' : None,
95 94 'queue' : None,
96 95 'pyin' : None,
97 96 'pyout': None,
98 97 'pyerr': None,
99 98 'stdout': '',
100 99 'stderr': '',
101 100 }
102 101
103 102
104 103 class EngineConnector(HasTraits):
105 104 """A simple object for accessing the various zmq connections of an object.
106 105 Attributes are:
107 106 id (int): engine ID
108 107 uuid (str): uuid (unused?)
109 108 queue (str): identity of queue's XREQ socket
110 109 registration (str): identity of registration XREQ socket
111 110 heartbeat (str): identity of heartbeat XREQ socket
112 111 """
113 112 id=Int(0)
114 113 queue=CBytes()
115 114 control=CBytes()
116 115 registration=CBytes()
117 116 heartbeat=CBytes()
118 117 pending=Set()
119 118
120 119 class HubFactory(RegistrationFactory):
121 120 """The Configurable for setting up a Hub."""
122 121
123 122 # port-pairs for monitoredqueues:
124 123 hb = Tuple(Int,Int,config=True,
125 124 help="""XREQ/SUB Port pair for Engine heartbeats""")
126 125 def _hb_default(self):
127 126 return tuple(util.select_random_ports(2))
128 127
129 128 mux = Tuple(Int,Int,config=True,
130 129 help="""Engine/Client Port pair for MUX queue""")
131 130
132 131 def _mux_default(self):
133 132 return tuple(util.select_random_ports(2))
134 133
135 134 task = Tuple(Int,Int,config=True,
136 135 help="""Engine/Client Port pair for Task queue""")
137 136 def _task_default(self):
138 137 return tuple(util.select_random_ports(2))
139 138
140 139 control = Tuple(Int,Int,config=True,
141 140 help="""Engine/Client Port pair for Control queue""")
142 141
143 142 def _control_default(self):
144 143 return tuple(util.select_random_ports(2))
145 144
146 145 iopub = Tuple(Int,Int,config=True,
147 146 help="""Engine/Client Port pair for IOPub relay""")
148 147
149 148 def _iopub_default(self):
150 149 return tuple(util.select_random_ports(2))
151 150
152 151 # single ports:
153 152 mon_port = Int(config=True,
154 153 help="""Monitor (SUB) port for queue traffic""")
155 154
156 155 def _mon_port_default(self):
157 156 return util.select_random_ports(1)[0]
158 157
159 158 notifier_port = Int(config=True,
160 159 help="""PUB port for sending engine status notifications""")
161 160
162 161 def _notifier_port_default(self):
163 162 return util.select_random_ports(1)[0]
164 163
165 164 engine_ip = Unicode('127.0.0.1', config=True,
166 165 help="IP on which to listen for engine connections. [default: loopback]")
167 166 engine_transport = Unicode('tcp', config=True,
168 167 help="0MQ transport for engine connections. [default: tcp]")
169 168
170 169 client_ip = Unicode('127.0.0.1', config=True,
171 170 help="IP on which to listen for client connections. [default: loopback]")
172 171 client_transport = Unicode('tcp', config=True,
173 172 help="0MQ transport for client connections. [default : tcp]")
174 173
175 174 monitor_ip = Unicode('127.0.0.1', config=True,
176 175 help="IP on which to listen for monitor messages. [default: loopback]")
177 176 monitor_transport = Unicode('tcp', config=True,
178 177 help="0MQ transport for monitor messages. [default : tcp]")
179 178
180 179 monitor_url = Unicode('')
181 180
182 181 db_class = DottedObjectName('IPython.parallel.controller.dictdb.DictDB',
183 182 config=True, help="""The class to use for the DB backend""")
184 183
185 184 # not configurable
186 185 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
187 186 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
188 187
189 188 def _ip_changed(self, name, old, new):
190 189 self.engine_ip = new
191 190 self.client_ip = new
192 191 self.monitor_ip = new
193 192 self._update_monitor_url()
194 193
195 194 def _update_monitor_url(self):
196 195 self.monitor_url = "%s://%s:%i"%(self.monitor_transport, self.monitor_ip, self.mon_port)
197 196
198 197 def _transport_changed(self, name, old, new):
199 198 self.engine_transport = new
200 199 self.client_transport = new
201 200 self.monitor_transport = new
202 201 self._update_monitor_url()
203 202
204 203 def __init__(self, **kwargs):
205 204 super(HubFactory, self).__init__(**kwargs)
206 205 self._update_monitor_url()
207 206
208 207
209 208 def construct(self):
210 209 self.init_hub()
211 210
212 211 def start(self):
213 212 self.heartmonitor.start()
214 213 self.log.info("Heartmonitor started")
215 214
216 215 def init_hub(self):
217 216 """construct"""
218 217 client_iface = "%s://%s:"%(self.client_transport, self.client_ip) + "%i"
219 218 engine_iface = "%s://%s:"%(self.engine_transport, self.engine_ip) + "%i"
220 219
221 220 ctx = self.context
222 221 loop = self.loop
223 222
224 223 # Registrar socket
225 224 q = ZMQStream(ctx.socket(zmq.XREP), loop)
226 225 q.bind(client_iface % self.regport)
227 226 self.log.info("Hub listening on %s for registration."%(client_iface%self.regport))
228 227 if self.client_ip != self.engine_ip:
229 228 q.bind(engine_iface % self.regport)
230 229 self.log.info("Hub listening on %s for registration."%(engine_iface%self.regport))
231 230
232 231 ### Engine connections ###
233 232
234 233 # heartbeat
235 234 hpub = ctx.socket(zmq.PUB)
236 235 hpub.bind(engine_iface % self.hb[0])
237 236 hrep = ctx.socket(zmq.XREP)
238 237 hrep.bind(engine_iface % self.hb[1])
239 238 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
240 239 pingstream=ZMQStream(hpub,loop),
241 240 pongstream=ZMQStream(hrep,loop)
242 241 )
243 242
244 243 ### Client connections ###
245 244 # Notifier socket
246 245 n = ZMQStream(ctx.socket(zmq.PUB), loop)
247 246 n.bind(client_iface%self.notifier_port)
248 247
249 248 ### build and launch the queues ###
250 249
251 250 # monitor socket
252 251 sub = ctx.socket(zmq.SUB)
253 252 sub.setsockopt(zmq.SUBSCRIBE, b"")
254 253 sub.bind(self.monitor_url)
255 254 sub.bind('inproc://monitor')
256 255 sub = ZMQStream(sub, loop)
257 256
258 257 # connect the db
259 258 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
260 259 # cdir = self.config.Global.cluster_dir
261 260 self.db = import_item(str(self.db_class))(session=self.session.session,
262 261 config=self.config, log=self.log)
263 262 time.sleep(.25)
264 263 try:
265 264 scheme = self.config.TaskScheduler.scheme_name
266 265 except AttributeError:
267 266 from .scheduler import TaskScheduler
268 267 scheme = TaskScheduler.scheme_name.get_default_value()
269 268 # build connection dicts
270 269 self.engine_info = {
271 270 'control' : engine_iface%self.control[1],
272 271 'mux': engine_iface%self.mux[1],
273 272 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
274 273 'task' : engine_iface%self.task[1],
275 274 'iopub' : engine_iface%self.iopub[1],
276 275 # 'monitor' : engine_iface%self.mon_port,
277 276 }
278 277
279 278 self.client_info = {
280 279 'control' : client_iface%self.control[0],
281 280 'mux': client_iface%self.mux[0],
282 281 'task' : (scheme, client_iface%self.task[0]),
283 282 'iopub' : client_iface%self.iopub[0],
284 283 'notification': client_iface%self.notifier_port
285 284 }
286 285 self.log.debug("Hub engine addrs: %s"%self.engine_info)
287 286 self.log.debug("Hub client addrs: %s"%self.client_info)
288 287
289 288 # resubmit stream
290 289 r = ZMQStream(ctx.socket(zmq.XREQ), loop)
291 290 url = util.disambiguate_url(self.client_info['task'][-1])
292 291 r.setsockopt(zmq.IDENTITY, util.asbytes(self.session.session))
293 292 r.connect(url)
294 293
295 294 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
296 295 query=q, notifier=n, resubmit=r, db=self.db,
297 296 engine_info=self.engine_info, client_info=self.client_info,
298 297 log=self.log)
299 298
300 299
301 300 class Hub(SessionFactory):
302 301 """The IPython Controller Hub with 0MQ connections
303 302
304 303 Parameters
305 304 ==========
306 305 loop: zmq IOLoop instance
307 306 session: Session object
308 307 <removed> context: zmq context for creating new connections (?)
309 308 queue: ZMQStream for monitoring the command queue (SUB)
310 309 query: ZMQStream for engine registration and client queries requests (XREP)
311 310 heartbeat: HeartMonitor object checking the pulse of the engines
312 311 notifier: ZMQStream for broadcasting engine registration changes (PUB)
313 312 db: connection to db for out of memory logging of commands
314 313 NotImplemented
315 314 engine_info: dict of zmq connection information for engines to connect
316 315 to the queues.
317 316 client_info: dict of zmq connection information for engines to connect
318 317 to the queues.
319 318 """
320 319 # internal data structures:
321 320 ids=Set() # engine IDs
322 321 keytable=Dict()
323 322 by_ident=Dict()
324 323 engines=Dict()
325 324 clients=Dict()
326 325 hearts=Dict()
327 326 pending=Set()
328 327 queues=Dict() # pending msg_ids keyed by engine_id
329 328 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
330 329 completed=Dict() # completed msg_ids keyed by engine_id
331 330 all_completed=Set() # completed msg_ids keyed by engine_id
332 331 dead_engines=Set() # completed msg_ids keyed by engine_id
333 332 unassigned=Set() # set of task msg_ds not yet assigned a destination
334 333 incoming_registrations=Dict()
335 334 registration_timeout=Int()
336 335 _idcounter=Int(0)
337 336
338 337 # objects from constructor:
339 338 query=Instance(ZMQStream)
340 339 monitor=Instance(ZMQStream)
341 340 notifier=Instance(ZMQStream)
342 341 resubmit=Instance(ZMQStream)
343 342 heartmonitor=Instance(HeartMonitor)
344 343 db=Instance(object)
345 344 client_info=Dict()
346 345 engine_info=Dict()
347 346
348 347
349 348 def __init__(self, **kwargs):
350 349 """
351 350 # universal:
352 351 loop: IOLoop for creating future connections
353 352 session: streamsession for sending serialized data
354 353 # engine:
355 354 queue: ZMQStream for monitoring queue messages
356 355 query: ZMQStream for engine+client registration and client requests
357 356 heartbeat: HeartMonitor object for tracking engines
358 357 # extra:
359 358 db: ZMQStream for db connection (NotImplemented)
360 359 engine_info: zmq address/protocol dict for engine connections
361 360 client_info: zmq address/protocol dict for client connections
362 361 """
363 362
364 363 super(Hub, self).__init__(**kwargs)
365 364 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
366 365
367 366 # validate connection dicts:
368 367 for k,v in self.client_info.iteritems():
369 368 if k == 'task':
370 369 util.validate_url_container(v[1])
371 370 else:
372 371 util.validate_url_container(v)
373 372 # util.validate_url_container(self.client_info)
374 373 util.validate_url_container(self.engine_info)
375 374
376 375 # register our callbacks
377 376 self.query.on_recv(self.dispatch_query)
378 377 self.monitor.on_recv(self.dispatch_monitor_traffic)
379 378
380 379 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
381 380 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
382 381
383 382 self.monitor_handlers = {b'in' : self.save_queue_request,
384 383 b'out': self.save_queue_result,
385 384 b'intask': self.save_task_request,
386 385 b'outtask': self.save_task_result,
387 386 b'tracktask': self.save_task_destination,
388 387 b'incontrol': _passer,
389 388 b'outcontrol': _passer,
390 389 b'iopub': self.save_iopub_message,
391 390 }
392 391
393 392 self.query_handlers = {'queue_request': self.queue_status,
394 393 'result_request': self.get_results,
395 394 'history_request': self.get_history,
396 395 'db_request': self.db_query,
397 396 'purge_request': self.purge_results,
398 397 'load_request': self.check_load,
399 398 'resubmit_request': self.resubmit_task,
400 399 'shutdown_request': self.shutdown_request,
401 400 'registration_request' : self.register_engine,
402 401 'unregistration_request' : self.unregister_engine,
403 402 'connection_request': self.connection_request,
404 403 }
405 404
406 405 # ignore resubmit replies
407 406 self.resubmit.on_recv(lambda msg: None, copy=False)
408 407
409 408 self.log.info("hub::created hub")
410 409
411 410 @property
412 411 def _next_id(self):
413 412 """gemerate a new ID.
414 413
415 414 No longer reuse old ids, just count from 0."""
416 415 newid = self._idcounter
417 416 self._idcounter += 1
418 417 return newid
419 418 # newid = 0
420 419 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
421 420 # # print newid, self.ids, self.incoming_registrations
422 421 # while newid in self.ids or newid in incoming:
423 422 # newid += 1
424 423 # return newid
425 424
426 425 #-----------------------------------------------------------------------------
427 426 # message validation
428 427 #-----------------------------------------------------------------------------
429 428
430 429 def _validate_targets(self, targets):
431 430 """turn any valid targets argument into a list of integer ids"""
432 431 if targets is None:
433 432 # default to all
434 433 targets = self.ids
435 434
436 435 if isinstance(targets, (int,str,unicode)):
437 436 # only one target specified
438 437 targets = [targets]
439 438 _targets = []
440 439 for t in targets:
441 440 # map raw identities to ids
442 441 if isinstance(t, (str,unicode)):
443 442 t = self.by_ident.get(t, t)
444 443 _targets.append(t)
445 444 targets = _targets
446 445 bad_targets = [ t for t in targets if t not in self.ids ]
447 446 if bad_targets:
448 447 raise IndexError("No Such Engine: %r"%bad_targets)
449 448 if not targets:
450 449 raise IndexError("No Engines Registered")
451 450 return targets
452 451
453 452 #-----------------------------------------------------------------------------
454 453 # dispatch methods (1 per stream)
455 454 #-----------------------------------------------------------------------------
456 455
457 456
458 457 def dispatch_monitor_traffic(self, msg):
459 458 """all ME and Task queue messages come through here, as well as
460 459 IOPub traffic."""
461 460 self.log.debug("monitor traffic: %r"%msg[:2])
462 461 switch = msg[0]
463 462 try:
464 463 idents, msg = self.session.feed_identities(msg[1:])
465 464 except ValueError:
466 465 idents=[]
467 466 if not idents:
468 467 self.log.error("Bad Monitor Message: %r"%msg)
469 468 return
470 469 handler = self.monitor_handlers.get(switch, None)
471 470 if handler is not None:
472 471 handler(idents, msg)
473 472 else:
474 473 self.log.error("Invalid monitor topic: %r"%switch)
475 474
476 475
477 476 def dispatch_query(self, msg):
478 477 """Route registration requests and queries from clients."""
479 478 try:
480 479 idents, msg = self.session.feed_identities(msg)
481 480 except ValueError:
482 481 idents = []
483 482 if not idents:
484 483 self.log.error("Bad Query Message: %r"%msg)
485 484 return
486 485 client_id = idents[0]
487 486 try:
488 487 msg = self.session.unpack_message(msg, content=True)
489 488 except Exception:
490 489 content = error.wrap_exception()
491 490 self.log.error("Bad Query Message: %r"%msg, exc_info=True)
492 491 self.session.send(self.query, "hub_error", ident=client_id,
493 492 content=content)
494 493 return
495 494 # print client_id, header, parent, content
496 495 #switch on message type:
497 496 msg_type = msg['msg_type']
498 497 self.log.info("client::client %r requested %r"%(client_id, msg_type))
499 498 handler = self.query_handlers.get(msg_type, None)
500 499 try:
501 500 assert handler is not None, "Bad Message Type: %r"%msg_type
502 501 except:
503 502 content = error.wrap_exception()
504 503 self.log.error("Bad Message Type: %r"%msg_type, exc_info=True)
505 504 self.session.send(self.query, "hub_error", ident=client_id,
506 505 content=content)
507 506 return
508 507
509 508 else:
510 509 handler(idents, msg)
511 510
512 511 def dispatch_db(self, msg):
513 512 """"""
514 513 raise NotImplementedError
515 514
516 515 #---------------------------------------------------------------------------
517 516 # handler methods (1 per event)
518 517 #---------------------------------------------------------------------------
519 518
520 519 #----------------------- Heartbeat --------------------------------------
521 520
522 521 def handle_new_heart(self, heart):
523 522 """handler to attach to heartbeater.
524 523 Called when a new heart starts to beat.
525 524 Triggers completion of registration."""
526 525 self.log.debug("heartbeat::handle_new_heart(%r)"%heart)
527 526 if heart not in self.incoming_registrations:
528 527 self.log.info("heartbeat::ignoring new heart: %r"%heart)
529 528 else:
530 529 self.finish_registration(heart)
531 530
532 531
533 532 def handle_heart_failure(self, heart):
534 533 """handler to attach to heartbeater.
535 534 called when a previously registered heart fails to respond to beat request.
536 535 triggers unregistration"""
537 536 self.log.debug("heartbeat::handle_heart_failure(%r)"%heart)
538 537 eid = self.hearts.get(heart, None)
539 538 queue = self.engines[eid].queue
540 539 if eid is None:
541 540 self.log.info("heartbeat::ignoring heart failure %r"%heart)
542 541 else:
543 542 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
544 543
545 544 #----------------------- MUX Queue Traffic ------------------------------
546 545
547 546 def save_queue_request(self, idents, msg):
548 547 if len(idents) < 2:
549 548 self.log.error("invalid identity prefix: %r"%idents)
550 549 return
551 550 queue_id, client_id = idents[:2]
552 551 try:
553 552 msg = self.session.unpack_message(msg)
554 553 except Exception:
555 554 self.log.error("queue::client %r sent invalid message to %r: %r"%(client_id, queue_id, msg), exc_info=True)
556 555 return
557 556
558 557 eid = self.by_ident.get(queue_id, None)
559 558 if eid is None:
560 559 self.log.error("queue::target %r not registered"%queue_id)
561 560 self.log.debug("queue:: valid are: %r"%(self.by_ident.keys()))
562 561 return
563 562 record = init_record(msg)
564 563 msg_id = record['msg_id']
565 564 # Unicode in records
566 565 record['engine_uuid'] = queue_id.decode('ascii')
567 566 record['client_uuid'] = client_id.decode('ascii')
568 567 record['queue'] = 'mux'
569 568
570 569 try:
571 570 # it's posible iopub arrived first:
572 571 existing = self.db.get_record(msg_id)
573 572 for key,evalue in existing.iteritems():
574 573 rvalue = record.get(key, None)
575 574 if evalue and rvalue and evalue != rvalue:
576 575 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
577 576 elif evalue and not rvalue:
578 577 record[key] = evalue
579 578 try:
580 579 self.db.update_record(msg_id, record)
581 580 except Exception:
582 581 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
583 582 except KeyError:
584 583 try:
585 584 self.db.add_record(msg_id, record)
586 585 except Exception:
587 586 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
588 587
589 588
590 589 self.pending.add(msg_id)
591 590 self.queues[eid].append(msg_id)
592 591
593 592 def save_queue_result(self, idents, msg):
594 593 if len(idents) < 2:
595 594 self.log.error("invalid identity prefix: %r"%idents)
596 595 return
597 596
598 597 client_id, queue_id = idents[:2]
599 598 try:
600 599 msg = self.session.unpack_message(msg)
601 600 except Exception:
602 601 self.log.error("queue::engine %r sent invalid message to %r: %r"%(
603 602 queue_id,client_id, msg), exc_info=True)
604 603 return
605 604
606 605 eid = self.by_ident.get(queue_id, None)
607 606 if eid is None:
608 607 self.log.error("queue::unknown engine %r is sending a reply: "%queue_id)
609 608 return
610 609
611 610 parent = msg['parent_header']
612 611 if not parent:
613 612 return
614 613 msg_id = parent['msg_id']
615 614 if msg_id in self.pending:
616 615 self.pending.remove(msg_id)
617 616 self.all_completed.add(msg_id)
618 617 self.queues[eid].remove(msg_id)
619 618 self.completed[eid].append(msg_id)
620 619 elif msg_id not in self.all_completed:
621 620 # it could be a result from a dead engine that died before delivering the
622 621 # result
623 622 self.log.warn("queue:: unknown msg finished %r"%msg_id)
624 623 return
625 624 # update record anyway, because the unregistration could have been premature
626 625 rheader = msg['header']
627 626 completed = rheader['date']
628 627 started = rheader.get('started', None)
629 628 result = {
630 629 'result_header' : rheader,
631 630 'result_content': msg['content'],
632 631 'started' : started,
633 632 'completed' : completed
634 633 }
635 634
636 635 result['result_buffers'] = msg['buffers']
637 636 try:
638 637 self.db.update_record(msg_id, result)
639 638 except Exception:
640 639 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
641 640
642 641
643 642 #--------------------- Task Queue Traffic ------------------------------
644 643
645 644 def save_task_request(self, idents, msg):
646 645 """Save the submission of a task."""
647 646 client_id = idents[0]
648 647
649 648 try:
650 649 msg = self.session.unpack_message(msg)
651 650 except Exception:
652 651 self.log.error("task::client %r sent invalid task message: %r"%(
653 652 client_id, msg), exc_info=True)
654 653 return
655 654 record = init_record(msg)
656 655
657 656 record['client_uuid'] = client_id
658 657 record['queue'] = 'task'
659 658 header = msg['header']
660 659 msg_id = header['msg_id']
661 660 self.pending.add(msg_id)
662 661 self.unassigned.add(msg_id)
663 662 try:
664 663 # it's posible iopub arrived first:
665 664 existing = self.db.get_record(msg_id)
666 665 if existing['resubmitted']:
667 666 for key in ('submitted', 'client_uuid', 'buffers'):
668 667 # don't clobber these keys on resubmit
669 668 # submitted and client_uuid should be different
670 669 # and buffers might be big, and shouldn't have changed
671 670 record.pop(key)
672 671 # still check content,header which should not change
673 672 # but are not expensive to compare as buffers
674 673
675 674 for key,evalue in existing.iteritems():
676 675 if key.endswith('buffers'):
677 676 # don't compare buffers
678 677 continue
679 678 rvalue = record.get(key, None)
680 679 if evalue and rvalue and evalue != rvalue:
681 680 self.log.warn("conflicting initial state for record: %r:%r <%r> %r"%(msg_id, rvalue, key, evalue))
682 681 elif evalue and not rvalue:
683 682 record[key] = evalue
684 683 try:
685 684 self.db.update_record(msg_id, record)
686 685 except Exception:
687 686 self.log.error("DB Error updating record %r"%msg_id, exc_info=True)
688 687 except KeyError:
689 688 try:
690 689 self.db.add_record(msg_id, record)
691 690 except Exception:
692 691 self.log.error("DB Error adding record %r"%msg_id, exc_info=True)
693 692 except Exception:
694 693 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
695 694
696 695 def save_task_result(self, idents, msg):
697 696 """save the result of a completed task."""
698 697 client_id = idents[0]
699 698 try:
700 699 msg = self.session.unpack_message(msg)
701 700 except Exception:
702 701 self.log.error("task::invalid task result message send to %r: %r"%(
703 702 client_id, msg), exc_info=True)
704 703 return
705 704
706 705 parent = msg['parent_header']
707 706 if not parent:
708 707 # print msg
709 708 self.log.warn("Task %r had no parent!"%msg)
710 709 return
711 710 msg_id = parent['msg_id']
712 711 if msg_id in self.unassigned:
713 712 self.unassigned.remove(msg_id)
714 713
715 714 header = msg['header']
716 715 engine_uuid = header.get('engine', None)
717 716 eid = self.by_ident.get(engine_uuid, None)
718 717
719 718 if msg_id in self.pending:
720 719 self.pending.remove(msg_id)
721 720 self.all_completed.add(msg_id)
722 721 if eid is not None:
723 722 self.completed[eid].append(msg_id)
724 723 if msg_id in self.tasks[eid]:
725 724 self.tasks[eid].remove(msg_id)
726 725 completed = header['date']
727 726 started = header.get('started', None)
728 727 result = {
729 728 'result_header' : header,
730 729 'result_content': msg['content'],
731 730 'started' : started,
732 731 'completed' : completed,
733 732 'engine_uuid': engine_uuid
734 733 }
735 734
736 735 result['result_buffers'] = msg['buffers']
737 736 try:
738 737 self.db.update_record(msg_id, result)
739 738 except Exception:
740 739 self.log.error("DB Error saving task request %r"%msg_id, exc_info=True)
741 740
742 741 else:
743 742 self.log.debug("task::unknown task %r finished"%msg_id)
744 743
745 744 def save_task_destination(self, idents, msg):
746 745 try:
747 746 msg = self.session.unpack_message(msg, content=True)
748 747 except Exception:
749 748 self.log.error("task::invalid task tracking message", exc_info=True)
750 749 return
751 750 content = msg['content']
752 751 # print (content)
753 752 msg_id = content['msg_id']
754 753 engine_uuid = content['engine_id']
755 754 eid = self.by_ident[util.asbytes(engine_uuid)]
756 755
757 756 self.log.info("task::task %r arrived on %r"%(msg_id, eid))
758 757 if msg_id in self.unassigned:
759 758 self.unassigned.remove(msg_id)
760 759 # else:
761 760 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
762 761
763 762 self.tasks[eid].append(msg_id)
764 763 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
765 764 try:
766 765 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
767 766 except Exception:
768 767 self.log.error("DB Error saving task destination %r"%msg_id, exc_info=True)
769 768
770 769
771 770 def mia_task_request(self, idents, msg):
772 771 raise NotImplementedError
773 772 client_id = idents[0]
774 773 # content = dict(mia=self.mia,status='ok')
775 774 # self.session.send('mia_reply', content=content, idents=client_id)
776 775
777 776
778 777 #--------------------- IOPub Traffic ------------------------------
779 778
780 779 def save_iopub_message(self, topics, msg):
781 780 """save an iopub message into the db"""
782 781 # print (topics)
783 782 try:
784 783 msg = self.session.unpack_message(msg, content=True)
785 784 except Exception:
786 785 self.log.error("iopub::invalid IOPub message", exc_info=True)
787 786 return
788 787
789 788 parent = msg['parent_header']
790 789 if not parent:
791 790 self.log.error("iopub::invalid IOPub message: %r"%msg)
792 791 return
793 792 msg_id = parent['msg_id']
794 793 msg_type = msg['msg_type']
795 794 content = msg['content']
796 795
797 796 # ensure msg_id is in db
798 797 try:
799 798 rec = self.db.get_record(msg_id)
800 799 except KeyError:
801 800 rec = empty_record()
802 801 rec['msg_id'] = msg_id
803 802 self.db.add_record(msg_id, rec)
804 803 # stream
805 804 d = {}
806 805 if msg_type == 'stream':
807 806 name = content['name']
808 807 s = rec[name] or ''
809 808 d[name] = s + content['data']
810 809
811 810 elif msg_type == 'pyerr':
812 811 d['pyerr'] = content
813 812 elif msg_type == 'pyin':
814 813 d['pyin'] = content['code']
815 814 else:
816 815 d[msg_type] = content.get('data', '')
817 816
818 817 try:
819 818 self.db.update_record(msg_id, d)
820 819 except Exception:
821 820 self.log.error("DB Error saving iopub message %r"%msg_id, exc_info=True)
822 821
823 822
824 823
825 824 #-------------------------------------------------------------------------
826 825 # Registration requests
827 826 #-------------------------------------------------------------------------
828 827
829 828 def connection_request(self, client_id, msg):
830 829 """Reply with connection addresses for clients."""
831 830 self.log.info("client::client %r connected"%client_id)
832 831 content = dict(status='ok')
833 832 content.update(self.client_info)
834 833 jsonable = {}
835 834 for k,v in self.keytable.iteritems():
836 835 if v not in self.dead_engines:
837 836 jsonable[str(k)] = v.decode('ascii')
838 837 content['engines'] = jsonable
839 838 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
840 839
841 840 def register_engine(self, reg, msg):
842 841 """Register a new engine."""
843 842 content = msg['content']
844 843 try:
845 844 queue = util.asbytes(content['queue'])
846 845 except KeyError:
847 846 self.log.error("registration::queue not specified", exc_info=True)
848 847 return
849 848 heart = content.get('heartbeat', None)
850 849 if heart:
851 850 heart = util.asbytes(heart)
852 851 """register a new engine, and create the socket(s) necessary"""
853 852 eid = self._next_id
854 853 # print (eid, queue, reg, heart)
855 854
856 855 self.log.debug("registration::register_engine(%i, %r, %r, %r)"%(eid, queue, reg, heart))
857 856
858 857 content = dict(id=eid,status='ok')
859 858 content.update(self.engine_info)
860 859 # check if requesting available IDs:
861 860 if queue in self.by_ident:
862 861 try:
863 862 raise KeyError("queue_id %r in use"%queue)
864 863 except:
865 864 content = error.wrap_exception()
866 865 self.log.error("queue_id %r in use"%queue, exc_info=True)
867 866 elif heart in self.hearts: # need to check unique hearts?
868 867 try:
869 868 raise KeyError("heart_id %r in use"%heart)
870 869 except:
871 870 self.log.error("heart_id %r in use"%heart, exc_info=True)
872 871 content = error.wrap_exception()
873 872 else:
874 873 for h, pack in self.incoming_registrations.iteritems():
875 874 if heart == h:
876 875 try:
877 876 raise KeyError("heart_id %r in use"%heart)
878 877 except:
879 878 self.log.error("heart_id %r in use"%heart, exc_info=True)
880 879 content = error.wrap_exception()
881 880 break
882 881 elif queue == pack[1]:
883 882 try:
884 883 raise KeyError("queue_id %r in use"%queue)
885 884 except:
886 885 self.log.error("queue_id %r in use"%queue, exc_info=True)
887 886 content = error.wrap_exception()
888 887 break
889 888
890 889 msg = self.session.send(self.query, "registration_reply",
891 890 content=content,
892 891 ident=reg)
893 892
894 893 if content['status'] == 'ok':
895 894 if heart in self.heartmonitor.hearts:
896 895 # already beating
897 896 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
898 897 self.finish_registration(heart)
899 898 else:
900 899 purge = lambda : self._purge_stalled_registration(heart)
901 900 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
902 901 dc.start()
903 902 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
904 903 else:
905 904 self.log.error("registration::registration %i failed: %r"%(eid, content['evalue']))
906 905 return eid
907 906
908 907 def unregister_engine(self, ident, msg):
909 908 """Unregister an engine that explicitly requested to leave."""
910 909 try:
911 910 eid = msg['content']['id']
912 911 except:
913 912 self.log.error("registration::bad engine id for unregistration: %r"%ident, exc_info=True)
914 913 return
915 914 self.log.info("registration::unregister_engine(%r)"%eid)
916 915 # print (eid)
917 916 uuid = self.keytable[eid]
918 917 content=dict(id=eid, queue=uuid.decode('ascii'))
919 918 self.dead_engines.add(uuid)
920 919 # self.ids.remove(eid)
921 920 # uuid = self.keytable.pop(eid)
922 921 #
923 922 # ec = self.engines.pop(eid)
924 923 # self.hearts.pop(ec.heartbeat)
925 924 # self.by_ident.pop(ec.queue)
926 925 # self.completed.pop(eid)
927 926 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
928 927 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
929 928 dc.start()
930 929 ############## TODO: HANDLE IT ################
931 930
932 931 if self.notifier:
933 932 self.session.send(self.notifier, "unregistration_notification", content=content)
934 933
935 934 def _handle_stranded_msgs(self, eid, uuid):
936 935 """Handle messages known to be on an engine when the engine unregisters.
937 936
938 937 It is possible that this will fire prematurely - that is, an engine will
939 938 go down after completing a result, and the client will be notified
940 939 that the result failed and later receive the actual result.
941 940 """
942 941
943 942 outstanding = self.queues[eid]
944 943
945 944 for msg_id in outstanding:
946 945 self.pending.remove(msg_id)
947 946 self.all_completed.add(msg_id)
948 947 try:
949 948 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
950 949 except:
951 950 content = error.wrap_exception()
952 951 # build a fake header:
953 952 header = {}
954 953 header['engine'] = uuid
955 954 header['date'] = datetime.now()
956 955 rec = dict(result_content=content, result_header=header, result_buffers=[])
957 956 rec['completed'] = header['date']
958 957 rec['engine_uuid'] = uuid
959 958 try:
960 959 self.db.update_record(msg_id, rec)
961 960 except Exception:
962 961 self.log.error("DB Error handling stranded msg %r"%msg_id, exc_info=True)
963 962
964 963
965 964 def finish_registration(self, heart):
966 965 """Second half of engine registration, called after our HeartMonitor
967 966 has received a beat from the Engine's Heart."""
968 967 try:
969 968 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
970 969 except KeyError:
971 970 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
972 971 return
973 972 self.log.info("registration::finished registering engine %i:%r"%(eid,queue))
974 973 if purge is not None:
975 974 purge.stop()
976 975 control = queue
977 976 self.ids.add(eid)
978 977 self.keytable[eid] = queue
979 978 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
980 979 control=control, heartbeat=heart)
981 980 self.by_ident[queue] = eid
982 981 self.queues[eid] = list()
983 982 self.tasks[eid] = list()
984 983 self.completed[eid] = list()
985 984 self.hearts[heart] = eid
986 985 content = dict(id=eid, queue=self.engines[eid].queue.decode('ascii'))
987 986 if self.notifier:
988 987 self.session.send(self.notifier, "registration_notification", content=content)
989 988 self.log.info("engine::Engine Connected: %i"%eid)
990 989
991 990 def _purge_stalled_registration(self, heart):
992 991 if heart in self.incoming_registrations:
993 992 eid = self.incoming_registrations.pop(heart)[0]
994 993 self.log.info("registration::purging stalled registration: %i"%eid)
995 994 else:
996 995 pass
997 996
998 997 #-------------------------------------------------------------------------
999 998 # Client Requests
1000 999 #-------------------------------------------------------------------------
1001 1000
1002 1001 def shutdown_request(self, client_id, msg):
1003 1002 """handle shutdown request."""
1004 1003 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1005 1004 # also notify other clients of shutdown
1006 1005 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1007 1006 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1008 1007 dc.start()
1009 1008
1010 1009 def _shutdown(self):
1011 1010 self.log.info("hub::hub shutting down.")
1012 1011 time.sleep(0.1)
1013 1012 sys.exit(0)
1014 1013
1015 1014
1016 1015 def check_load(self, client_id, msg):
1017 1016 content = msg['content']
1018 1017 try:
1019 1018 targets = content['targets']
1020 1019 targets = self._validate_targets(targets)
1021 1020 except:
1022 1021 content = error.wrap_exception()
1023 1022 self.session.send(self.query, "hub_error",
1024 1023 content=content, ident=client_id)
1025 1024 return
1026 1025
1027 1026 content = dict(status='ok')
1028 1027 # loads = {}
1029 1028 for t in targets:
1030 1029 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1031 1030 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1032 1031
1033 1032
1034 1033 def queue_status(self, client_id, msg):
1035 1034 """Return the Queue status of one or more targets.
1036 1035 if verbose: return the msg_ids
1037 1036 else: return len of each type.
1038 1037 keys: queue (pending MUX jobs)
1039 1038 tasks (pending Task jobs)
1040 1039 completed (finished jobs from both queues)"""
1041 1040 content = msg['content']
1042 1041 targets = content['targets']
1043 1042 try:
1044 1043 targets = self._validate_targets(targets)
1045 1044 except:
1046 1045 content = error.wrap_exception()
1047 1046 self.session.send(self.query, "hub_error",
1048 1047 content=content, ident=client_id)
1049 1048 return
1050 1049 verbose = content.get('verbose', False)
1051 1050 content = dict(status='ok')
1052 1051 for t in targets:
1053 1052 queue = self.queues[t]
1054 1053 completed = self.completed[t]
1055 1054 tasks = self.tasks[t]
1056 1055 if not verbose:
1057 1056 queue = len(queue)
1058 1057 completed = len(completed)
1059 1058 tasks = len(tasks)
1060 1059 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1061 1060 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1062 1061 # print (content)
1063 1062 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1064 1063
1065 1064 def purge_results(self, client_id, msg):
1066 1065 """Purge results from memory. This method is more valuable before we move
1067 1066 to a DB based message storage mechanism."""
1068 1067 content = msg['content']
1069 1068 self.log.info("Dropping records with %s", content)
1070 1069 msg_ids = content.get('msg_ids', [])
1071 1070 reply = dict(status='ok')
1072 1071 if msg_ids == 'all':
1073 1072 try:
1074 1073 self.db.drop_matching_records(dict(completed={'$ne':None}))
1075 1074 except Exception:
1076 1075 reply = error.wrap_exception()
1077 1076 else:
1078 1077 pending = filter(lambda m: m in self.pending, msg_ids)
1079 1078 if pending:
1080 1079 try:
1081 1080 raise IndexError("msg pending: %r"%pending[0])
1082 1081 except:
1083 1082 reply = error.wrap_exception()
1084 1083 else:
1085 1084 try:
1086 1085 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1087 1086 except Exception:
1088 1087 reply = error.wrap_exception()
1089 1088
1090 1089 if reply['status'] == 'ok':
1091 1090 eids = content.get('engine_ids', [])
1092 1091 for eid in eids:
1093 1092 if eid not in self.engines:
1094 1093 try:
1095 1094 raise IndexError("No such engine: %i"%eid)
1096 1095 except:
1097 1096 reply = error.wrap_exception()
1098 1097 break
1099 1098 uid = self.engines[eid].queue
1100 1099 try:
1101 1100 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1102 1101 except Exception:
1103 1102 reply = error.wrap_exception()
1104 1103 break
1105 1104
1106 1105 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1107 1106
1108 1107 def resubmit_task(self, client_id, msg):
1109 1108 """Resubmit one or more tasks."""
1110 1109 def finish(reply):
1111 1110 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1112 1111
1113 1112 content = msg['content']
1114 1113 msg_ids = content['msg_ids']
1115 1114 reply = dict(status='ok')
1116 1115 try:
1117 1116 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1118 1117 'header', 'content', 'buffers'])
1119 1118 except Exception:
1120 1119 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1121 1120 return finish(error.wrap_exception())
1122 1121
1123 1122 # validate msg_ids
1124 1123 found_ids = [ rec['msg_id'] for rec in records ]
1125 1124 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1126 1125 if len(records) > len(msg_ids):
1127 1126 try:
1128 1127 raise RuntimeError("DB appears to be in an inconsistent state."
1129 1128 "More matching records were found than should exist")
1130 1129 except Exception:
1131 1130 return finish(error.wrap_exception())
1132 1131 elif len(records) < len(msg_ids):
1133 1132 missing = [ m for m in msg_ids if m not in found_ids ]
1134 1133 try:
1135 1134 raise KeyError("No such msg(s): %r"%missing)
1136 1135 except KeyError:
1137 1136 return finish(error.wrap_exception())
1138 1137 elif invalid_ids:
1139 1138 msg_id = invalid_ids[0]
1140 1139 try:
1141 1140 raise ValueError("Task %r appears to be inflight"%(msg_id))
1142 1141 except Exception:
1143 1142 return finish(error.wrap_exception())
1144 1143
1145 1144 # clear the existing records
1146 1145 now = datetime.now()
1147 1146 rec = empty_record()
1148 1147 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1149 1148 rec['resubmitted'] = now
1150 1149 rec['queue'] = 'task'
1151 1150 rec['client_uuid'] = client_id[0]
1152 1151 try:
1153 1152 for msg_id in msg_ids:
1154 1153 self.all_completed.discard(msg_id)
1155 1154 self.db.update_record(msg_id, rec)
1156 1155 except Exception:
1157 1156 self.log.error('db::db error upating record', exc_info=True)
1158 1157 reply = error.wrap_exception()
1159 1158 else:
1160 1159 # send the messages
1161 1160 for rec in records:
1162 1161 header = rec['header']
1163 1162 # include resubmitted in header to prevent digest collision
1164 1163 header['resubmitted'] = now
1165 1164 msg = self.session.msg(header['msg_type'])
1166 1165 msg['content'] = rec['content']
1167 1166 msg['header'] = header
1168 1167 msg['msg_id'] = rec['msg_id']
1169 1168 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1170 1169
1171 1170 finish(dict(status='ok'))
1172 1171
1173 1172
1174 1173 def _extract_record(self, rec):
1175 1174 """decompose a TaskRecord dict into subsection of reply for get_result"""
1176 1175 io_dict = {}
1177 1176 for key in 'pyin pyout pyerr stdout stderr'.split():
1178 1177 io_dict[key] = rec[key]
1179 1178 content = { 'result_content': rec['result_content'],
1180 1179 'header': rec['header'],
1181 1180 'result_header' : rec['result_header'],
1182 1181 'io' : io_dict,
1183 1182 }
1184 1183 if rec['result_buffers']:
1185 1184 buffers = map(bytes, rec['result_buffers'])
1186 1185 else:
1187 1186 buffers = []
1188 1187
1189 1188 return content, buffers
1190 1189
1191 1190 def get_results(self, client_id, msg):
1192 1191 """Get the result of 1 or more messages."""
1193 1192 content = msg['content']
1194 1193 msg_ids = sorted(set(content['msg_ids']))
1195 1194 statusonly = content.get('status_only', False)
1196 1195 pending = []
1197 1196 completed = []
1198 1197 content = dict(status='ok')
1199 1198 content['pending'] = pending
1200 1199 content['completed'] = completed
1201 1200 buffers = []
1202 1201 if not statusonly:
1203 1202 try:
1204 1203 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1205 1204 # turn match list into dict, for faster lookup
1206 1205 records = {}
1207 1206 for rec in matches:
1208 1207 records[rec['msg_id']] = rec
1209 1208 except Exception:
1210 1209 content = error.wrap_exception()
1211 1210 self.session.send(self.query, "result_reply", content=content,
1212 1211 parent=msg, ident=client_id)
1213 1212 return
1214 1213 else:
1215 1214 records = {}
1216 1215 for msg_id in msg_ids:
1217 1216 if msg_id in self.pending:
1218 1217 pending.append(msg_id)
1219 1218 elif msg_id in self.all_completed:
1220 1219 completed.append(msg_id)
1221 1220 if not statusonly:
1222 1221 c,bufs = self._extract_record(records[msg_id])
1223 1222 content[msg_id] = c
1224 1223 buffers.extend(bufs)
1225 1224 elif msg_id in records:
1226 1225 if rec['completed']:
1227 1226 completed.append(msg_id)
1228 1227 c,bufs = self._extract_record(records[msg_id])
1229 1228 content[msg_id] = c
1230 1229 buffers.extend(bufs)
1231 1230 else:
1232 1231 pending.append(msg_id)
1233 1232 else:
1234 1233 try:
1235 1234 raise KeyError('No such message: '+msg_id)
1236 1235 except:
1237 1236 content = error.wrap_exception()
1238 1237 break
1239 1238 self.session.send(self.query, "result_reply", content=content,
1240 1239 parent=msg, ident=client_id,
1241 1240 buffers=buffers)
1242 1241
1243 1242 def get_history(self, client_id, msg):
1244 1243 """Get a list of all msg_ids in our DB records"""
1245 1244 try:
1246 1245 msg_ids = self.db.get_history()
1247 1246 except Exception as e:
1248 1247 content = error.wrap_exception()
1249 1248 else:
1250 1249 content = dict(status='ok', history=msg_ids)
1251 1250
1252 1251 self.session.send(self.query, "history_reply", content=content,
1253 1252 parent=msg, ident=client_id)
1254 1253
1255 1254 def db_query(self, client_id, msg):
1256 1255 """Perform a raw query on the task record database."""
1257 1256 content = msg['content']
1258 1257 query = content.get('query', {})
1259 1258 keys = content.get('keys', None)
1260 1259 buffers = []
1261 1260 empty = list()
1262 1261 try:
1263 1262 records = self.db.find_records(query, keys)
1264 1263 except Exception as e:
1265 1264 content = error.wrap_exception()
1266 1265 else:
1267 1266 # extract buffers from reply content:
1268 1267 if keys is not None:
1269 1268 buffer_lens = [] if 'buffers' in keys else None
1270 1269 result_buffer_lens = [] if 'result_buffers' in keys else None
1271 1270 else:
1272 1271 buffer_lens = []
1273 1272 result_buffer_lens = []
1274 1273
1275 1274 for rec in records:
1276 1275 # buffers may be None, so double check
1277 1276 if buffer_lens is not None:
1278 1277 b = rec.pop('buffers', empty) or empty
1279 1278 buffer_lens.append(len(b))
1280 1279 buffers.extend(b)
1281 1280 if result_buffer_lens is not None:
1282 1281 rb = rec.pop('result_buffers', empty) or empty
1283 1282 result_buffer_lens.append(len(rb))
1284 1283 buffers.extend(rb)
1285 1284 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1286 1285 result_buffer_lens=result_buffer_lens)
1287 1286 # self.log.debug (content)
1288 1287 self.session.send(self.query, "db_reply", content=content,
1289 1288 parent=msg, ident=client_id,
1290 1289 buffers=buffers)
1291 1290
@@ -1,174 +1,173 b''
1 #!/usr/bin/env python
2 1 """A simple engine that talks to a controller over 0MQ.
3 2 it handles registration, etc. and launches a kernel
4 3 connected to the Controller's Schedulers.
5 4
6 5 Authors:
7 6
8 7 * Min RK
9 8 """
10 9 #-----------------------------------------------------------------------------
11 10 # Copyright (C) 2010-2011 The IPython Development Team
12 11 #
13 12 # Distributed under the terms of the BSD License. The full license is in
14 13 # the file COPYING, distributed as part of this software.
15 14 #-----------------------------------------------------------------------------
16 15
17 16 from __future__ import print_function
18 17
19 18 import sys
20 19 import time
21 20
22 21 import zmq
23 22 from zmq.eventloop import ioloop, zmqstream
24 23
25 24 # internal
26 25 from IPython.utils.traitlets import Instance, Dict, Int, Type, CFloat, Unicode, CBytes
27 26 # from IPython.utils.localinterfaces import LOCALHOST
28 27
29 28 from IPython.parallel.controller.heartmonitor import Heart
30 29 from IPython.parallel.factory import RegistrationFactory
31 30 from IPython.parallel.util import disambiguate_url, asbytes
32 31
33 32 from IPython.zmq.session import Message
34 33
35 34 from .streamkernel import Kernel
36 35
37 36 class EngineFactory(RegistrationFactory):
38 37 """IPython engine"""
39 38
40 39 # configurables:
41 40 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True,
42 41 help="""The OutStream for handling stdout/err.
43 42 Typically 'IPython.zmq.iostream.OutStream'""")
44 43 display_hook_factory=Type('IPython.zmq.displayhook.ZMQDisplayHook', config=True,
45 44 help="""The class for handling displayhook.
46 45 Typically 'IPython.zmq.displayhook.ZMQDisplayHook'""")
47 46 location=Unicode(config=True,
48 47 help="""The location (an IP address) of the controller. This is
49 48 used for disambiguating URLs, to determine whether
50 49 loopback should be used to connect or the public address.""")
51 50 timeout=CFloat(2,config=True,
52 51 help="""The time (in seconds) to wait for the Controller to respond
53 52 to registration requests before giving up.""")
54 53
55 54 # not configurable:
56 55 user_ns=Dict()
57 56 id=Int(allow_none=True)
58 57 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
59 58 kernel=Instance(Kernel)
60 59
61 60 bident = CBytes()
62 61 ident = Unicode()
63 62 def _ident_changed(self, name, old, new):
64 63 self.bident = asbytes(new)
65 64
66 65
67 66 def __init__(self, **kwargs):
68 67 super(EngineFactory, self).__init__(**kwargs)
69 68 self.ident = self.session.session
70 69 ctx = self.context
71 70
72 71 reg = ctx.socket(zmq.XREQ)
73 72 reg.setsockopt(zmq.IDENTITY, self.bident)
74 73 reg.connect(self.url)
75 74 self.registrar = zmqstream.ZMQStream(reg, self.loop)
76 75
77 76 def register(self):
78 77 """send the registration_request"""
79 78
80 79 self.log.info("Registering with controller at %s"%self.url)
81 80 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
82 81 self.registrar.on_recv(self.complete_registration)
83 82 # print (self.session.key)
84 83 self.session.send(self.registrar, "registration_request",content=content)
85 84
86 85 def complete_registration(self, msg):
87 86 # print msg
88 87 self._abort_dc.stop()
89 88 ctx = self.context
90 89 loop = self.loop
91 90 identity = self.bident
92 91 idents,msg = self.session.feed_identities(msg)
93 92 msg = Message(self.session.unpack_message(msg))
94 93
95 94 if msg.content.status == 'ok':
96 95 self.id = int(msg.content.id)
97 96
98 97 # create Shell Streams (MUX, Task, etc.):
99 98 queue_addr = msg.content.mux
100 99 shell_addrs = [ str(queue_addr) ]
101 100 task_addr = msg.content.task
102 101 if task_addr:
103 102 shell_addrs.append(str(task_addr))
104 103
105 104 # Uncomment this to go back to two-socket model
106 105 # shell_streams = []
107 106 # for addr in shell_addrs:
108 107 # stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
109 108 # stream.setsockopt(zmq.IDENTITY, identity)
110 109 # stream.connect(disambiguate_url(addr, self.location))
111 110 # shell_streams.append(stream)
112 111
113 112 # Now use only one shell stream for mux and tasks
114 113 stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
115 114 stream.setsockopt(zmq.IDENTITY, identity)
116 115 shell_streams = [stream]
117 116 for addr in shell_addrs:
118 117 stream.connect(disambiguate_url(addr, self.location))
119 118 # end single stream-socket
120 119
121 120 # control stream:
122 121 control_addr = str(msg.content.control)
123 122 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.XREP), loop)
124 123 control_stream.setsockopt(zmq.IDENTITY, identity)
125 124 control_stream.connect(disambiguate_url(control_addr, self.location))
126 125
127 126 # create iopub stream:
128 127 iopub_addr = msg.content.iopub
129 128 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
130 129 iopub_stream.setsockopt(zmq.IDENTITY, identity)
131 130 iopub_stream.connect(disambiguate_url(iopub_addr, self.location))
132 131
133 132 # launch heartbeat
134 133 hb_addrs = msg.content.heartbeat
135 134 # print (hb_addrs)
136 135
137 136 # # Redirect input streams and set a display hook.
138 137 if self.out_stream_factory:
139 138 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
140 139 sys.stdout.topic = 'engine.%i.stdout'%self.id
141 140 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
142 141 sys.stderr.topic = 'engine.%i.stderr'%self.id
143 142 if self.display_hook_factory:
144 143 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
145 144 sys.displayhook.topic = 'engine.%i.pyout'%self.id
146 145
147 146 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
148 147 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
149 148 loop=loop, user_ns = self.user_ns, log=self.log)
150 149 self.kernel.start()
151 150 hb_addrs = [ disambiguate_url(addr, self.location) for addr in hb_addrs ]
152 151 heart = Heart(*map(str, hb_addrs), heart_id=identity)
153 152 heart.start()
154 153
155 154
156 155 else:
157 156 self.log.fatal("Registration Failed: %s"%msg)
158 157 raise Exception("Registration Failed: %s"%msg)
159 158
160 159 self.log.info("Completed registration with id %i"%self.id)
161 160
162 161
163 162 def abort(self):
164 163 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
165 164 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
166 165 time.sleep(1)
167 166 sys.exit(255)
168 167
169 168 def start(self):
170 169 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
171 170 dc.start()
172 171 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
173 172 self._abort_dc.start()
174 173
@@ -1,438 +1,437 b''
1 #!/usr/bin/env python
2 1 """
3 2 Kernel adapted from kernel.py to use ZMQ Streams
4 3
5 4 Authors:
6 5
7 6 * Min RK
8 7 * Brian Granger
9 8 * Fernando Perez
10 9 * Evan Patterson
11 10 """
12 11 #-----------------------------------------------------------------------------
13 12 # Copyright (C) 2010-2011 The IPython Development Team
14 13 #
15 14 # Distributed under the terms of the BSD License. The full license is in
16 15 # the file COPYING, distributed as part of this software.
17 16 #-----------------------------------------------------------------------------
18 17
19 18 #-----------------------------------------------------------------------------
20 19 # Imports
21 20 #-----------------------------------------------------------------------------
22 21
23 22 # Standard library imports.
24 23 from __future__ import print_function
25 24
26 25 import sys
27 26 import time
28 27
29 28 from code import CommandCompiler
30 29 from datetime import datetime
31 30 from pprint import pprint
32 31
33 32 # System library imports.
34 33 import zmq
35 34 from zmq.eventloop import ioloop, zmqstream
36 35
37 36 # Local imports.
38 37 from IPython.utils.traitlets import Instance, List, Int, Dict, Set, Unicode, CBytes
39 38 from IPython.zmq.completer import KernelCompleter
40 39
41 40 from IPython.parallel.error import wrap_exception
42 41 from IPython.parallel.factory import SessionFactory
43 42 from IPython.parallel.util import serialize_object, unpack_apply_message, asbytes
44 43
45 44 def printer(*args):
46 45 pprint(args, stream=sys.__stdout__)
47 46
48 47
49 48 class _Passer(zmqstream.ZMQStream):
50 49 """Empty class that implements `send()` that does nothing.
51 50
52 51 Subclass ZMQStream for Session typechecking
53 52
54 53 """
55 54 def __init__(self, *args, **kwargs):
56 55 pass
57 56
58 57 def send(self, *args, **kwargs):
59 58 pass
60 59 send_multipart = send
61 60
62 61
63 62 #-----------------------------------------------------------------------------
64 63 # Main kernel class
65 64 #-----------------------------------------------------------------------------
66 65
67 66 class Kernel(SessionFactory):
68 67
69 68 #---------------------------------------------------------------------------
70 69 # Kernel interface
71 70 #---------------------------------------------------------------------------
72 71
73 72 # kwargs:
74 73 exec_lines = List(Unicode, config=True,
75 74 help="List of lines to execute")
76 75
77 76 # identities:
78 77 int_id = Int(-1)
79 78 bident = CBytes()
80 79 ident = Unicode()
81 80 def _ident_changed(self, name, old, new):
82 81 self.bident = asbytes(new)
83 82
84 83 user_ns = Dict(config=True, help="""Set the user's namespace of the Kernel""")
85 84
86 85 control_stream = Instance(zmqstream.ZMQStream)
87 86 task_stream = Instance(zmqstream.ZMQStream)
88 87 iopub_stream = Instance(zmqstream.ZMQStream)
89 88 client = Instance('IPython.parallel.Client')
90 89
91 90 # internals
92 91 shell_streams = List()
93 92 compiler = Instance(CommandCompiler, (), {})
94 93 completer = Instance(KernelCompleter)
95 94
96 95 aborted = Set()
97 96 shell_handlers = Dict()
98 97 control_handlers = Dict()
99 98
100 99 def _set_prefix(self):
101 100 self.prefix = "engine.%s"%self.int_id
102 101
103 102 def _connect_completer(self):
104 103 self.completer = KernelCompleter(self.user_ns)
105 104
106 105 def __init__(self, **kwargs):
107 106 super(Kernel, self).__init__(**kwargs)
108 107 self._set_prefix()
109 108 self._connect_completer()
110 109
111 110 self.on_trait_change(self._set_prefix, 'id')
112 111 self.on_trait_change(self._connect_completer, 'user_ns')
113 112
114 113 # Build dict of handlers for message types
115 114 for msg_type in ['execute_request', 'complete_request', 'apply_request',
116 115 'clear_request']:
117 116 self.shell_handlers[msg_type] = getattr(self, msg_type)
118 117
119 118 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
120 119 self.control_handlers[msg_type] = getattr(self, msg_type)
121 120
122 121 self._initial_exec_lines()
123 122
124 123 def _wrap_exception(self, method=None):
125 124 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
126 125 content=wrap_exception(e_info)
127 126 return content
128 127
129 128 def _initial_exec_lines(self):
130 129 s = _Passer()
131 130 content = dict(silent=True, user_variable=[],user_expressions=[])
132 131 for line in self.exec_lines:
133 132 self.log.debug("executing initialization: %s"%line)
134 133 content.update({'code':line})
135 134 msg = self.session.msg('execute_request', content)
136 135 self.execute_request(s, [], msg)
137 136
138 137
139 138 #-------------------- control handlers -----------------------------
140 139 def abort_queues(self):
141 140 for stream in self.shell_streams:
142 141 if stream:
143 142 self.abort_queue(stream)
144 143
145 144 def abort_queue(self, stream):
146 145 while True:
147 146 idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
148 147 if msg is None:
149 148 return
150 149
151 150 self.log.info("Aborting:")
152 151 self.log.info(str(msg))
153 152 msg_type = msg['msg_type']
154 153 reply_type = msg_type.split('_')[0] + '_reply'
155 154 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
156 155 # self.reply_socket.send(ident,zmq.SNDMORE)
157 156 # self.reply_socket.send_json(reply_msg)
158 157 reply_msg = self.session.send(stream, reply_type,
159 158 content={'status' : 'aborted'}, parent=msg, ident=idents)
160 159 self.log.debug(str(reply_msg))
161 160 # We need to wait a bit for requests to come in. This can probably
162 161 # be set shorter for true asynchronous clients.
163 162 time.sleep(0.05)
164 163
165 164 def abort_request(self, stream, ident, parent):
166 165 """abort a specifig msg by id"""
167 166 msg_ids = parent['content'].get('msg_ids', None)
168 167 if isinstance(msg_ids, basestring):
169 168 msg_ids = [msg_ids]
170 169 if not msg_ids:
171 170 self.abort_queues()
172 171 for mid in msg_ids:
173 172 self.aborted.add(str(mid))
174 173
175 174 content = dict(status='ok')
176 175 reply_msg = self.session.send(stream, 'abort_reply', content=content,
177 176 parent=parent, ident=ident)
178 177 self.log.debug(str(reply_msg))
179 178
180 179 def shutdown_request(self, stream, ident, parent):
181 180 """kill ourself. This should really be handled in an external process"""
182 181 try:
183 182 self.abort_queues()
184 183 except:
185 184 content = self._wrap_exception('shutdown')
186 185 else:
187 186 content = dict(parent['content'])
188 187 content['status'] = 'ok'
189 188 msg = self.session.send(stream, 'shutdown_reply',
190 189 content=content, parent=parent, ident=ident)
191 190 self.log.debug(str(msg))
192 191 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
193 192 dc.start()
194 193
195 194 def dispatch_control(self, msg):
196 195 idents,msg = self.session.feed_identities(msg, copy=False)
197 196 try:
198 197 msg = self.session.unpack_message(msg, content=True, copy=False)
199 198 except:
200 199 self.log.error("Invalid Message", exc_info=True)
201 200 return
202 201 else:
203 202 self.log.debug("Control received, %s", msg)
204 203
205 204 header = msg['header']
206 205 msg_id = header['msg_id']
207 206
208 207 handler = self.control_handlers.get(msg['msg_type'], None)
209 208 if handler is None:
210 209 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg['msg_type'])
211 210 else:
212 211 handler(self.control_stream, idents, msg)
213 212
214 213
215 214 #-------------------- queue helpers ------------------------------
216 215
217 216 def check_dependencies(self, dependencies):
218 217 if not dependencies:
219 218 return True
220 219 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
221 220 anyorall = dependencies[0]
222 221 dependencies = dependencies[1]
223 222 else:
224 223 anyorall = 'all'
225 224 results = self.client.get_results(dependencies,status_only=True)
226 225 if results['status'] != 'ok':
227 226 return False
228 227
229 228 if anyorall == 'any':
230 229 if not results['completed']:
231 230 return False
232 231 else:
233 232 if results['pending']:
234 233 return False
235 234
236 235 return True
237 236
238 237 def check_aborted(self, msg_id):
239 238 return msg_id in self.aborted
240 239
241 240 #-------------------- queue handlers -----------------------------
242 241
243 242 def clear_request(self, stream, idents, parent):
244 243 """Clear our namespace."""
245 244 self.user_ns = {}
246 245 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
247 246 content = dict(status='ok'))
248 247 self._initial_exec_lines()
249 248
250 249 def execute_request(self, stream, ident, parent):
251 250 self.log.debug('execute request %s'%parent)
252 251 try:
253 252 code = parent[u'content'][u'code']
254 253 except:
255 254 self.log.error("Got bad msg: %s"%parent, exc_info=True)
256 255 return
257 256 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
258 257 ident=asbytes('%s.pyin'%self.prefix))
259 258 started = datetime.now()
260 259 try:
261 260 comp_code = self.compiler(code, '<zmq-kernel>')
262 261 # allow for not overriding displayhook
263 262 if hasattr(sys.displayhook, 'set_parent'):
264 263 sys.displayhook.set_parent(parent)
265 264 sys.stdout.set_parent(parent)
266 265 sys.stderr.set_parent(parent)
267 266 exec comp_code in self.user_ns, self.user_ns
268 267 except:
269 268 exc_content = self._wrap_exception('execute')
270 269 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
271 270 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
272 271 ident=asbytes('%s.pyerr'%self.prefix))
273 272 reply_content = exc_content
274 273 else:
275 274 reply_content = {'status' : 'ok'}
276 275
277 276 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
278 277 ident=ident, subheader = dict(started=started))
279 278 self.log.debug(str(reply_msg))
280 279 if reply_msg['content']['status'] == u'error':
281 280 self.abort_queues()
282 281
283 282 def complete_request(self, stream, ident, parent):
284 283 matches = {'matches' : self.complete(parent),
285 284 'status' : 'ok'}
286 285 completion_msg = self.session.send(stream, 'complete_reply',
287 286 matches, parent, ident)
288 287 # print >> sys.__stdout__, completion_msg
289 288
290 289 def complete(self, msg):
291 290 return self.completer.complete(msg.content.line, msg.content.text)
292 291
293 292 def apply_request(self, stream, ident, parent):
294 293 # flush previous reply, so this request won't block it
295 294 stream.flush(zmq.POLLOUT)
296 295 try:
297 296 content = parent[u'content']
298 297 bufs = parent[u'buffers']
299 298 msg_id = parent['header']['msg_id']
300 299 # bound = parent['header'].get('bound', False)
301 300 except:
302 301 self.log.error("Got bad msg: %s"%parent, exc_info=True)
303 302 return
304 303 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
305 304 # self.iopub_stream.send(pyin_msg)
306 305 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
307 306 sub = {'dependencies_met' : True, 'engine' : self.ident,
308 307 'started': datetime.now()}
309 308 try:
310 309 # allow for not overriding displayhook
311 310 if hasattr(sys.displayhook, 'set_parent'):
312 311 sys.displayhook.set_parent(parent)
313 312 sys.stdout.set_parent(parent)
314 313 sys.stderr.set_parent(parent)
315 314 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
316 315 working = self.user_ns
317 316 # suffix =
318 317 prefix = "_"+str(msg_id).replace("-","")+"_"
319 318
320 319 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
321 320 # if bound:
322 321 # bound_ns = Namespace(working)
323 322 # args = [bound_ns]+list(args)
324 323
325 324 fname = getattr(f, '__name__', 'f')
326 325
327 326 fname = prefix+"f"
328 327 argname = prefix+"args"
329 328 kwargname = prefix+"kwargs"
330 329 resultname = prefix+"result"
331 330
332 331 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
333 332 # print ns
334 333 working.update(ns)
335 334 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
336 335 try:
337 336 exec code in working,working
338 337 result = working.get(resultname)
339 338 finally:
340 339 for key in ns.iterkeys():
341 340 working.pop(key)
342 341 # if bound:
343 342 # working.update(bound_ns)
344 343
345 344 packed_result,buf = serialize_object(result)
346 345 result_buf = [packed_result]+buf
347 346 except:
348 347 exc_content = self._wrap_exception('apply')
349 348 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
350 349 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
351 350 ident=asbytes('%s.pyerr'%self.prefix))
352 351 reply_content = exc_content
353 352 result_buf = []
354 353
355 354 if exc_content['ename'] == 'UnmetDependency':
356 355 sub['dependencies_met'] = False
357 356 else:
358 357 reply_content = {'status' : 'ok'}
359 358
360 359 # put 'ok'/'error' status in header, for scheduler introspection:
361 360 sub['status'] = reply_content['status']
362 361
363 362 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
364 363 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
365 364
366 365 # flush i/o
367 366 # should this be before reply_msg is sent, like in the single-kernel code,
368 367 # or should nothing get in the way of real results?
369 368 sys.stdout.flush()
370 369 sys.stderr.flush()
371 370
372 371 def dispatch_queue(self, stream, msg):
373 372 self.control_stream.flush()
374 373 idents,msg = self.session.feed_identities(msg, copy=False)
375 374 try:
376 375 msg = self.session.unpack_message(msg, content=True, copy=False)
377 376 except:
378 377 self.log.error("Invalid Message", exc_info=True)
379 378 return
380 379 else:
381 380 self.log.debug("Message received, %s", msg)
382 381
383 382
384 383 header = msg['header']
385 384 msg_id = header['msg_id']
386 385 if self.check_aborted(msg_id):
387 386 self.aborted.remove(msg_id)
388 387 # is it safe to assume a msg_id will not be resubmitted?
389 388 reply_type = msg['msg_type'].split('_')[0] + '_reply'
390 389 status = {'status' : 'aborted'}
391 390 reply_msg = self.session.send(stream, reply_type, subheader=status,
392 391 content=status, parent=msg, ident=idents)
393 392 return
394 393 handler = self.shell_handlers.get(msg['msg_type'], None)
395 394 if handler is None:
396 395 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg['msg_type'])
397 396 else:
398 397 handler(stream, idents, msg)
399 398
400 399 def start(self):
401 400 #### stream mode:
402 401 if self.control_stream:
403 402 self.control_stream.on_recv(self.dispatch_control, copy=False)
404 403 self.control_stream.on_err(printer)
405 404
406 405 def make_dispatcher(stream):
407 406 def dispatcher(msg):
408 407 return self.dispatch_queue(stream, msg)
409 408 return dispatcher
410 409
411 410 for s in self.shell_streams:
412 411 s.on_recv(make_dispatcher(s), copy=False)
413 412 s.on_err(printer)
414 413
415 414 if self.iopub_stream:
416 415 self.iopub_stream.on_err(printer)
417 416
418 417 #### while True mode:
419 418 # while True:
420 419 # idle = True
421 420 # try:
422 421 # msg = self.shell_stream.socket.recv_multipart(
423 422 # zmq.NOBLOCK, copy=False)
424 423 # except zmq.ZMQError, e:
425 424 # if e.errno != zmq.EAGAIN:
426 425 # raise e
427 426 # else:
428 427 # idle=False
429 428 # self.dispatch_queue(self.shell_stream, msg)
430 429 #
431 430 # if not self.task_stream.empty():
432 431 # idle=False
433 432 # msg = self.task_stream.recv_multipart()
434 433 # self.dispatch_queue(self.task_stream, msg)
435 434 # if idle:
436 435 # # don't busywait
437 436 # time.sleep(1e-3)
438 437
@@ -1,38 +1,35 b''
1 #!/usr/bin/env python
2
3
4 1 """
5 2 Add %global magic for GNU Global usage.
6 3
7 4 http://www.gnu.org/software/global/
8 5
9 6 """
10 7
11 8 from IPython.core import ipapi
12 9 ip = ipapi.get()
13 10 import os
14 11
15 12 # alter to your liking
16 13 global_bin = 'd:/opt/global/bin/global'
17 14
18 15 def global_f(self,cmdline):
19 16 simple = 0
20 17 if '-' not in cmdline:
21 18 cmdline = '-rx ' + cmdline
22 19 simple = 1
23 20
24 21 lines = [l.rstrip() for l in os.popen( global_bin + ' ' + cmdline ).readlines()]
25 22
26 23 if simple:
27 24 parts = [l.split(None,3) for l in lines]
28 25 lines = ['%s [%s]\n%s' % (p[2].rjust(70),p[1],p[3].rstrip()) for p in parts]
29 26 print "\n".join(lines)
30 27
31 28 ip.define_magic('global', global_f)
32 29
33 30 def global_completer(self,event):
34 31 compl = [l.rstrip() for l in os.popen(global_bin + ' -c ' + event.symbol).readlines()]
35 32 return compl
36 33
37 34 ip.set_hook('complete_command', global_completer, str_key = '%global')
38 35
@@ -1,68 +1,65 b''
1 #!/usr/bin/env python
2
3 1 """ IPython extension: Render templates from variables and paste to clipbard """
4 2
5 3 from IPython.core import ipapi
6 4
7 5 ip = ipapi.get()
8 6
9 7 from string import Template
10 8 import sys,os
11 9
12 10 from IPython.external.Itpl import itplns
13 11
14 12 def toclip_w32(s):
15 13 """ Places contents of s to clipboard
16 14
17 15 Needs pyvin32 to work:
18 16 http://sourceforge.net/projects/pywin32/
19 17 """
20 18 import win32clipboard as cl
21 19 import win32con
22 20 cl.OpenClipboard()
23 21 cl.EmptyClipboard()
24 22 cl.SetClipboardText( s.replace('\n','\r\n' ))
25 23 cl.CloseClipboard()
26 24
27 25 try:
28 26 import win32clipboard
29 27 toclip = toclip_w32
30 28 except ImportError:
31 29 def toclip(s): pass
32 30
33 31
34 32 def render(tmpl):
35 33 """ Render a template (Itpl format) from ipython variables
36 34
37 35 Example:
38 36
39 37 $ import ipy_render
40 38 $ my_name = 'Bob' # %store this for convenience
41 39 $ t_submission_form = "Submission report, author: $my_name" # %store also
42 40 $ render t_submission_form
43 41
44 42 => returns "Submission report, author: Bob" and copies to clipboard on win32
45 43
46 44 # if template exist as a file, read it. Note: ;f hei vaan => f("hei vaan")
47 45 $ ;render c:/templates/greeting.txt
48 46
49 47 Template examples (Ka-Ping Yee's Itpl library):
50 48
51 49 Here is a $string.
52 50 Here is a $module.member.
53 51 Here is an $object.member.
54 52 Here is a $functioncall(with, arguments).
55 53 Here is an ${arbitrary + expression}.
56 54 Here is an $array[3] member.
57 55 Here is a $dictionary['member'].
58 56 """
59 57
60 58 if os.path.isfile(tmpl):
61 59 tmpl = open(tmpl).read()
62 60
63 61 res = itplns(tmpl, ip.user_ns)
64 62 toclip(res)
65 63 return res
66 64
67 65 ip.push('render')
68 No newline at end of file
@@ -1,43 +1,41 b''
1 #!/usr/bin/env python
2
3 1 from IPython.core import ipapi
4 2 ip = ipapi.get()
5 3
6 4 import os, subprocess
7 5
8 6 workdir = None
9 7 def workdir_f(ip,line):
10 8 """ Exceute commands residing in cwd elsewhere
11 9
12 10 Example::
13 11
14 12 workdir /myfiles
15 13 cd bin
16 14 workdir myscript.py
17 15
18 16 executes myscript.py (stored in bin, but not in path) in /myfiles
19 17 """
20 18 global workdir
21 19 dummy,cmd = line.split(None,1)
22 20 if os.path.isdir(cmd):
23 21 workdir = os.path.abspath(cmd)
24 22 print "Set workdir",workdir
25 23 elif workdir is None:
26 24 print "Please set workdir first by doing e.g. 'workdir q:/'"
27 25 else:
28 26 sp = cmd.split(None,1)
29 27 if len(sp) == 1:
30 28 head, tail = cmd, ''
31 29 else:
32 30 head, tail = sp
33 31 if os.path.isfile(head):
34 32 cmd = os.path.abspath(head) + ' ' + tail
35 33 print "Execute command '" + cmd+ "' in",workdir
36 34 olddir = os.getcwdu()
37 35 os.chdir(workdir)
38 36 try:
39 37 os.system(cmd)
40 38 finally:
41 39 os.chdir(olddir)
42 40
43 41 ip.define_alias("workdir",workdir_f)
1 NO CONTENT: modified file chmod 100644 => 100755
@@ -1,74 +1,73 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 Tests for testing.tools
5 4 """
6 5
7 6 #-----------------------------------------------------------------------------
8 7 # Copyright (C) 2008-2009 The IPython Development Team
9 8 #
10 9 # Distributed under the terms of the BSD License. The full license is in
11 10 # the file COPYING, distributed as part of this software.
12 11 #-----------------------------------------------------------------------------
13 12
14 13 #-----------------------------------------------------------------------------
15 14 # Imports
16 15 #-----------------------------------------------------------------------------
17 16 from __future__ import with_statement
18 17
19 18 import os
20 19 import sys
21 20
22 21 import nose.tools as nt
23 22
24 23 from IPython.testing import decorators as dec
25 24 from IPython.testing import tools as tt
26 25
27 26 #-----------------------------------------------------------------------------
28 27 # Tests
29 28 #-----------------------------------------------------------------------------
30 29
31 30 @dec.skip_win32
32 31 def test_full_path_posix():
33 32 spath = '/foo/bar.py'
34 33 result = tt.full_path(spath,['a.txt','b.txt'])
35 34 nt.assert_equal(result, ['/foo/a.txt', '/foo/b.txt'])
36 35 spath = '/foo'
37 36 result = tt.full_path(spath,['a.txt','b.txt'])
38 37 nt.assert_equal(result, ['/a.txt', '/b.txt'])
39 38 result = tt.full_path(spath,'a.txt')
40 39 nt.assert_equal(result, ['/a.txt'])
41 40
42 41
43 42 @dec.skip_if_not_win32
44 43 def test_full_path_win32():
45 44 spath = 'c:\\foo\\bar.py'
46 45 result = tt.full_path(spath,['a.txt','b.txt'])
47 46 nt.assert_equal(result, ['c:\\foo\\a.txt', 'c:\\foo\\b.txt'])
48 47 spath = 'c:\\foo'
49 48 result = tt.full_path(spath,['a.txt','b.txt'])
50 49 nt.assert_equal(result, ['c:\\a.txt', 'c:\\b.txt'])
51 50 result = tt.full_path(spath,'a.txt')
52 51 nt.assert_equal(result, ['c:\\a.txt'])
53 52
54 53
55 54 @dec.parametric
56 55 def test_parser():
57 56 err = ("FAILED (errors=1)", 1, 0)
58 57 fail = ("FAILED (failures=1)", 0, 1)
59 58 both = ("FAILED (errors=1, failures=1)", 1, 1)
60 59 for txt, nerr, nfail in [err, fail, both]:
61 60 nerr1, nfail1 = tt.parse_test_output(txt)
62 61 yield nt.assert_equal(nerr, nerr1)
63 62 yield nt.assert_equal(nfail, nfail1)
64 63
65 64
66 65 @dec.parametric
67 66 def test_temp_pyfile():
68 67 src = 'pass\n'
69 68 fname, fh = tt.temp_pyfile(src)
70 69 yield nt.assert_true(os.path.isfile(fname))
71 70 fh.close()
72 71 with open(fname) as fh2:
73 72 src2 = fh2.read()
74 73 yield nt.assert_equal(src2, src)
@@ -1,396 +1,395 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """A dict subclass that supports attribute style access.
4 3
5 4 Authors:
6 5
7 6 * Fernando Perez (original)
8 7 * Brian Granger (refactoring to a dict subclass)
9 8 """
10 9
11 10 #-----------------------------------------------------------------------------
12 11 # Copyright (C) 2008-2009 The IPython Development Team
13 12 #
14 13 # Distributed under the terms of the BSD License. The full license is in
15 14 # the file COPYING, distributed as part of this software.
16 15 #-----------------------------------------------------------------------------
17 16
18 17 #-----------------------------------------------------------------------------
19 18 # Imports
20 19 #-----------------------------------------------------------------------------
21 20
22 21 from IPython.utils.data import list2dict2
23 22
24 23 __all__ = ['Struct']
25 24
26 25 #-----------------------------------------------------------------------------
27 26 # Code
28 27 #-----------------------------------------------------------------------------
29 28
30 29
31 30 class Struct(dict):
32 31 """A dict subclass with attribute style access.
33 32
34 33 This dict subclass has a a few extra features:
35 34
36 35 * Attribute style access.
37 36 * Protection of class members (like keys, items) when using attribute
38 37 style access.
39 38 * The ability to restrict assignment to only existing keys.
40 39 * Intelligent merging.
41 40 * Overloaded operators.
42 41 """
43 42 _allownew = True
44 43 def __init__(self, *args, **kw):
45 44 """Initialize with a dictionary, another Struct, or data.
46 45
47 46 Parameters
48 47 ----------
49 48 args : dict, Struct
50 49 Initialize with one dict or Struct
51 50 kw : dict
52 51 Initialize with key, value pairs.
53 52
54 53 Examples
55 54 --------
56 55
57 56 >>> s = Struct(a=10,b=30)
58 57 >>> s.a
59 58 10
60 59 >>> s.b
61 60 30
62 61 >>> s2 = Struct(s,c=30)
63 62 >>> s2.keys()
64 63 ['a', 'c', 'b']
65 64 """
66 65 object.__setattr__(self, '_allownew', True)
67 66 dict.__init__(self, *args, **kw)
68 67
69 68 def __setitem__(self, key, value):
70 69 """Set an item with check for allownew.
71 70
72 71 Examples
73 72 --------
74 73
75 74 >>> s = Struct()
76 75 >>> s['a'] = 10
77 76 >>> s.allow_new_attr(False)
78 77 >>> s['a'] = 10
79 78 >>> s['a']
80 79 10
81 80 >>> try:
82 81 ... s['b'] = 20
83 82 ... except KeyError:
84 83 ... print 'this is not allowed'
85 84 ...
86 85 this is not allowed
87 86 """
88 87 if not self._allownew and not self.has_key(key):
89 88 raise KeyError(
90 89 "can't create new attribute %s when allow_new_attr(False)" % key)
91 90 dict.__setitem__(self, key, value)
92 91
93 92 def __setattr__(self, key, value):
94 93 """Set an attr with protection of class members.
95 94
96 95 This calls :meth:`self.__setitem__` but convert :exc:`KeyError` to
97 96 :exc:`AttributeError`.
98 97
99 98 Examples
100 99 --------
101 100
102 101 >>> s = Struct()
103 102 >>> s.a = 10
104 103 >>> s.a
105 104 10
106 105 >>> try:
107 106 ... s.get = 10
108 107 ... except AttributeError:
109 108 ... print "you can't set a class member"
110 109 ...
111 110 you can't set a class member
112 111 """
113 112 # If key is an str it might be a class member or instance var
114 113 if isinstance(key, str):
115 114 # I can't simply call hasattr here because it calls getattr, which
116 115 # calls self.__getattr__, which returns True for keys in
117 116 # self._data. But I only want keys in the class and in
118 117 # self.__dict__
119 118 if key in self.__dict__ or hasattr(Struct, key):
120 119 raise AttributeError(
121 120 'attr %s is a protected member of class Struct.' % key
122 121 )
123 122 try:
124 123 self.__setitem__(key, value)
125 124 except KeyError, e:
126 125 raise AttributeError(e)
127 126
128 127 def __getattr__(self, key):
129 128 """Get an attr by calling :meth:`dict.__getitem__`.
130 129
131 130 Like :meth:`__setattr__`, this method converts :exc:`KeyError` to
132 131 :exc:`AttributeError`.
133 132
134 133 Examples
135 134 --------
136 135
137 136 >>> s = Struct(a=10)
138 137 >>> s.a
139 138 10
140 139 >>> type(s.get)
141 140 <type 'builtin_function_or_method'>
142 141 >>> try:
143 142 ... s.b
144 143 ... except AttributeError:
145 144 ... print "I don't have that key"
146 145 ...
147 146 I don't have that key
148 147 """
149 148 try:
150 149 result = self[key]
151 150 except KeyError:
152 151 raise AttributeError(key)
153 152 else:
154 153 return result
155 154
156 155 def __iadd__(self, other):
157 156 """s += s2 is a shorthand for s.merge(s2).
158 157
159 158 Examples
160 159 --------
161 160
162 161 >>> s = Struct(a=10,b=30)
163 162 >>> s2 = Struct(a=20,c=40)
164 163 >>> s += s2
165 164 >>> s
166 165 {'a': 10, 'c': 40, 'b': 30}
167 166 """
168 167 self.merge(other)
169 168 return self
170 169
171 170 def __add__(self,other):
172 171 """s + s2 -> New Struct made from s.merge(s2).
173 172
174 173 Examples
175 174 --------
176 175
177 176 >>> s1 = Struct(a=10,b=30)
178 177 >>> s2 = Struct(a=20,c=40)
179 178 >>> s = s1 + s2
180 179 >>> s
181 180 {'a': 10, 'c': 40, 'b': 30}
182 181 """
183 182 sout = self.copy()
184 183 sout.merge(other)
185 184 return sout
186 185
187 186 def __sub__(self,other):
188 187 """s1 - s2 -> remove keys in s2 from s1.
189 188
190 189 Examples
191 190 --------
192 191
193 192 >>> s1 = Struct(a=10,b=30)
194 193 >>> s2 = Struct(a=40)
195 194 >>> s = s1 - s2
196 195 >>> s
197 196 {'b': 30}
198 197 """
199 198 sout = self.copy()
200 199 sout -= other
201 200 return sout
202 201
203 202 def __isub__(self,other):
204 203 """Inplace remove keys from self that are in other.
205 204
206 205 Examples
207 206 --------
208 207
209 208 >>> s1 = Struct(a=10,b=30)
210 209 >>> s2 = Struct(a=40)
211 210 >>> s1 -= s2
212 211 >>> s1
213 212 {'b': 30}
214 213 """
215 214 for k in other.keys():
216 215 if self.has_key(k):
217 216 del self[k]
218 217 return self
219 218
220 219 def __dict_invert(self, data):
221 220 """Helper function for merge.
222 221
223 222 Takes a dictionary whose values are lists and returns a dict with
224 223 the elements of each list as keys and the original keys as values.
225 224 """
226 225 outdict = {}
227 226 for k,lst in data.items():
228 227 if isinstance(lst, str):
229 228 lst = lst.split()
230 229 for entry in lst:
231 230 outdict[entry] = k
232 231 return outdict
233 232
234 233 def dict(self):
235 234 return self
236 235
237 236 def copy(self):
238 237 """Return a copy as a Struct.
239 238
240 239 Examples
241 240 --------
242 241
243 242 >>> s = Struct(a=10,b=30)
244 243 >>> s2 = s.copy()
245 244 >>> s2
246 245 {'a': 10, 'b': 30}
247 246 >>> type(s2).__name__
248 247 'Struct'
249 248 """
250 249 return Struct(dict.copy(self))
251 250
252 251 def hasattr(self, key):
253 252 """hasattr function available as a method.
254 253
255 254 Implemented like has_key.
256 255
257 256 Examples
258 257 --------
259 258
260 259 >>> s = Struct(a=10)
261 260 >>> s.hasattr('a')
262 261 True
263 262 >>> s.hasattr('b')
264 263 False
265 264 >>> s.hasattr('get')
266 265 False
267 266 """
268 267 return self.has_key(key)
269 268
270 269 def allow_new_attr(self, allow = True):
271 270 """Set whether new attributes can be created in this Struct.
272 271
273 272 This can be used to catch typos by verifying that the attribute user
274 273 tries to change already exists in this Struct.
275 274 """
276 275 object.__setattr__(self, '_allownew', allow)
277 276
278 277 def merge(self, __loc_data__=None, __conflict_solve=None, **kw):
279 278 """Merge two Structs with customizable conflict resolution.
280 279
281 280 This is similar to :meth:`update`, but much more flexible. First, a
282 281 dict is made from data+key=value pairs. When merging this dict with
283 282 the Struct S, the optional dictionary 'conflict' is used to decide
284 283 what to do.
285 284
286 285 If conflict is not given, the default behavior is to preserve any keys
287 286 with their current value (the opposite of the :meth:`update` method's
288 287 behavior).
289 288
290 289 Parameters
291 290 ----------
292 291 __loc_data : dict, Struct
293 292 The data to merge into self
294 293 __conflict_solve : dict
295 294 The conflict policy dict. The keys are binary functions used to
296 295 resolve the conflict and the values are lists of strings naming
297 296 the keys the conflict resolution function applies to. Instead of
298 297 a list of strings a space separated string can be used, like
299 298 'a b c'.
300 299 kw : dict
301 300 Additional key, value pairs to merge in
302 301
303 302 Notes
304 303 -----
305 304
306 305 The `__conflict_solve` dict is a dictionary of binary functions which will be used to
307 306 solve key conflicts. Here is an example::
308 307
309 308 __conflict_solve = dict(
310 309 func1=['a','b','c'],
311 310 func2=['d','e']
312 311 )
313 312
314 313 In this case, the function :func:`func1` will be used to resolve
315 314 keys 'a', 'b' and 'c' and the function :func:`func2` will be used for
316 315 keys 'd' and 'e'. This could also be written as::
317 316
318 317 __conflict_solve = dict(func1='a b c',func2='d e')
319 318
320 319 These functions will be called for each key they apply to with the
321 320 form::
322 321
323 322 func1(self['a'], other['a'])
324 323
325 324 The return value is used as the final merged value.
326 325
327 326 As a convenience, merge() provides five (the most commonly needed)
328 327 pre-defined policies: preserve, update, add, add_flip and add_s. The
329 328 easiest explanation is their implementation::
330 329
331 330 preserve = lambda old,new: old
332 331 update = lambda old,new: new
333 332 add = lambda old,new: old + new
334 333 add_flip = lambda old,new: new + old # note change of order!
335 334 add_s = lambda old,new: old + ' ' + new # only for str!
336 335
337 336 You can use those four words (as strings) as keys instead
338 337 of defining them as functions, and the merge method will substitute
339 338 the appropriate functions for you.
340 339
341 340 For more complicated conflict resolution policies, you still need to
342 341 construct your own functions.
343 342
344 343 Examples
345 344 --------
346 345
347 346 This show the default policy:
348 347
349 348 >>> s = Struct(a=10,b=30)
350 349 >>> s2 = Struct(a=20,c=40)
351 350 >>> s.merge(s2)
352 351 >>> s
353 352 {'a': 10, 'c': 40, 'b': 30}
354 353
355 354 Now, show how to specify a conflict dict:
356 355
357 356 >>> s = Struct(a=10,b=30)
358 357 >>> s2 = Struct(a=20,b=40)
359 358 >>> conflict = {'update':'a','add':'b'}
360 359 >>> s.merge(s2,conflict)
361 360 >>> s
362 361 {'a': 20, 'b': 70}
363 362 """
364 363
365 364 data_dict = dict(__loc_data__,**kw)
366 365
367 366 # policies for conflict resolution: two argument functions which return
368 367 # the value that will go in the new struct
369 368 preserve = lambda old,new: old
370 369 update = lambda old,new: new
371 370 add = lambda old,new: old + new
372 371 add_flip = lambda old,new: new + old # note change of order!
373 372 add_s = lambda old,new: old + ' ' + new
374 373
375 374 # default policy is to keep current keys when there's a conflict
376 375 conflict_solve = list2dict2(self.keys(), default = preserve)
377 376
378 377 # the conflict_solve dictionary is given by the user 'inverted': we
379 378 # need a name-function mapping, it comes as a function -> names
380 379 # dict. Make a local copy (b/c we'll make changes), replace user
381 380 # strings for the three builtin policies and invert it.
382 381 if __conflict_solve:
383 382 inv_conflict_solve_user = __conflict_solve.copy()
384 383 for name, func in [('preserve',preserve), ('update',update),
385 384 ('add',add), ('add_flip',add_flip),
386 385 ('add_s',add_s)]:
387 386 if name in inv_conflict_solve_user.keys():
388 387 inv_conflict_solve_user[func] = inv_conflict_solve_user[name]
389 388 del inv_conflict_solve_user[name]
390 389 conflict_solve.update(self.__dict_invert(inv_conflict_solve_user))
391 390 for key in data_dict:
392 391 if key not in self:
393 392 self[key] = data_dict[key]
394 393 else:
395 394 self[key] = conflict_solve[key](self[key],data_dict[key])
396 395
@@ -1,143 +1,142 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 The IPython Core Notification Center.
5 4
6 5 See docs/source/development/notification_blueprint.txt for an overview of the
7 6 notification module.
8 7
9 8 Authors:
10 9
11 10 * Barry Wark
12 11 * Brian Granger
13 12 """
14 13
15 14 #-----------------------------------------------------------------------------
16 15 # Copyright (C) 2008-2009 The IPython Development Team
17 16 #
18 17 # Distributed under the terms of the BSD License. The full license is in
19 18 # the file COPYING, distributed as part of this software.
20 19 #-----------------------------------------------------------------------------
21 20
22 21 #-----------------------------------------------------------------------------
23 22 # Code
24 23 #-----------------------------------------------------------------------------
25 24
26 25
27 26 class NotificationError(Exception):
28 27 pass
29 28
30 29
31 30 class NotificationCenter(object):
32 31 """Synchronous notification center.
33 32
34 33 Examples
35 34 --------
36 35 Here is a simple example of how to use this::
37 36
38 37 import IPython.util.notification as notification
39 38 def callback(ntype, theSender, args={}):
40 39 print ntype,theSender,args
41 40
42 41 notification.sharedCenter.add_observer(callback, 'NOTIFICATION_TYPE', None)
43 42 notification.sharedCenter.post_notification('NOTIFICATION_TYPE', object()) # doctest:+ELLIPSIS
44 43 NOTIFICATION_TYPE ...
45 44 """
46 45 def __init__(self):
47 46 super(NotificationCenter, self).__init__()
48 47 self._init_observers()
49 48
50 49 def _init_observers(self):
51 50 """Initialize observer storage"""
52 51
53 52 self.registered_types = set() #set of types that are observed
54 53 self.registered_senders = set() #set of senders that are observed
55 54 self.observers = {} #map (type,sender) => callback (callable)
56 55
57 56 def post_notification(self, ntype, sender, *args, **kwargs):
58 57 """Post notification to all registered observers.
59 58
60 59 The registered callback will be called as::
61 60
62 61 callback(ntype, sender, *args, **kwargs)
63 62
64 63 Parameters
65 64 ----------
66 65 ntype : hashable
67 66 The notification type.
68 67 sender : hashable
69 68 The object sending the notification.
70 69 *args : tuple
71 70 The positional arguments to be passed to the callback.
72 71 **kwargs : dict
73 72 The keyword argument to be passed to the callback.
74 73
75 74 Notes
76 75 -----
77 76 * If no registered observers, performance is O(1).
78 77 * Notificaiton order is undefined.
79 78 * Notifications are posted synchronously.
80 79 """
81 80
82 81 if(ntype==None or sender==None):
83 82 raise NotificationError(
84 83 "Notification type and sender are required.")
85 84
86 85 # If there are no registered observers for the type/sender pair
87 86 if((ntype not in self.registered_types and
88 87 None not in self.registered_types) or
89 88 (sender not in self.registered_senders and
90 89 None not in self.registered_senders)):
91 90 return
92 91
93 92 for o in self._observers_for_notification(ntype, sender):
94 93 o(ntype, sender, *args, **kwargs)
95 94
96 95 def _observers_for_notification(self, ntype, sender):
97 96 """Find all registered observers that should recieve notification"""
98 97
99 98 keys = (
100 99 (ntype,sender),
101 100 (ntype, None),
102 101 (None, sender),
103 102 (None,None)
104 103 )
105 104
106 105 obs = set()
107 106 for k in keys:
108 107 obs.update(self.observers.get(k, set()))
109 108
110 109 return obs
111 110
112 111 def add_observer(self, callback, ntype, sender):
113 112 """Add an observer callback to this notification center.
114 113
115 114 The given callback will be called upon posting of notifications of
116 115 the given type/sender and will receive any additional arguments passed
117 116 to post_notification.
118 117
119 118 Parameters
120 119 ----------
121 120 callback : callable
122 121 The callable that will be called by :meth:`post_notification`
123 122 as ``callback(ntype, sender, *args, **kwargs)
124 123 ntype : hashable
125 124 The notification type. If None, all notifications from sender
126 125 will be posted.
127 126 sender : hashable
128 127 The notification sender. If None, all notifications of ntype
129 128 will be posted.
130 129 """
131 130 assert(callback != None)
132 131 self.registered_types.add(ntype)
133 132 self.registered_senders.add(sender)
134 133 self.observers.setdefault((ntype,sender), set()).add(callback)
135 134
136 135 def remove_all_observers(self):
137 136 """Removes all observers from this notification center"""
138 137
139 138 self._init_observers()
140 139
141 140
142 141
143 142 shared_center = NotificationCenter()
@@ -1,70 +1,69 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 Context managers for adding things to sys.path temporarily.
5 4
6 5 Authors:
7 6
8 7 * Brian Granger
9 8 """
10 9
11 10 #-----------------------------------------------------------------------------
12 11 # Copyright (C) 2008-2009 The IPython Development Team
13 12 #
14 13 # Distributed under the terms of the BSD License. The full license is in
15 14 # the file COPYING, distributed as part of this software.
16 15 #-----------------------------------------------------------------------------
17 16
18 17 #-----------------------------------------------------------------------------
19 18 # Imports
20 19 #-----------------------------------------------------------------------------
21 20
22 21 import sys
23 22
24 23 #-----------------------------------------------------------------------------
25 24 # Code
26 25 #-----------------------------------------------------------------------------
27 26
28 27 class appended_to_syspath(object):
29 28 """A context for appending a directory to sys.path for a second."""
30 29
31 30 def __init__(self, dir):
32 31 self.dir = dir
33 32
34 33 def __enter__(self):
35 34 if self.dir not in sys.path:
36 35 sys.path.append(self.dir)
37 36 self.added = True
38 37 else:
39 38 self.added = False
40 39
41 40 def __exit__(self, type, value, traceback):
42 41 if self.added:
43 42 try:
44 43 sys.path.remove(self.dir)
45 44 except ValueError:
46 45 pass
47 46 # Returning False causes any exceptions to be re-raised.
48 47 return False
49 48
50 49 class prepended_to_syspath(object):
51 50 """A context for prepending a directory to sys.path for a second."""
52 51
53 52 def __init__(self, dir):
54 53 self.dir = dir
55 54
56 55 def __enter__(self):
57 56 if self.dir not in sys.path:
58 57 sys.path.insert(0,self.dir)
59 58 self.added = True
60 59 else:
61 60 self.added = False
62 61
63 62 def __exit__(self, type, value, traceback):
64 63 if self.added:
65 64 try:
66 65 sys.path.remove(self.dir)
67 66 except ValueError:
68 67 pass
69 68 # Returning False causes any exceptions to be re-raised.
70 69 return False
@@ -1,847 +1,846 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 Tests for IPython.utils.traitlets.
5 4
6 5 Authors:
7 6
8 7 * Brian Granger
9 8 * Enthought, Inc. Some of the code in this file comes from enthought.traits
10 9 and is licensed under the BSD license. Also, many of the ideas also come
11 10 from enthought.traits even though our implementation is very different.
12 11 """
13 12
14 13 #-----------------------------------------------------------------------------
15 14 # Copyright (C) 2008-2009 The IPython Development Team
16 15 #
17 16 # Distributed under the terms of the BSD License. The full license is in
18 17 # the file COPYING, distributed as part of this software.
19 18 #-----------------------------------------------------------------------------
20 19
21 20 #-----------------------------------------------------------------------------
22 21 # Imports
23 22 #-----------------------------------------------------------------------------
24 23
25 24 import sys
26 25 from unittest import TestCase
27 26
28 27 from IPython.utils.traitlets import (
29 28 HasTraits, MetaHasTraits, TraitType, Any, CBytes,
30 29 Int, Long, Float, Complex, Bytes, Unicode, TraitError,
31 30 Undefined, Type, This, Instance, TCPAddress, List, Tuple,
32 31 ObjectName, DottedObjectName
33 32 )
34 33
35 34
36 35 #-----------------------------------------------------------------------------
37 36 # Helper classes for testing
38 37 #-----------------------------------------------------------------------------
39 38
40 39
41 40 class HasTraitsStub(HasTraits):
42 41
43 42 def _notify_trait(self, name, old, new):
44 43 self._notify_name = name
45 44 self._notify_old = old
46 45 self._notify_new = new
47 46
48 47
49 48 #-----------------------------------------------------------------------------
50 49 # Test classes
51 50 #-----------------------------------------------------------------------------
52 51
53 52
54 53 class TestTraitType(TestCase):
55 54
56 55 def test_get_undefined(self):
57 56 class A(HasTraits):
58 57 a = TraitType
59 58 a = A()
60 59 self.assertEquals(a.a, Undefined)
61 60
62 61 def test_set(self):
63 62 class A(HasTraitsStub):
64 63 a = TraitType
65 64
66 65 a = A()
67 66 a.a = 10
68 67 self.assertEquals(a.a, 10)
69 68 self.assertEquals(a._notify_name, 'a')
70 69 self.assertEquals(a._notify_old, Undefined)
71 70 self.assertEquals(a._notify_new, 10)
72 71
73 72 def test_validate(self):
74 73 class MyTT(TraitType):
75 74 def validate(self, inst, value):
76 75 return -1
77 76 class A(HasTraitsStub):
78 77 tt = MyTT
79 78
80 79 a = A()
81 80 a.tt = 10
82 81 self.assertEquals(a.tt, -1)
83 82
84 83 def test_default_validate(self):
85 84 class MyIntTT(TraitType):
86 85 def validate(self, obj, value):
87 86 if isinstance(value, int):
88 87 return value
89 88 self.error(obj, value)
90 89 class A(HasTraits):
91 90 tt = MyIntTT(10)
92 91 a = A()
93 92 self.assertEquals(a.tt, 10)
94 93
95 94 # Defaults are validated when the HasTraits is instantiated
96 95 class B(HasTraits):
97 96 tt = MyIntTT('bad default')
98 97 self.assertRaises(TraitError, B)
99 98
100 99 def test_is_valid_for(self):
101 100 class MyTT(TraitType):
102 101 def is_valid_for(self, value):
103 102 return True
104 103 class A(HasTraits):
105 104 tt = MyTT
106 105
107 106 a = A()
108 107 a.tt = 10
109 108 self.assertEquals(a.tt, 10)
110 109
111 110 def test_value_for(self):
112 111 class MyTT(TraitType):
113 112 def value_for(self, value):
114 113 return 20
115 114 class A(HasTraits):
116 115 tt = MyTT
117 116
118 117 a = A()
119 118 a.tt = 10
120 119 self.assertEquals(a.tt, 20)
121 120
122 121 def test_info(self):
123 122 class A(HasTraits):
124 123 tt = TraitType
125 124 a = A()
126 125 self.assertEquals(A.tt.info(), 'any value')
127 126
128 127 def test_error(self):
129 128 class A(HasTraits):
130 129 tt = TraitType
131 130 a = A()
132 131 self.assertRaises(TraitError, A.tt.error, a, 10)
133 132
134 133 def test_dynamic_initializer(self):
135 134 class A(HasTraits):
136 135 x = Int(10)
137 136 def _x_default(self):
138 137 return 11
139 138 class B(A):
140 139 x = Int(20)
141 140 class C(A):
142 141 def _x_default(self):
143 142 return 21
144 143
145 144 a = A()
146 145 self.assertEquals(a._trait_values, {})
147 146 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
148 147 self.assertEquals(a.x, 11)
149 148 self.assertEquals(a._trait_values, {'x': 11})
150 149 b = B()
151 150 self.assertEquals(b._trait_values, {'x': 20})
152 151 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
153 152 self.assertEquals(b.x, 20)
154 153 c = C()
155 154 self.assertEquals(c._trait_values, {})
156 155 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
157 156 self.assertEquals(c.x, 21)
158 157 self.assertEquals(c._trait_values, {'x': 21})
159 158 # Ensure that the base class remains unmolested when the _default
160 159 # initializer gets overridden in a subclass.
161 160 a = A()
162 161 c = C()
163 162 self.assertEquals(a._trait_values, {})
164 163 self.assertEquals(a._trait_dyn_inits.keys(), ['x'])
165 164 self.assertEquals(a.x, 11)
166 165 self.assertEquals(a._trait_values, {'x': 11})
167 166
168 167
169 168
170 169 class TestHasTraitsMeta(TestCase):
171 170
172 171 def test_metaclass(self):
173 172 self.assertEquals(type(HasTraits), MetaHasTraits)
174 173
175 174 class A(HasTraits):
176 175 a = Int
177 176
178 177 a = A()
179 178 self.assertEquals(type(a.__class__), MetaHasTraits)
180 179 self.assertEquals(a.a,0)
181 180 a.a = 10
182 181 self.assertEquals(a.a,10)
183 182
184 183 class B(HasTraits):
185 184 b = Int()
186 185
187 186 b = B()
188 187 self.assertEquals(b.b,0)
189 188 b.b = 10
190 189 self.assertEquals(b.b,10)
191 190
192 191 class C(HasTraits):
193 192 c = Int(30)
194 193
195 194 c = C()
196 195 self.assertEquals(c.c,30)
197 196 c.c = 10
198 197 self.assertEquals(c.c,10)
199 198
200 199 def test_this_class(self):
201 200 class A(HasTraits):
202 201 t = This()
203 202 tt = This()
204 203 class B(A):
205 204 tt = This()
206 205 ttt = This()
207 206 self.assertEquals(A.t.this_class, A)
208 207 self.assertEquals(B.t.this_class, A)
209 208 self.assertEquals(B.tt.this_class, B)
210 209 self.assertEquals(B.ttt.this_class, B)
211 210
212 211 class TestHasTraitsNotify(TestCase):
213 212
214 213 def setUp(self):
215 214 self._notify1 = []
216 215 self._notify2 = []
217 216
218 217 def notify1(self, name, old, new):
219 218 self._notify1.append((name, old, new))
220 219
221 220 def notify2(self, name, old, new):
222 221 self._notify2.append((name, old, new))
223 222
224 223 def test_notify_all(self):
225 224
226 225 class A(HasTraits):
227 226 a = Int
228 227 b = Float
229 228
230 229 a = A()
231 230 a.on_trait_change(self.notify1)
232 231 a.a = 0
233 232 self.assertEquals(len(self._notify1),0)
234 233 a.b = 0.0
235 234 self.assertEquals(len(self._notify1),0)
236 235 a.a = 10
237 236 self.assert_(('a',0,10) in self._notify1)
238 237 a.b = 10.0
239 238 self.assert_(('b',0.0,10.0) in self._notify1)
240 239 self.assertRaises(TraitError,setattr,a,'a','bad string')
241 240 self.assertRaises(TraitError,setattr,a,'b','bad string')
242 241 self._notify1 = []
243 242 a.on_trait_change(self.notify1,remove=True)
244 243 a.a = 20
245 244 a.b = 20.0
246 245 self.assertEquals(len(self._notify1),0)
247 246
248 247 def test_notify_one(self):
249 248
250 249 class A(HasTraits):
251 250 a = Int
252 251 b = Float
253 252
254 253 a = A()
255 254 a.on_trait_change(self.notify1, 'a')
256 255 a.a = 0
257 256 self.assertEquals(len(self._notify1),0)
258 257 a.a = 10
259 258 self.assert_(('a',0,10) in self._notify1)
260 259 self.assertRaises(TraitError,setattr,a,'a','bad string')
261 260
262 261 def test_subclass(self):
263 262
264 263 class A(HasTraits):
265 264 a = Int
266 265
267 266 class B(A):
268 267 b = Float
269 268
270 269 b = B()
271 270 self.assertEquals(b.a,0)
272 271 self.assertEquals(b.b,0.0)
273 272 b.a = 100
274 273 b.b = 100.0
275 274 self.assertEquals(b.a,100)
276 275 self.assertEquals(b.b,100.0)
277 276
278 277 def test_notify_subclass(self):
279 278
280 279 class A(HasTraits):
281 280 a = Int
282 281
283 282 class B(A):
284 283 b = Float
285 284
286 285 b = B()
287 286 b.on_trait_change(self.notify1, 'a')
288 287 b.on_trait_change(self.notify2, 'b')
289 288 b.a = 0
290 289 b.b = 0.0
291 290 self.assertEquals(len(self._notify1),0)
292 291 self.assertEquals(len(self._notify2),0)
293 292 b.a = 10
294 293 b.b = 10.0
295 294 self.assert_(('a',0,10) in self._notify1)
296 295 self.assert_(('b',0.0,10.0) in self._notify2)
297 296
298 297 def test_static_notify(self):
299 298
300 299 class A(HasTraits):
301 300 a = Int
302 301 _notify1 = []
303 302 def _a_changed(self, name, old, new):
304 303 self._notify1.append((name, old, new))
305 304
306 305 a = A()
307 306 a.a = 0
308 307 # This is broken!!!
309 308 self.assertEquals(len(a._notify1),0)
310 309 a.a = 10
311 310 self.assert_(('a',0,10) in a._notify1)
312 311
313 312 class B(A):
314 313 b = Float
315 314 _notify2 = []
316 315 def _b_changed(self, name, old, new):
317 316 self._notify2.append((name, old, new))
318 317
319 318 b = B()
320 319 b.a = 10
321 320 b.b = 10.0
322 321 self.assert_(('a',0,10) in b._notify1)
323 322 self.assert_(('b',0.0,10.0) in b._notify2)
324 323
325 324 def test_notify_args(self):
326 325
327 326 def callback0():
328 327 self.cb = ()
329 328 def callback1(name):
330 329 self.cb = (name,)
331 330 def callback2(name, new):
332 331 self.cb = (name, new)
333 332 def callback3(name, old, new):
334 333 self.cb = (name, old, new)
335 334
336 335 class A(HasTraits):
337 336 a = Int
338 337
339 338 a = A()
340 339 a.on_trait_change(callback0, 'a')
341 340 a.a = 10
342 341 self.assertEquals(self.cb,())
343 342 a.on_trait_change(callback0, 'a', remove=True)
344 343
345 344 a.on_trait_change(callback1, 'a')
346 345 a.a = 100
347 346 self.assertEquals(self.cb,('a',))
348 347 a.on_trait_change(callback1, 'a', remove=True)
349 348
350 349 a.on_trait_change(callback2, 'a')
351 350 a.a = 1000
352 351 self.assertEquals(self.cb,('a',1000))
353 352 a.on_trait_change(callback2, 'a', remove=True)
354 353
355 354 a.on_trait_change(callback3, 'a')
356 355 a.a = 10000
357 356 self.assertEquals(self.cb,('a',1000,10000))
358 357 a.on_trait_change(callback3, 'a', remove=True)
359 358
360 359 self.assertEquals(len(a._trait_notifiers['a']),0)
361 360
362 361
363 362 class TestHasTraits(TestCase):
364 363
365 364 def test_trait_names(self):
366 365 class A(HasTraits):
367 366 i = Int
368 367 f = Float
369 368 a = A()
370 369 self.assertEquals(a.trait_names(),['i','f'])
371 370 self.assertEquals(A.class_trait_names(),['i','f'])
372 371
373 372 def test_trait_metadata(self):
374 373 class A(HasTraits):
375 374 i = Int(config_key='MY_VALUE')
376 375 a = A()
377 376 self.assertEquals(a.trait_metadata('i','config_key'), 'MY_VALUE')
378 377
379 378 def test_traits(self):
380 379 class A(HasTraits):
381 380 i = Int
382 381 f = Float
383 382 a = A()
384 383 self.assertEquals(a.traits(), dict(i=A.i, f=A.f))
385 384 self.assertEquals(A.class_traits(), dict(i=A.i, f=A.f))
386 385
387 386 def test_traits_metadata(self):
388 387 class A(HasTraits):
389 388 i = Int(config_key='VALUE1', other_thing='VALUE2')
390 389 f = Float(config_key='VALUE3', other_thing='VALUE2')
391 390 j = Int(0)
392 391 a = A()
393 392 self.assertEquals(a.traits(), dict(i=A.i, f=A.f, j=A.j))
394 393 traits = a.traits(config_key='VALUE1', other_thing='VALUE2')
395 394 self.assertEquals(traits, dict(i=A.i))
396 395
397 396 # This passes, but it shouldn't because I am replicating a bug in
398 397 # traits.
399 398 traits = a.traits(config_key=lambda v: True)
400 399 self.assertEquals(traits, dict(i=A.i, f=A.f, j=A.j))
401 400
402 401 def test_init(self):
403 402 class A(HasTraits):
404 403 i = Int()
405 404 x = Float()
406 405 a = A(i=1, x=10.0)
407 406 self.assertEquals(a.i, 1)
408 407 self.assertEquals(a.x, 10.0)
409 408
410 409 #-----------------------------------------------------------------------------
411 410 # Tests for specific trait types
412 411 #-----------------------------------------------------------------------------
413 412
414 413
415 414 class TestType(TestCase):
416 415
417 416 def test_default(self):
418 417
419 418 class B(object): pass
420 419 class A(HasTraits):
421 420 klass = Type
422 421
423 422 a = A()
424 423 self.assertEquals(a.klass, None)
425 424
426 425 a.klass = B
427 426 self.assertEquals(a.klass, B)
428 427 self.assertRaises(TraitError, setattr, a, 'klass', 10)
429 428
430 429 def test_value(self):
431 430
432 431 class B(object): pass
433 432 class C(object): pass
434 433 class A(HasTraits):
435 434 klass = Type(B)
436 435
437 436 a = A()
438 437 self.assertEquals(a.klass, B)
439 438 self.assertRaises(TraitError, setattr, a, 'klass', C)
440 439 self.assertRaises(TraitError, setattr, a, 'klass', object)
441 440 a.klass = B
442 441
443 442 def test_allow_none(self):
444 443
445 444 class B(object): pass
446 445 class C(B): pass
447 446 class A(HasTraits):
448 447 klass = Type(B, allow_none=False)
449 448
450 449 a = A()
451 450 self.assertEquals(a.klass, B)
452 451 self.assertRaises(TraitError, setattr, a, 'klass', None)
453 452 a.klass = C
454 453 self.assertEquals(a.klass, C)
455 454
456 455 def test_validate_klass(self):
457 456
458 457 class A(HasTraits):
459 458 klass = Type('no strings allowed')
460 459
461 460 self.assertRaises(ImportError, A)
462 461
463 462 class A(HasTraits):
464 463 klass = Type('rub.adub.Duck')
465 464
466 465 self.assertRaises(ImportError, A)
467 466
468 467 def test_validate_default(self):
469 468
470 469 class B(object): pass
471 470 class A(HasTraits):
472 471 klass = Type('bad default', B)
473 472
474 473 self.assertRaises(ImportError, A)
475 474
476 475 class C(HasTraits):
477 476 klass = Type(None, B, allow_none=False)
478 477
479 478 self.assertRaises(TraitError, C)
480 479
481 480 def test_str_klass(self):
482 481
483 482 class A(HasTraits):
484 483 klass = Type('IPython.utils.ipstruct.Struct')
485 484
486 485 from IPython.utils.ipstruct import Struct
487 486 a = A()
488 487 a.klass = Struct
489 488 self.assertEquals(a.klass, Struct)
490 489
491 490 self.assertRaises(TraitError, setattr, a, 'klass', 10)
492 491
493 492 class TestInstance(TestCase):
494 493
495 494 def test_basic(self):
496 495 class Foo(object): pass
497 496 class Bar(Foo): pass
498 497 class Bah(object): pass
499 498
500 499 class A(HasTraits):
501 500 inst = Instance(Foo)
502 501
503 502 a = A()
504 503 self.assert_(a.inst is None)
505 504 a.inst = Foo()
506 505 self.assert_(isinstance(a.inst, Foo))
507 506 a.inst = Bar()
508 507 self.assert_(isinstance(a.inst, Foo))
509 508 self.assertRaises(TraitError, setattr, a, 'inst', Foo)
510 509 self.assertRaises(TraitError, setattr, a, 'inst', Bar)
511 510 self.assertRaises(TraitError, setattr, a, 'inst', Bah())
512 511
513 512 def test_unique_default_value(self):
514 513 class Foo(object): pass
515 514 class A(HasTraits):
516 515 inst = Instance(Foo,(),{})
517 516
518 517 a = A()
519 518 b = A()
520 519 self.assert_(a.inst is not b.inst)
521 520
522 521 def test_args_kw(self):
523 522 class Foo(object):
524 523 def __init__(self, c): self.c = c
525 524 class Bar(object): pass
526 525 class Bah(object):
527 526 def __init__(self, c, d):
528 527 self.c = c; self.d = d
529 528
530 529 class A(HasTraits):
531 530 inst = Instance(Foo, (10,))
532 531 a = A()
533 532 self.assertEquals(a.inst.c, 10)
534 533
535 534 class B(HasTraits):
536 535 inst = Instance(Bah, args=(10,), kw=dict(d=20))
537 536 b = B()
538 537 self.assertEquals(b.inst.c, 10)
539 538 self.assertEquals(b.inst.d, 20)
540 539
541 540 class C(HasTraits):
542 541 inst = Instance(Foo)
543 542 c = C()
544 543 self.assert_(c.inst is None)
545 544
546 545 def test_bad_default(self):
547 546 class Foo(object): pass
548 547
549 548 class A(HasTraits):
550 549 inst = Instance(Foo, allow_none=False)
551 550
552 551 self.assertRaises(TraitError, A)
553 552
554 553 def test_instance(self):
555 554 class Foo(object): pass
556 555
557 556 def inner():
558 557 class A(HasTraits):
559 558 inst = Instance(Foo())
560 559
561 560 self.assertRaises(TraitError, inner)
562 561
563 562
564 563 class TestThis(TestCase):
565 564
566 565 def test_this_class(self):
567 566 class Foo(HasTraits):
568 567 this = This
569 568
570 569 f = Foo()
571 570 self.assertEquals(f.this, None)
572 571 g = Foo()
573 572 f.this = g
574 573 self.assertEquals(f.this, g)
575 574 self.assertRaises(TraitError, setattr, f, 'this', 10)
576 575
577 576 def test_this_inst(self):
578 577 class Foo(HasTraits):
579 578 this = This()
580 579
581 580 f = Foo()
582 581 f.this = Foo()
583 582 self.assert_(isinstance(f.this, Foo))
584 583
585 584 def test_subclass(self):
586 585 class Foo(HasTraits):
587 586 t = This()
588 587 class Bar(Foo):
589 588 pass
590 589 f = Foo()
591 590 b = Bar()
592 591 f.t = b
593 592 b.t = f
594 593 self.assertEquals(f.t, b)
595 594 self.assertEquals(b.t, f)
596 595
597 596 def test_subclass_override(self):
598 597 class Foo(HasTraits):
599 598 t = This()
600 599 class Bar(Foo):
601 600 t = This()
602 601 f = Foo()
603 602 b = Bar()
604 603 f.t = b
605 604 self.assertEquals(f.t, b)
606 605 self.assertRaises(TraitError, setattr, b, 't', f)
607 606
608 607 class TraitTestBase(TestCase):
609 608 """A best testing class for basic trait types."""
610 609
611 610 def assign(self, value):
612 611 self.obj.value = value
613 612
614 613 def coerce(self, value):
615 614 return value
616 615
617 616 def test_good_values(self):
618 617 if hasattr(self, '_good_values'):
619 618 for value in self._good_values:
620 619 self.assign(value)
621 620 self.assertEquals(self.obj.value, self.coerce(value))
622 621
623 622 def test_bad_values(self):
624 623 if hasattr(self, '_bad_values'):
625 624 for value in self._bad_values:
626 625 self.assertRaises(TraitError, self.assign, value)
627 626
628 627 def test_default_value(self):
629 628 if hasattr(self, '_default_value'):
630 629 self.assertEquals(self._default_value, self.obj.value)
631 630
632 631
633 632 class AnyTrait(HasTraits):
634 633
635 634 value = Any
636 635
637 636 class AnyTraitTest(TraitTestBase):
638 637
639 638 obj = AnyTrait()
640 639
641 640 _default_value = None
642 641 _good_values = [10.0, 'ten', u'ten', [10], {'ten': 10},(10,), None, 1j]
643 642 _bad_values = []
644 643
645 644
646 645 class IntTrait(HasTraits):
647 646
648 647 value = Int(99)
649 648
650 649 class TestInt(TraitTestBase):
651 650
652 651 obj = IntTrait()
653 652 _default_value = 99
654 653 _good_values = [10, -10]
655 654 _bad_values = ['ten', u'ten', [10], {'ten': 10},(10,), None, 1j, 10L,
656 655 -10L, 10.1, -10.1, '10L', '-10L', '10.1', '-10.1', u'10L',
657 656 u'-10L', u'10.1', u'-10.1', '10', '-10', u'10', u'-10']
658 657
659 658
660 659 class LongTrait(HasTraits):
661 660
662 661 value = Long(99L)
663 662
664 663 class TestLong(TraitTestBase):
665 664
666 665 obj = LongTrait()
667 666
668 667 _default_value = 99L
669 668 _good_values = [10, -10, 10L, -10L]
670 669 _bad_values = ['ten', u'ten', [10], [10l], {'ten': 10},(10,),(10L,),
671 670 None, 1j, 10.1, -10.1, '10', '-10', '10L', '-10L', '10.1',
672 671 '-10.1', u'10', u'-10', u'10L', u'-10L', u'10.1',
673 672 u'-10.1']
674 673
675 674
676 675 class FloatTrait(HasTraits):
677 676
678 677 value = Float(99.0)
679 678
680 679 class TestFloat(TraitTestBase):
681 680
682 681 obj = FloatTrait()
683 682
684 683 _default_value = 99.0
685 684 _good_values = [10, -10, 10.1, -10.1]
686 685 _bad_values = [10L, -10L, 'ten', u'ten', [10], {'ten': 10},(10,), None,
687 686 1j, '10', '-10', '10L', '-10L', '10.1', '-10.1', u'10',
688 687 u'-10', u'10L', u'-10L', u'10.1', u'-10.1']
689 688
690 689
691 690 class ComplexTrait(HasTraits):
692 691
693 692 value = Complex(99.0-99.0j)
694 693
695 694 class TestComplex(TraitTestBase):
696 695
697 696 obj = ComplexTrait()
698 697
699 698 _default_value = 99.0-99.0j
700 699 _good_values = [10, -10, 10.1, -10.1, 10j, 10+10j, 10-10j,
701 700 10.1j, 10.1+10.1j, 10.1-10.1j]
702 701 _bad_values = [10L, -10L, u'10L', u'-10L', 'ten', [10], {'ten': 10},(10,), None]
703 702
704 703
705 704 class BytesTrait(HasTraits):
706 705
707 706 value = Bytes(b'string')
708 707
709 708 class TestBytes(TraitTestBase):
710 709
711 710 obj = BytesTrait()
712 711
713 712 _default_value = b'string'
714 713 _good_values = [b'10', b'-10', b'10L',
715 714 b'-10L', b'10.1', b'-10.1', b'string']
716 715 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j, [10],
717 716 ['ten'],{'ten': 10},(10,), None, u'string']
718 717
719 718
720 719 class UnicodeTrait(HasTraits):
721 720
722 721 value = Unicode(u'unicode')
723 722
724 723 class TestUnicode(TraitTestBase):
725 724
726 725 obj = UnicodeTrait()
727 726
728 727 _default_value = u'unicode'
729 728 _good_values = ['10', '-10', '10L', '-10L', '10.1',
730 729 '-10.1', '', u'', 'string', u'string', u"€"]
731 730 _bad_values = [10, -10, 10L, -10L, 10.1, -10.1, 1j,
732 731 [10], ['ten'], [u'ten'], {'ten': 10},(10,), None]
733 732
734 733
735 734 class ObjectNameTrait(HasTraits):
736 735 value = ObjectName("abc")
737 736
738 737 class TestObjectName(TraitTestBase):
739 738 obj = ObjectNameTrait()
740 739
741 740 _default_value = "abc"
742 741 _good_values = ["a", "gh", "g9", "g_", "_G", u"a345_"]
743 742 _bad_values = [1, "", u"€", "9g", "!", "#abc", "aj@", "a.b", "a()", "a[0]",
744 743 object(), object]
745 744 if sys.version_info[0] < 3:
746 745 _bad_values.append(u"þ")
747 746 else:
748 747 _good_values.append(u"þ") # þ=1 is valid in Python 3 (PEP 3131).
749 748
750 749
751 750 class DottedObjectNameTrait(HasTraits):
752 751 value = DottedObjectName("a.b")
753 752
754 753 class TestDottedObjectName(TraitTestBase):
755 754 obj = DottedObjectNameTrait()
756 755
757 756 _default_value = "a.b"
758 757 _good_values = ["A", "y.t", "y765.__repr__", "os.path.join", u"os.path.join"]
759 758 _bad_values = [1, u"abc.€", "_.@", ".", ".abc", "abc.", ".abc."]
760 759 if sys.version_info[0] < 3:
761 760 _bad_values.append(u"t.þ")
762 761 else:
763 762 _good_values.append(u"t.þ")
764 763
765 764
766 765 class TCPAddressTrait(HasTraits):
767 766
768 767 value = TCPAddress()
769 768
770 769 class TestTCPAddress(TraitTestBase):
771 770
772 771 obj = TCPAddressTrait()
773 772
774 773 _default_value = ('127.0.0.1',0)
775 774 _good_values = [('localhost',0),('192.168.0.1',1000),('www.google.com',80)]
776 775 _bad_values = [(0,0),('localhost',10.0),('localhost',-1)]
777 776
778 777 class ListTrait(HasTraits):
779 778
780 779 value = List(Int)
781 780
782 781 class TestList(TraitTestBase):
783 782
784 783 obj = ListTrait()
785 784
786 785 _default_value = []
787 786 _good_values = [[], [1], range(10)]
788 787 _bad_values = [10, [1,'a'], 'a', (1,2)]
789 788
790 789 class LenListTrait(HasTraits):
791 790
792 791 value = List(Int, [0], minlen=1, maxlen=2)
793 792
794 793 class TestLenList(TraitTestBase):
795 794
796 795 obj = LenListTrait()
797 796
798 797 _default_value = [0]
799 798 _good_values = [[1], range(2)]
800 799 _bad_values = [10, [1,'a'], 'a', (1,2), [], range(3)]
801 800
802 801 class TupleTrait(HasTraits):
803 802
804 803 value = Tuple(Int)
805 804
806 805 class TestTupleTrait(TraitTestBase):
807 806
808 807 obj = TupleTrait()
809 808
810 809 _default_value = None
811 810 _good_values = [(1,), None,(0,)]
812 811 _bad_values = [10, (1,2), [1],('a'), ()]
813 812
814 813 def test_invalid_args(self):
815 814 self.assertRaises(TypeError, Tuple, 5)
816 815 self.assertRaises(TypeError, Tuple, default_value='hello')
817 816 t = Tuple(Int, CBytes, default_value=(1,5))
818 817
819 818 class LooseTupleTrait(HasTraits):
820 819
821 820 value = Tuple((1,2,3))
822 821
823 822 class TestLooseTupleTrait(TraitTestBase):
824 823
825 824 obj = LooseTupleTrait()
826 825
827 826 _default_value = (1,2,3)
828 827 _good_values = [(1,), None, (0,), tuple(range(5)), tuple('hello'), ('a',5), ()]
829 828 _bad_values = [10, 'hello', [1], []]
830 829
831 830 def test_invalid_args(self):
832 831 self.assertRaises(TypeError, Tuple, 5)
833 832 self.assertRaises(TypeError, Tuple, default_value='hello')
834 833 t = Tuple(Int, CBytes, default_value=(1,5))
835 834
836 835
837 836 class MultiTupleTrait(HasTraits):
838 837
839 838 value = Tuple(Int, Bytes, default_value=[99,'bottles'])
840 839
841 840 class TestMultiTuple(TraitTestBase):
842 841
843 842 obj = MultiTupleTrait()
844 843
845 844 _default_value = (99,'bottles')
846 845 _good_values = [(1,'a'), (2,'b')]
847 846 _bad_values = ((),10, 'a', (1,'a',3), ('a',1))
@@ -1,1397 +1,1396 b''
1 #!/usr/bin/env python
2 1 # encoding: utf-8
3 2 """
4 3 A lightweight Traits like module.
5 4
6 5 This is designed to provide a lightweight, simple, pure Python version of
7 6 many of the capabilities of enthought.traits. This includes:
8 7
9 8 * Validation
10 9 * Type specification with defaults
11 10 * Static and dynamic notification
12 11 * Basic predefined types
13 12 * An API that is similar to enthought.traits
14 13
15 14 We don't support:
16 15
17 16 * Delegation
18 17 * Automatic GUI generation
19 18 * A full set of trait types. Most importantly, we don't provide container
20 19 traits (list, dict, tuple) that can trigger notifications if their
21 20 contents change.
22 21 * API compatibility with enthought.traits
23 22
24 23 There are also some important difference in our design:
25 24
26 25 * enthought.traits does not validate default values. We do.
27 26
28 27 We choose to create this module because we need these capabilities, but
29 28 we need them to be pure Python so they work in all Python implementations,
30 29 including Jython and IronPython.
31 30
32 31 Authors:
33 32
34 33 * Brian Granger
35 34 * Enthought, Inc. Some of the code in this file comes from enthought.traits
36 35 and is licensed under the BSD license. Also, many of the ideas also come
37 36 from enthought.traits even though our implementation is very different.
38 37 """
39 38
40 39 #-----------------------------------------------------------------------------
41 40 # Copyright (C) 2008-2009 The IPython Development Team
42 41 #
43 42 # Distributed under the terms of the BSD License. The full license is in
44 43 # the file COPYING, distributed as part of this software.
45 44 #-----------------------------------------------------------------------------
46 45
47 46 #-----------------------------------------------------------------------------
48 47 # Imports
49 48 #-----------------------------------------------------------------------------
50 49
51 50
52 51 import inspect
53 52 import re
54 53 import sys
55 54 import types
56 55 from types import (
57 56 InstanceType, ClassType, FunctionType,
58 57 ListType, TupleType
59 58 )
60 59 from .importstring import import_item
61 60
62 61 ClassTypes = (ClassType, type)
63 62
64 63 SequenceTypes = (ListType, TupleType, set, frozenset)
65 64
66 65 #-----------------------------------------------------------------------------
67 66 # Basic classes
68 67 #-----------------------------------------------------------------------------
69 68
70 69
71 70 class NoDefaultSpecified ( object ): pass
72 71 NoDefaultSpecified = NoDefaultSpecified()
73 72
74 73
75 74 class Undefined ( object ): pass
76 75 Undefined = Undefined()
77 76
78 77 class TraitError(Exception):
79 78 pass
80 79
81 80 #-----------------------------------------------------------------------------
82 81 # Utilities
83 82 #-----------------------------------------------------------------------------
84 83
85 84
86 85 def class_of ( object ):
87 86 """ Returns a string containing the class name of an object with the
88 87 correct indefinite article ('a' or 'an') preceding it (e.g., 'an Image',
89 88 'a PlotValue').
90 89 """
91 90 if isinstance( object, basestring ):
92 91 return add_article( object )
93 92
94 93 return add_article( object.__class__.__name__ )
95 94
96 95
97 96 def add_article ( name ):
98 97 """ Returns a string containing the correct indefinite article ('a' or 'an')
99 98 prefixed to the specified string.
100 99 """
101 100 if name[:1].lower() in 'aeiou':
102 101 return 'an ' + name
103 102
104 103 return 'a ' + name
105 104
106 105
107 106 def repr_type(obj):
108 107 """ Return a string representation of a value and its type for readable
109 108 error messages.
110 109 """
111 110 the_type = type(obj)
112 111 if the_type is InstanceType:
113 112 # Old-style class.
114 113 the_type = obj.__class__
115 114 msg = '%r %r' % (obj, the_type)
116 115 return msg
117 116
118 117
119 118 def parse_notifier_name(name):
120 119 """Convert the name argument to a list of names.
121 120
122 121 Examples
123 122 --------
124 123
125 124 >>> parse_notifier_name('a')
126 125 ['a']
127 126 >>> parse_notifier_name(['a','b'])
128 127 ['a', 'b']
129 128 >>> parse_notifier_name(None)
130 129 ['anytrait']
131 130 """
132 131 if isinstance(name, str):
133 132 return [name]
134 133 elif name is None:
135 134 return ['anytrait']
136 135 elif isinstance(name, (list, tuple)):
137 136 for n in name:
138 137 assert isinstance(n, str), "names must be strings"
139 138 return name
140 139
141 140
142 141 class _SimpleTest:
143 142 def __init__ ( self, value ): self.value = value
144 143 def __call__ ( self, test ):
145 144 return test == self.value
146 145 def __repr__(self):
147 146 return "<SimpleTest(%r)" % self.value
148 147 def __str__(self):
149 148 return self.__repr__()
150 149
151 150
152 151 def getmembers(object, predicate=None):
153 152 """A safe version of inspect.getmembers that handles missing attributes.
154 153
155 154 This is useful when there are descriptor based attributes that for
156 155 some reason raise AttributeError even though they exist. This happens
157 156 in zope.inteface with the __provides__ attribute.
158 157 """
159 158 results = []
160 159 for key in dir(object):
161 160 try:
162 161 value = getattr(object, key)
163 162 except AttributeError:
164 163 pass
165 164 else:
166 165 if not predicate or predicate(value):
167 166 results.append((key, value))
168 167 results.sort()
169 168 return results
170 169
171 170
172 171 #-----------------------------------------------------------------------------
173 172 # Base TraitType for all traits
174 173 #-----------------------------------------------------------------------------
175 174
176 175
177 176 class TraitType(object):
178 177 """A base class for all trait descriptors.
179 178
180 179 Notes
181 180 -----
182 181 Our implementation of traits is based on Python's descriptor
183 182 prototol. This class is the base class for all such descriptors. The
184 183 only magic we use is a custom metaclass for the main :class:`HasTraits`
185 184 class that does the following:
186 185
187 186 1. Sets the :attr:`name` attribute of every :class:`TraitType`
188 187 instance in the class dict to the name of the attribute.
189 188 2. Sets the :attr:`this_class` attribute of every :class:`TraitType`
190 189 instance in the class dict to the *class* that declared the trait.
191 190 This is used by the :class:`This` trait to allow subclasses to
192 191 accept superclasses for :class:`This` values.
193 192 """
194 193
195 194
196 195 metadata = {}
197 196 default_value = Undefined
198 197 info_text = 'any value'
199 198
200 199 def __init__(self, default_value=NoDefaultSpecified, **metadata):
201 200 """Create a TraitType.
202 201 """
203 202 if default_value is not NoDefaultSpecified:
204 203 self.default_value = default_value
205 204
206 205 if len(metadata) > 0:
207 206 if len(self.metadata) > 0:
208 207 self._metadata = self.metadata.copy()
209 208 self._metadata.update(metadata)
210 209 else:
211 210 self._metadata = metadata
212 211 else:
213 212 self._metadata = self.metadata
214 213
215 214 self.init()
216 215
217 216 def init(self):
218 217 pass
219 218
220 219 def get_default_value(self):
221 220 """Create a new instance of the default value."""
222 221 return self.default_value
223 222
224 223 def instance_init(self, obj):
225 224 """This is called by :meth:`HasTraits.__new__` to finish init'ing.
226 225
227 226 Some stages of initialization must be delayed until the parent
228 227 :class:`HasTraits` instance has been created. This method is
229 228 called in :meth:`HasTraits.__new__` after the instance has been
230 229 created.
231 230
232 231 This method trigger the creation and validation of default values
233 232 and also things like the resolution of str given class names in
234 233 :class:`Type` and :class`Instance`.
235 234
236 235 Parameters
237 236 ----------
238 237 obj : :class:`HasTraits` instance
239 238 The parent :class:`HasTraits` instance that has just been
240 239 created.
241 240 """
242 241 self.set_default_value(obj)
243 242
244 243 def set_default_value(self, obj):
245 244 """Set the default value on a per instance basis.
246 245
247 246 This method is called by :meth:`instance_init` to create and
248 247 validate the default value. The creation and validation of
249 248 default values must be delayed until the parent :class:`HasTraits`
250 249 class has been instantiated.
251 250 """
252 251 # Check for a deferred initializer defined in the same class as the
253 252 # trait declaration or above.
254 253 mro = type(obj).mro()
255 254 meth_name = '_%s_default' % self.name
256 255 for cls in mro[:mro.index(self.this_class)+1]:
257 256 if meth_name in cls.__dict__:
258 257 break
259 258 else:
260 259 # We didn't find one. Do static initialization.
261 260 dv = self.get_default_value()
262 261 newdv = self._validate(obj, dv)
263 262 obj._trait_values[self.name] = newdv
264 263 return
265 264 # Complete the dynamic initialization.
266 265 obj._trait_dyn_inits[self.name] = cls.__dict__[meth_name]
267 266
268 267 def __get__(self, obj, cls=None):
269 268 """Get the value of the trait by self.name for the instance.
270 269
271 270 Default values are instantiated when :meth:`HasTraits.__new__`
272 271 is called. Thus by the time this method gets called either the
273 272 default value or a user defined value (they called :meth:`__set__`)
274 273 is in the :class:`HasTraits` instance.
275 274 """
276 275 if obj is None:
277 276 return self
278 277 else:
279 278 try:
280 279 value = obj._trait_values[self.name]
281 280 except KeyError:
282 281 # Check for a dynamic initializer.
283 282 if self.name in obj._trait_dyn_inits:
284 283 value = obj._trait_dyn_inits[self.name](obj)
285 284 # FIXME: Do we really validate here?
286 285 value = self._validate(obj, value)
287 286 obj._trait_values[self.name] = value
288 287 return value
289 288 else:
290 289 raise TraitError('Unexpected error in TraitType: '
291 290 'both default value and dynamic initializer are '
292 291 'absent.')
293 292 except Exception:
294 293 # HasTraits should call set_default_value to populate
295 294 # this. So this should never be reached.
296 295 raise TraitError('Unexpected error in TraitType: '
297 296 'default value not set properly')
298 297 else:
299 298 return value
300 299
301 300 def __set__(self, obj, value):
302 301 new_value = self._validate(obj, value)
303 302 old_value = self.__get__(obj)
304 303 if old_value != new_value:
305 304 obj._trait_values[self.name] = new_value
306 305 obj._notify_trait(self.name, old_value, new_value)
307 306
308 307 def _validate(self, obj, value):
309 308 if hasattr(self, 'validate'):
310 309 return self.validate(obj, value)
311 310 elif hasattr(self, 'is_valid_for'):
312 311 valid = self.is_valid_for(value)
313 312 if valid:
314 313 return value
315 314 else:
316 315 raise TraitError('invalid value for type: %r' % value)
317 316 elif hasattr(self, 'value_for'):
318 317 return self.value_for(value)
319 318 else:
320 319 return value
321 320
322 321 def info(self):
323 322 return self.info_text
324 323
325 324 def error(self, obj, value):
326 325 if obj is not None:
327 326 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
328 327 % (self.name, class_of(obj),
329 328 self.info(), repr_type(value))
330 329 else:
331 330 e = "The '%s' trait must be %s, but a value of %r was specified." \
332 331 % (self.name, self.info(), repr_type(value))
333 332 raise TraitError(e)
334 333
335 334 def get_metadata(self, key):
336 335 return getattr(self, '_metadata', {}).get(key, None)
337 336
338 337 def set_metadata(self, key, value):
339 338 getattr(self, '_metadata', {})[key] = value
340 339
341 340
342 341 #-----------------------------------------------------------------------------
343 342 # The HasTraits implementation
344 343 #-----------------------------------------------------------------------------
345 344
346 345
347 346 class MetaHasTraits(type):
348 347 """A metaclass for HasTraits.
349 348
350 349 This metaclass makes sure that any TraitType class attributes are
351 350 instantiated and sets their name attribute.
352 351 """
353 352
354 353 def __new__(mcls, name, bases, classdict):
355 354 """Create the HasTraits class.
356 355
357 356 This instantiates all TraitTypes in the class dict and sets their
358 357 :attr:`name` attribute.
359 358 """
360 359 # print "MetaHasTraitlets (mcls, name): ", mcls, name
361 360 # print "MetaHasTraitlets (bases): ", bases
362 361 # print "MetaHasTraitlets (classdict): ", classdict
363 362 for k,v in classdict.iteritems():
364 363 if isinstance(v, TraitType):
365 364 v.name = k
366 365 elif inspect.isclass(v):
367 366 if issubclass(v, TraitType):
368 367 vinst = v()
369 368 vinst.name = k
370 369 classdict[k] = vinst
371 370 return super(MetaHasTraits, mcls).__new__(mcls, name, bases, classdict)
372 371
373 372 def __init__(cls, name, bases, classdict):
374 373 """Finish initializing the HasTraits class.
375 374
376 375 This sets the :attr:`this_class` attribute of each TraitType in the
377 376 class dict to the newly created class ``cls``.
378 377 """
379 378 for k, v in classdict.iteritems():
380 379 if isinstance(v, TraitType):
381 380 v.this_class = cls
382 381 super(MetaHasTraits, cls).__init__(name, bases, classdict)
383 382
384 383 class HasTraits(object):
385 384
386 385 __metaclass__ = MetaHasTraits
387 386
388 387 def __new__(cls, **kw):
389 388 # This is needed because in Python 2.6 object.__new__ only accepts
390 389 # the cls argument.
391 390 new_meth = super(HasTraits, cls).__new__
392 391 if new_meth is object.__new__:
393 392 inst = new_meth(cls)
394 393 else:
395 394 inst = new_meth(cls, **kw)
396 395 inst._trait_values = {}
397 396 inst._trait_notifiers = {}
398 397 inst._trait_dyn_inits = {}
399 398 # Here we tell all the TraitType instances to set their default
400 399 # values on the instance.
401 400 for key in dir(cls):
402 401 # Some descriptors raise AttributeError like zope.interface's
403 402 # __provides__ attributes even though they exist. This causes
404 403 # AttributeErrors even though they are listed in dir(cls).
405 404 try:
406 405 value = getattr(cls, key)
407 406 except AttributeError:
408 407 pass
409 408 else:
410 409 if isinstance(value, TraitType):
411 410 value.instance_init(inst)
412 411
413 412 return inst
414 413
415 414 def __init__(self, **kw):
416 415 # Allow trait values to be set using keyword arguments.
417 416 # We need to use setattr for this to trigger validation and
418 417 # notifications.
419 418 for key, value in kw.iteritems():
420 419 setattr(self, key, value)
421 420
422 421 def _notify_trait(self, name, old_value, new_value):
423 422
424 423 # First dynamic ones
425 424 callables = self._trait_notifiers.get(name,[])
426 425 more_callables = self._trait_notifiers.get('anytrait',[])
427 426 callables.extend(more_callables)
428 427
429 428 # Now static ones
430 429 try:
431 430 cb = getattr(self, '_%s_changed' % name)
432 431 except:
433 432 pass
434 433 else:
435 434 callables.append(cb)
436 435
437 436 # Call them all now
438 437 for c in callables:
439 438 # Traits catches and logs errors here. I allow them to raise
440 439 if callable(c):
441 440 argspec = inspect.getargspec(c)
442 441 nargs = len(argspec[0])
443 442 # Bound methods have an additional 'self' argument
444 443 # I don't know how to treat unbound methods, but they
445 444 # can't really be used for callbacks.
446 445 if isinstance(c, types.MethodType):
447 446 offset = -1
448 447 else:
449 448 offset = 0
450 449 if nargs + offset == 0:
451 450 c()
452 451 elif nargs + offset == 1:
453 452 c(name)
454 453 elif nargs + offset == 2:
455 454 c(name, new_value)
456 455 elif nargs + offset == 3:
457 456 c(name, old_value, new_value)
458 457 else:
459 458 raise TraitError('a trait changed callback '
460 459 'must have 0-3 arguments.')
461 460 else:
462 461 raise TraitError('a trait changed callback '
463 462 'must be callable.')
464 463
465 464
466 465 def _add_notifiers(self, handler, name):
467 466 if not self._trait_notifiers.has_key(name):
468 467 nlist = []
469 468 self._trait_notifiers[name] = nlist
470 469 else:
471 470 nlist = self._trait_notifiers[name]
472 471 if handler not in nlist:
473 472 nlist.append(handler)
474 473
475 474 def _remove_notifiers(self, handler, name):
476 475 if self._trait_notifiers.has_key(name):
477 476 nlist = self._trait_notifiers[name]
478 477 try:
479 478 index = nlist.index(handler)
480 479 except ValueError:
481 480 pass
482 481 else:
483 482 del nlist[index]
484 483
485 484 def on_trait_change(self, handler, name=None, remove=False):
486 485 """Setup a handler to be called when a trait changes.
487 486
488 487 This is used to setup dynamic notifications of trait changes.
489 488
490 489 Static handlers can be created by creating methods on a HasTraits
491 490 subclass with the naming convention '_[traitname]_changed'. Thus,
492 491 to create static handler for the trait 'a', create the method
493 492 _a_changed(self, name, old, new) (fewer arguments can be used, see
494 493 below).
495 494
496 495 Parameters
497 496 ----------
498 497 handler : callable
499 498 A callable that is called when a trait changes. Its
500 499 signature can be handler(), handler(name), handler(name, new)
501 500 or handler(name, old, new).
502 501 name : list, str, None
503 502 If None, the handler will apply to all traits. If a list
504 503 of str, handler will apply to all names in the list. If a
505 504 str, the handler will apply just to that name.
506 505 remove : bool
507 506 If False (the default), then install the handler. If True
508 507 then unintall it.
509 508 """
510 509 if remove:
511 510 names = parse_notifier_name(name)
512 511 for n in names:
513 512 self._remove_notifiers(handler, n)
514 513 else:
515 514 names = parse_notifier_name(name)
516 515 for n in names:
517 516 self._add_notifiers(handler, n)
518 517
519 518 @classmethod
520 519 def class_trait_names(cls, **metadata):
521 520 """Get a list of all the names of this classes traits.
522 521
523 522 This method is just like the :meth:`trait_names` method, but is unbound.
524 523 """
525 524 return cls.class_traits(**metadata).keys()
526 525
527 526 @classmethod
528 527 def class_traits(cls, **metadata):
529 528 """Get a list of all the traits of this class.
530 529
531 530 This method is just like the :meth:`traits` method, but is unbound.
532 531
533 532 The TraitTypes returned don't know anything about the values
534 533 that the various HasTrait's instances are holding.
535 534
536 535 This follows the same algorithm as traits does and does not allow
537 536 for any simple way of specifying merely that a metadata name
538 537 exists, but has any value. This is because get_metadata returns
539 538 None if a metadata key doesn't exist.
540 539 """
541 540 traits = dict([memb for memb in getmembers(cls) if \
542 541 isinstance(memb[1], TraitType)])
543 542
544 543 if len(metadata) == 0:
545 544 return traits
546 545
547 546 for meta_name, meta_eval in metadata.items():
548 547 if type(meta_eval) is not FunctionType:
549 548 metadata[meta_name] = _SimpleTest(meta_eval)
550 549
551 550 result = {}
552 551 for name, trait in traits.items():
553 552 for meta_name, meta_eval in metadata.items():
554 553 if not meta_eval(trait.get_metadata(meta_name)):
555 554 break
556 555 else:
557 556 result[name] = trait
558 557
559 558 return result
560 559
561 560 def trait_names(self, **metadata):
562 561 """Get a list of all the names of this classes traits."""
563 562 return self.traits(**metadata).keys()
564 563
565 564 def traits(self, **metadata):
566 565 """Get a list of all the traits of this class.
567 566
568 567 The TraitTypes returned don't know anything about the values
569 568 that the various HasTrait's instances are holding.
570 569
571 570 This follows the same algorithm as traits does and does not allow
572 571 for any simple way of specifying merely that a metadata name
573 572 exists, but has any value. This is because get_metadata returns
574 573 None if a metadata key doesn't exist.
575 574 """
576 575 traits = dict([memb for memb in getmembers(self.__class__) if \
577 576 isinstance(memb[1], TraitType)])
578 577
579 578 if len(metadata) == 0:
580 579 return traits
581 580
582 581 for meta_name, meta_eval in metadata.items():
583 582 if type(meta_eval) is not FunctionType:
584 583 metadata[meta_name] = _SimpleTest(meta_eval)
585 584
586 585 result = {}
587 586 for name, trait in traits.items():
588 587 for meta_name, meta_eval in metadata.items():
589 588 if not meta_eval(trait.get_metadata(meta_name)):
590 589 break
591 590 else:
592 591 result[name] = trait
593 592
594 593 return result
595 594
596 595 def trait_metadata(self, traitname, key):
597 596 """Get metadata values for trait by key."""
598 597 try:
599 598 trait = getattr(self.__class__, traitname)
600 599 except AttributeError:
601 600 raise TraitError("Class %s does not have a trait named %s" %
602 601 (self.__class__.__name__, traitname))
603 602 else:
604 603 return trait.get_metadata(key)
605 604
606 605 #-----------------------------------------------------------------------------
607 606 # Actual TraitTypes implementations/subclasses
608 607 #-----------------------------------------------------------------------------
609 608
610 609 #-----------------------------------------------------------------------------
611 610 # TraitTypes subclasses for handling classes and instances of classes
612 611 #-----------------------------------------------------------------------------
613 612
614 613
615 614 class ClassBasedTraitType(TraitType):
616 615 """A trait with error reporting for Type, Instance and This."""
617 616
618 617 def error(self, obj, value):
619 618 kind = type(value)
620 619 if kind is InstanceType:
621 620 msg = 'class %s' % value.__class__.__name__
622 621 else:
623 622 msg = '%s (i.e. %s)' % ( str( kind )[1:-1], repr( value ) )
624 623
625 624 if obj is not None:
626 625 e = "The '%s' trait of %s instance must be %s, but a value of %s was specified." \
627 626 % (self.name, class_of(obj),
628 627 self.info(), msg)
629 628 else:
630 629 e = "The '%s' trait must be %s, but a value of %r was specified." \
631 630 % (self.name, self.info(), msg)
632 631
633 632 raise TraitError(e)
634 633
635 634
636 635 class Type(ClassBasedTraitType):
637 636 """A trait whose value must be a subclass of a specified class."""
638 637
639 638 def __init__ (self, default_value=None, klass=None, allow_none=True, **metadata ):
640 639 """Construct a Type trait
641 640
642 641 A Type trait specifies that its values must be subclasses of
643 642 a particular class.
644 643
645 644 If only ``default_value`` is given, it is used for the ``klass`` as
646 645 well.
647 646
648 647 Parameters
649 648 ----------
650 649 default_value : class, str or None
651 650 The default value must be a subclass of klass. If an str,
652 651 the str must be a fully specified class name, like 'foo.bar.Bah'.
653 652 The string is resolved into real class, when the parent
654 653 :class:`HasTraits` class is instantiated.
655 654 klass : class, str, None
656 655 Values of this trait must be a subclass of klass. The klass
657 656 may be specified in a string like: 'foo.bar.MyClass'.
658 657 The string is resolved into real class, when the parent
659 658 :class:`HasTraits` class is instantiated.
660 659 allow_none : boolean
661 660 Indicates whether None is allowed as an assignable value. Even if
662 661 ``False``, the default value may be ``None``.
663 662 """
664 663 if default_value is None:
665 664 if klass is None:
666 665 klass = object
667 666 elif klass is None:
668 667 klass = default_value
669 668
670 669 if not (inspect.isclass(klass) or isinstance(klass, basestring)):
671 670 raise TraitError("A Type trait must specify a class.")
672 671
673 672 self.klass = klass
674 673 self._allow_none = allow_none
675 674
676 675 super(Type, self).__init__(default_value, **metadata)
677 676
678 677 def validate(self, obj, value):
679 678 """Validates that the value is a valid object instance."""
680 679 try:
681 680 if issubclass(value, self.klass):
682 681 return value
683 682 except:
684 683 if (value is None) and (self._allow_none):
685 684 return value
686 685
687 686 self.error(obj, value)
688 687
689 688 def info(self):
690 689 """ Returns a description of the trait."""
691 690 if isinstance(self.klass, basestring):
692 691 klass = self.klass
693 692 else:
694 693 klass = self.klass.__name__
695 694 result = 'a subclass of ' + klass
696 695 if self._allow_none:
697 696 return result + ' or None'
698 697 return result
699 698
700 699 def instance_init(self, obj):
701 700 self._resolve_classes()
702 701 super(Type, self).instance_init(obj)
703 702
704 703 def _resolve_classes(self):
705 704 if isinstance(self.klass, basestring):
706 705 self.klass = import_item(self.klass)
707 706 if isinstance(self.default_value, basestring):
708 707 self.default_value = import_item(self.default_value)
709 708
710 709 def get_default_value(self):
711 710 return self.default_value
712 711
713 712
714 713 class DefaultValueGenerator(object):
715 714 """A class for generating new default value instances."""
716 715
717 716 def __init__(self, *args, **kw):
718 717 self.args = args
719 718 self.kw = kw
720 719
721 720 def generate(self, klass):
722 721 return klass(*self.args, **self.kw)
723 722
724 723
725 724 class Instance(ClassBasedTraitType):
726 725 """A trait whose value must be an instance of a specified class.
727 726
728 727 The value can also be an instance of a subclass of the specified class.
729 728 """
730 729
731 730 def __init__(self, klass=None, args=None, kw=None,
732 731 allow_none=True, **metadata ):
733 732 """Construct an Instance trait.
734 733
735 734 This trait allows values that are instances of a particular
736 735 class or its sublclasses. Our implementation is quite different
737 736 from that of enthough.traits as we don't allow instances to be used
738 737 for klass and we handle the ``args`` and ``kw`` arguments differently.
739 738
740 739 Parameters
741 740 ----------
742 741 klass : class, str
743 742 The class that forms the basis for the trait. Class names
744 743 can also be specified as strings, like 'foo.bar.Bar'.
745 744 args : tuple
746 745 Positional arguments for generating the default value.
747 746 kw : dict
748 747 Keyword arguments for generating the default value.
749 748 allow_none : bool
750 749 Indicates whether None is allowed as a value.
751 750
752 751 Default Value
753 752 -------------
754 753 If both ``args`` and ``kw`` are None, then the default value is None.
755 754 If ``args`` is a tuple and ``kw`` is a dict, then the default is
756 755 created as ``klass(*args, **kw)``. If either ``args`` or ``kw`` is
757 756 not (but not both), None is replace by ``()`` or ``{}``.
758 757 """
759 758
760 759 self._allow_none = allow_none
761 760
762 761 if (klass is None) or (not (inspect.isclass(klass) or isinstance(klass, basestring))):
763 762 raise TraitError('The klass argument must be a class'
764 763 ' you gave: %r' % klass)
765 764 self.klass = klass
766 765
767 766 # self.klass is a class, so handle default_value
768 767 if args is None and kw is None:
769 768 default_value = None
770 769 else:
771 770 if args is None:
772 771 # kw is not None
773 772 args = ()
774 773 elif kw is None:
775 774 # args is not None
776 775 kw = {}
777 776
778 777 if not isinstance(kw, dict):
779 778 raise TraitError("The 'kw' argument must be a dict or None.")
780 779 if not isinstance(args, tuple):
781 780 raise TraitError("The 'args' argument must be a tuple or None.")
782 781
783 782 default_value = DefaultValueGenerator(*args, **kw)
784 783
785 784 super(Instance, self).__init__(default_value, **metadata)
786 785
787 786 def validate(self, obj, value):
788 787 if value is None:
789 788 if self._allow_none:
790 789 return value
791 790 self.error(obj, value)
792 791
793 792 if isinstance(value, self.klass):
794 793 return value
795 794 else:
796 795 self.error(obj, value)
797 796
798 797 def info(self):
799 798 if isinstance(self.klass, basestring):
800 799 klass = self.klass
801 800 else:
802 801 klass = self.klass.__name__
803 802 result = class_of(klass)
804 803 if self._allow_none:
805 804 return result + ' or None'
806 805
807 806 return result
808 807
809 808 def instance_init(self, obj):
810 809 self._resolve_classes()
811 810 super(Instance, self).instance_init(obj)
812 811
813 812 def _resolve_classes(self):
814 813 if isinstance(self.klass, basestring):
815 814 self.klass = import_item(self.klass)
816 815
817 816 def get_default_value(self):
818 817 """Instantiate a default value instance.
819 818
820 819 This is called when the containing HasTraits classes'
821 820 :meth:`__new__` method is called to ensure that a unique instance
822 821 is created for each HasTraits instance.
823 822 """
824 823 dv = self.default_value
825 824 if isinstance(dv, DefaultValueGenerator):
826 825 return dv.generate(self.klass)
827 826 else:
828 827 return dv
829 828
830 829
831 830 class This(ClassBasedTraitType):
832 831 """A trait for instances of the class containing this trait.
833 832
834 833 Because how how and when class bodies are executed, the ``This``
835 834 trait can only have a default value of None. This, and because we
836 835 always validate default values, ``allow_none`` is *always* true.
837 836 """
838 837
839 838 info_text = 'an instance of the same type as the receiver or None'
840 839
841 840 def __init__(self, **metadata):
842 841 super(This, self).__init__(None, **metadata)
843 842
844 843 def validate(self, obj, value):
845 844 # What if value is a superclass of obj.__class__? This is
846 845 # complicated if it was the superclass that defined the This
847 846 # trait.
848 847 if isinstance(value, self.this_class) or (value is None):
849 848 return value
850 849 else:
851 850 self.error(obj, value)
852 851
853 852
854 853 #-----------------------------------------------------------------------------
855 854 # Basic TraitTypes implementations/subclasses
856 855 #-----------------------------------------------------------------------------
857 856
858 857
859 858 class Any(TraitType):
860 859 default_value = None
861 860 info_text = 'any value'
862 861
863 862
864 863 class Int(TraitType):
865 864 """A integer trait."""
866 865
867 866 default_value = 0
868 867 info_text = 'an integer'
869 868
870 869 def validate(self, obj, value):
871 870 if isinstance(value, int):
872 871 return value
873 872 self.error(obj, value)
874 873
875 874 class CInt(Int):
876 875 """A casting version of the int trait."""
877 876
878 877 def validate(self, obj, value):
879 878 try:
880 879 return int(value)
881 880 except:
882 881 self.error(obj, value)
883 882
884 883
885 884 class Long(TraitType):
886 885 """A long integer trait."""
887 886
888 887 default_value = 0L
889 888 info_text = 'a long'
890 889
891 890 def validate(self, obj, value):
892 891 if isinstance(value, long):
893 892 return value
894 893 if isinstance(value, int):
895 894 return long(value)
896 895 self.error(obj, value)
897 896
898 897
899 898 class CLong(Long):
900 899 """A casting version of the long integer trait."""
901 900
902 901 def validate(self, obj, value):
903 902 try:
904 903 return long(value)
905 904 except:
906 905 self.error(obj, value)
907 906
908 907
909 908 class Float(TraitType):
910 909 """A float trait."""
911 910
912 911 default_value = 0.0
913 912 info_text = 'a float'
914 913
915 914 def validate(self, obj, value):
916 915 if isinstance(value, float):
917 916 return value
918 917 if isinstance(value, int):
919 918 return float(value)
920 919 self.error(obj, value)
921 920
922 921
923 922 class CFloat(Float):
924 923 """A casting version of the float trait."""
925 924
926 925 def validate(self, obj, value):
927 926 try:
928 927 return float(value)
929 928 except:
930 929 self.error(obj, value)
931 930
932 931 class Complex(TraitType):
933 932 """A trait for complex numbers."""
934 933
935 934 default_value = 0.0 + 0.0j
936 935 info_text = 'a complex number'
937 936
938 937 def validate(self, obj, value):
939 938 if isinstance(value, complex):
940 939 return value
941 940 if isinstance(value, (float, int)):
942 941 return complex(value)
943 942 self.error(obj, value)
944 943
945 944
946 945 class CComplex(Complex):
947 946 """A casting version of the complex number trait."""
948 947
949 948 def validate (self, obj, value):
950 949 try:
951 950 return complex(value)
952 951 except:
953 952 self.error(obj, value)
954 953
955 954 # We should always be explicit about whether we're using bytes or unicode, both
956 955 # for Python 3 conversion and for reliable unicode behaviour on Python 2. So
957 956 # we don't have a Str type.
958 957 class Bytes(TraitType):
959 958 """A trait for strings."""
960 959
961 960 default_value = ''
962 961 info_text = 'a string'
963 962
964 963 def validate(self, obj, value):
965 964 if isinstance(value, bytes):
966 965 return value
967 966 self.error(obj, value)
968 967
969 968
970 969 class CBytes(Bytes):
971 970 """A casting version of the string trait."""
972 971
973 972 def validate(self, obj, value):
974 973 try:
975 974 return bytes(value)
976 975 except:
977 976 self.error(obj, value)
978 977
979 978
980 979 class Unicode(TraitType):
981 980 """A trait for unicode strings."""
982 981
983 982 default_value = u''
984 983 info_text = 'a unicode string'
985 984
986 985 def validate(self, obj, value):
987 986 if isinstance(value, unicode):
988 987 return value
989 988 if isinstance(value, bytes):
990 989 return unicode(value)
991 990 self.error(obj, value)
992 991
993 992
994 993 class CUnicode(Unicode):
995 994 """A casting version of the unicode trait."""
996 995
997 996 def validate(self, obj, value):
998 997 try:
999 998 return unicode(value)
1000 999 except:
1001 1000 self.error(obj, value)
1002 1001
1003 1002
1004 1003 class ObjectName(TraitType):
1005 1004 """A string holding a valid object name in this version of Python.
1006 1005
1007 1006 This does not check that the name exists in any scope."""
1008 1007 info_text = "a valid object identifier in Python"
1009 1008
1010 1009 if sys.version_info[0] < 3:
1011 1010 # Python 2:
1012 1011 _name_re = re.compile(r"[a-zA-Z_][a-zA-Z0-9_]*$")
1013 1012 def isidentifier(self, s):
1014 1013 return bool(self._name_re.match(s))
1015 1014
1016 1015 def coerce_str(self, obj, value):
1017 1016 "In Python 2, coerce ascii-only unicode to str"
1018 1017 if isinstance(value, unicode):
1019 1018 try:
1020 1019 return str(value)
1021 1020 except UnicodeEncodeError:
1022 1021 self.error(obj, value)
1023 1022 return value
1024 1023
1025 1024 else:
1026 1025 # Python 3:
1027 1026 isidentifier = staticmethod(lambda s: s.isidentifier())
1028 1027 coerce_str = staticmethod(lambda _,s: s)
1029 1028
1030 1029 def validate(self, obj, value):
1031 1030 value = self.coerce_str(obj, value)
1032 1031
1033 1032 if isinstance(value, str) and self.isidentifier(value):
1034 1033 return value
1035 1034 self.error(obj, value)
1036 1035
1037 1036 class DottedObjectName(ObjectName):
1038 1037 """A string holding a valid dotted object name in Python, such as A.b3._c"""
1039 1038 def validate(self, obj, value):
1040 1039 value = self.coerce_str(obj, value)
1041 1040
1042 1041 if isinstance(value, str) and all(self.isidentifier(x) \
1043 1042 for x in value.split('.')):
1044 1043 return value
1045 1044 self.error(obj, value)
1046 1045
1047 1046
1048 1047 class Bool(TraitType):
1049 1048 """A boolean (True, False) trait."""
1050 1049
1051 1050 default_value = False
1052 1051 info_text = 'a boolean'
1053 1052
1054 1053 def validate(self, obj, value):
1055 1054 if isinstance(value, bool):
1056 1055 return value
1057 1056 self.error(obj, value)
1058 1057
1059 1058
1060 1059 class CBool(Bool):
1061 1060 """A casting version of the boolean trait."""
1062 1061
1063 1062 def validate(self, obj, value):
1064 1063 try:
1065 1064 return bool(value)
1066 1065 except:
1067 1066 self.error(obj, value)
1068 1067
1069 1068
1070 1069 class Enum(TraitType):
1071 1070 """An enum that whose value must be in a given sequence."""
1072 1071
1073 1072 def __init__(self, values, default_value=None, allow_none=True, **metadata):
1074 1073 self.values = values
1075 1074 self._allow_none = allow_none
1076 1075 super(Enum, self).__init__(default_value, **metadata)
1077 1076
1078 1077 def validate(self, obj, value):
1079 1078 if value is None:
1080 1079 if self._allow_none:
1081 1080 return value
1082 1081
1083 1082 if value in self.values:
1084 1083 return value
1085 1084 self.error(obj, value)
1086 1085
1087 1086 def info(self):
1088 1087 """ Returns a description of the trait."""
1089 1088 result = 'any of ' + repr(self.values)
1090 1089 if self._allow_none:
1091 1090 return result + ' or None'
1092 1091 return result
1093 1092
1094 1093 class CaselessStrEnum(Enum):
1095 1094 """An enum of strings that are caseless in validate."""
1096 1095
1097 1096 def validate(self, obj, value):
1098 1097 if value is None:
1099 1098 if self._allow_none:
1100 1099 return value
1101 1100
1102 1101 if not isinstance(value, basestring):
1103 1102 self.error(obj, value)
1104 1103
1105 1104 for v in self.values:
1106 1105 if v.lower() == value.lower():
1107 1106 return v
1108 1107 self.error(obj, value)
1109 1108
1110 1109 class Container(Instance):
1111 1110 """An instance of a container (list, set, etc.)
1112 1111
1113 1112 To be subclassed by overriding klass.
1114 1113 """
1115 1114 klass = None
1116 1115 _valid_defaults = SequenceTypes
1117 1116 _trait = None
1118 1117
1119 1118 def __init__(self, trait=None, default_value=None, allow_none=True,
1120 1119 **metadata):
1121 1120 """Create a container trait type from a list, set, or tuple.
1122 1121
1123 1122 The default value is created by doing ``List(default_value)``,
1124 1123 which creates a copy of the ``default_value``.
1125 1124
1126 1125 ``trait`` can be specified, which restricts the type of elements
1127 1126 in the container to that TraitType.
1128 1127
1129 1128 If only one arg is given and it is not a Trait, it is taken as
1130 1129 ``default_value``:
1131 1130
1132 1131 ``c = List([1,2,3])``
1133 1132
1134 1133 Parameters
1135 1134 ----------
1136 1135
1137 1136 trait : TraitType [ optional ]
1138 1137 the type for restricting the contents of the Container. If unspecified,
1139 1138 types are not checked.
1140 1139
1141 1140 default_value : SequenceType [ optional ]
1142 1141 The default value for the Trait. Must be list/tuple/set, and
1143 1142 will be cast to the container type.
1144 1143
1145 1144 allow_none : Bool [ default True ]
1146 1145 Whether to allow the value to be None
1147 1146
1148 1147 **metadata : any
1149 1148 further keys for extensions to the Trait (e.g. config)
1150 1149
1151 1150 """
1152 1151 istrait = lambda t: isinstance(t, type) and issubclass(t, TraitType)
1153 1152
1154 1153 # allow List([values]):
1155 1154 if default_value is None and not istrait(trait):
1156 1155 default_value = trait
1157 1156 trait = None
1158 1157
1159 1158 if default_value is None:
1160 1159 args = ()
1161 1160 elif isinstance(default_value, self._valid_defaults):
1162 1161 args = (default_value,)
1163 1162 else:
1164 1163 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1165 1164
1166 1165 if istrait(trait):
1167 1166 self._trait = trait()
1168 1167 self._trait.name = 'element'
1169 1168 elif trait is not None:
1170 1169 raise TypeError("`trait` must be a Trait or None, got %s"%repr_type(trait))
1171 1170
1172 1171 super(Container,self).__init__(klass=self.klass, args=args,
1173 1172 allow_none=allow_none, **metadata)
1174 1173
1175 1174 def element_error(self, obj, element, validator):
1176 1175 e = "Element of the '%s' trait of %s instance must be %s, but a value of %s was specified." \
1177 1176 % (self.name, class_of(obj), validator.info(), repr_type(element))
1178 1177 raise TraitError(e)
1179 1178
1180 1179 def validate(self, obj, value):
1181 1180 value = super(Container, self).validate(obj, value)
1182 1181 if value is None:
1183 1182 return value
1184 1183
1185 1184 value = self.validate_elements(obj, value)
1186 1185
1187 1186 return value
1188 1187
1189 1188 def validate_elements(self, obj, value):
1190 1189 validated = []
1191 1190 if self._trait is None or isinstance(self._trait, Any):
1192 1191 return value
1193 1192 for v in value:
1194 1193 try:
1195 1194 v = self._trait.validate(obj, v)
1196 1195 except TraitError:
1197 1196 self.element_error(obj, v, self._trait)
1198 1197 else:
1199 1198 validated.append(v)
1200 1199 return self.klass(validated)
1201 1200
1202 1201
1203 1202 class List(Container):
1204 1203 """An instance of a Python list."""
1205 1204 klass = list
1206 1205
1207 1206 def __init__(self, trait=None, default_value=None, minlen=0, maxlen=sys.maxint,
1208 1207 allow_none=True, **metadata):
1209 1208 """Create a List trait type from a list, set, or tuple.
1210 1209
1211 1210 The default value is created by doing ``List(default_value)``,
1212 1211 which creates a copy of the ``default_value``.
1213 1212
1214 1213 ``trait`` can be specified, which restricts the type of elements
1215 1214 in the container to that TraitType.
1216 1215
1217 1216 If only one arg is given and it is not a Trait, it is taken as
1218 1217 ``default_value``:
1219 1218
1220 1219 ``c = List([1,2,3])``
1221 1220
1222 1221 Parameters
1223 1222 ----------
1224 1223
1225 1224 trait : TraitType [ optional ]
1226 1225 the type for restricting the contents of the Container. If unspecified,
1227 1226 types are not checked.
1228 1227
1229 1228 default_value : SequenceType [ optional ]
1230 1229 The default value for the Trait. Must be list/tuple/set, and
1231 1230 will be cast to the container type.
1232 1231
1233 1232 minlen : Int [ default 0 ]
1234 1233 The minimum length of the input list
1235 1234
1236 1235 maxlen : Int [ default sys.maxint ]
1237 1236 The maximum length of the input list
1238 1237
1239 1238 allow_none : Bool [ default True ]
1240 1239 Whether to allow the value to be None
1241 1240
1242 1241 **metadata : any
1243 1242 further keys for extensions to the Trait (e.g. config)
1244 1243
1245 1244 """
1246 1245 self._minlen = minlen
1247 1246 self._maxlen = maxlen
1248 1247 super(List, self).__init__(trait=trait, default_value=default_value,
1249 1248 allow_none=allow_none, **metadata)
1250 1249
1251 1250 def length_error(self, obj, value):
1252 1251 e = "The '%s' trait of %s instance must be of length %i <= L <= %i, but a value of %s was specified." \
1253 1252 % (self.name, class_of(obj), self._minlen, self._maxlen, value)
1254 1253 raise TraitError(e)
1255 1254
1256 1255 def validate_elements(self, obj, value):
1257 1256 length = len(value)
1258 1257 if length < self._minlen or length > self._maxlen:
1259 1258 self.length_error(obj, value)
1260 1259
1261 1260 return super(List, self).validate_elements(obj, value)
1262 1261
1263 1262
1264 1263 class Set(Container):
1265 1264 """An instance of a Python set."""
1266 1265 klass = set
1267 1266
1268 1267 class Tuple(Container):
1269 1268 """An instance of a Python tuple."""
1270 1269 klass = tuple
1271 1270
1272 1271 def __init__(self, *traits, **metadata):
1273 1272 """Tuple(*traits, default_value=None, allow_none=True, **medatata)
1274 1273
1275 1274 Create a tuple from a list, set, or tuple.
1276 1275
1277 1276 Create a fixed-type tuple with Traits:
1278 1277
1279 1278 ``t = Tuple(Int, Str, CStr)``
1280 1279
1281 1280 would be length 3, with Int,Str,CStr for each element.
1282 1281
1283 1282 If only one arg is given and it is not a Trait, it is taken as
1284 1283 default_value:
1285 1284
1286 1285 ``t = Tuple((1,2,3))``
1287 1286
1288 1287 Otherwise, ``default_value`` *must* be specified by keyword.
1289 1288
1290 1289 Parameters
1291 1290 ----------
1292 1291
1293 1292 *traits : TraitTypes [ optional ]
1294 1293 the tsype for restricting the contents of the Tuple. If unspecified,
1295 1294 types are not checked. If specified, then each positional argument
1296 1295 corresponds to an element of the tuple. Tuples defined with traits
1297 1296 are of fixed length.
1298 1297
1299 1298 default_value : SequenceType [ optional ]
1300 1299 The default value for the Tuple. Must be list/tuple/set, and
1301 1300 will be cast to a tuple. If `traits` are specified, the
1302 1301 `default_value` must conform to the shape and type they specify.
1303 1302
1304 1303 allow_none : Bool [ default True ]
1305 1304 Whether to allow the value to be None
1306 1305
1307 1306 **metadata : any
1308 1307 further keys for extensions to the Trait (e.g. config)
1309 1308
1310 1309 """
1311 1310 default_value = metadata.pop('default_value', None)
1312 1311 allow_none = metadata.pop('allow_none', True)
1313 1312
1314 1313 istrait = lambda t: isinstance(t, type) and issubclass(t, TraitType)
1315 1314
1316 1315 # allow Tuple((values,)):
1317 1316 if len(traits) == 1 and default_value is None and not istrait(traits[0]):
1318 1317 default_value = traits[0]
1319 1318 traits = ()
1320 1319
1321 1320 if default_value is None:
1322 1321 args = ()
1323 1322 elif isinstance(default_value, self._valid_defaults):
1324 1323 args = (default_value,)
1325 1324 else:
1326 1325 raise TypeError('default value of %s was %s' %(self.__class__.__name__, default_value))
1327 1326
1328 1327 self._traits = []
1329 1328 for trait in traits:
1330 1329 t = trait()
1331 1330 t.name = 'element'
1332 1331 self._traits.append(t)
1333 1332
1334 1333 if self._traits and default_value is None:
1335 1334 # don't allow default to be an empty container if length is specified
1336 1335 args = None
1337 1336 super(Container,self).__init__(klass=self.klass, args=args,
1338 1337 allow_none=allow_none, **metadata)
1339 1338
1340 1339 def validate_elements(self, obj, value):
1341 1340 if not self._traits:
1342 1341 # nothing to validate
1343 1342 return value
1344 1343 if len(value) != len(self._traits):
1345 1344 e = "The '%s' trait of %s instance requires %i elements, but a value of %s was specified." \
1346 1345 % (self.name, class_of(obj), len(self._traits), repr_type(value))
1347 1346 raise TraitError(e)
1348 1347
1349 1348 validated = []
1350 1349 for t,v in zip(self._traits, value):
1351 1350 try:
1352 1351 v = t.validate(obj, v)
1353 1352 except TraitError:
1354 1353 self.element_error(obj, v, t)
1355 1354 else:
1356 1355 validated.append(v)
1357 1356 return tuple(validated)
1358 1357
1359 1358
1360 1359 class Dict(Instance):
1361 1360 """An instance of a Python dict."""
1362 1361
1363 1362 def __init__(self, default_value=None, allow_none=True, **metadata):
1364 1363 """Create a dict trait type from a dict.
1365 1364
1366 1365 The default value is created by doing ``dict(default_value)``,
1367 1366 which creates a copy of the ``default_value``.
1368 1367 """
1369 1368 if default_value is None:
1370 1369 args = ((),)
1371 1370 elif isinstance(default_value, dict):
1372 1371 args = (default_value,)
1373 1372 elif isinstance(default_value, SequenceTypes):
1374 1373 args = (default_value,)
1375 1374 else:
1376 1375 raise TypeError('default value of Dict was %s' % default_value)
1377 1376
1378 1377 super(Dict,self).__init__(klass=dict, args=args,
1379 1378 allow_none=allow_none, **metadata)
1380 1379
1381 1380 class TCPAddress(TraitType):
1382 1381 """A trait for an (ip, port) tuple.
1383 1382
1384 1383 This allows for both IPv4 IP addresses as well as hostnames.
1385 1384 """
1386 1385
1387 1386 default_value = ('127.0.0.1', 0)
1388 1387 info_text = 'an (ip, port) tuple'
1389 1388
1390 1389 def validate(self, obj, value):
1391 1390 if isinstance(value, tuple):
1392 1391 if len(value) == 2:
1393 1392 if isinstance(value[0], basestring) and isinstance(value[1], int):
1394 1393 port = value[1]
1395 1394 if port >= 0 and port <= 65535:
1396 1395 return value
1397 1396 self.error(obj, value)
1 NO CONTENT: modified file chmod 100644 => 100755
@@ -1,227 +1,226 b''
1 #!/usr/bin/env python
2 1 """An Application for launching a kernel
3 2
4 3 Authors
5 4 -------
6 5 * MinRK
7 6 """
8 7 #-----------------------------------------------------------------------------
9 8 # Copyright (C) 2011 The IPython Development Team
10 9 #
11 10 # Distributed under the terms of the BSD License. The full license is in
12 11 # the file COPYING.txt, distributed as part of this software.
13 12 #-----------------------------------------------------------------------------
14 13
15 14 #-----------------------------------------------------------------------------
16 15 # Imports
17 16 #-----------------------------------------------------------------------------
18 17
19 18 # Standard library imports.
20 19 import os
21 20 import sys
22 21
23 22 # System library imports.
24 23 import zmq
25 24
26 25 # IPython imports.
27 26 from IPython.core.ultratb import FormattedTB
28 27 from IPython.core.application import (
29 28 BaseIPythonApplication, base_flags, base_aliases
30 29 )
31 30 from IPython.utils import io
32 31 from IPython.utils.localinterfaces import LOCALHOST
33 32 from IPython.utils.traitlets import (Any, Instance, Dict, Unicode, Int, Bool,
34 33 DottedObjectName)
35 34 from IPython.utils.importstring import import_item
36 35 # local imports
37 36 from IPython.zmq.heartbeat import Heartbeat
38 37 from IPython.zmq.parentpoller import ParentPollerUnix, ParentPollerWindows
39 38 from IPython.zmq.session import Session
40 39
41 40
42 41 #-----------------------------------------------------------------------------
43 42 # Flags and Aliases
44 43 #-----------------------------------------------------------------------------
45 44
46 45 kernel_aliases = dict(base_aliases)
47 46 kernel_aliases.update({
48 47 'ip' : 'KernelApp.ip',
49 48 'hb' : 'KernelApp.hb_port',
50 49 'shell' : 'KernelApp.shell_port',
51 50 'iopub' : 'KernelApp.iopub_port',
52 51 'stdin' : 'KernelApp.stdin_port',
53 52 'parent': 'KernelApp.parent',
54 53 })
55 54 if sys.platform.startswith('win'):
56 55 kernel_aliases['interrupt'] = 'KernelApp.interrupt'
57 56
58 57 kernel_flags = dict(base_flags)
59 58 kernel_flags.update({
60 59 'no-stdout' : (
61 60 {'KernelApp' : {'no_stdout' : True}},
62 61 "redirect stdout to the null device"),
63 62 'no-stderr' : (
64 63 {'KernelApp' : {'no_stderr' : True}},
65 64 "redirect stderr to the null device"),
66 65 })
67 66
68 67
69 68 #-----------------------------------------------------------------------------
70 69 # Application class for starting a Kernel
71 70 #-----------------------------------------------------------------------------
72 71
73 72 class KernelApp(BaseIPythonApplication):
74 73 name='pykernel'
75 74 aliases = Dict(kernel_aliases)
76 75 flags = Dict(kernel_flags)
77 76 classes = [Session]
78 77 # the kernel class, as an importstring
79 78 kernel_class = DottedObjectName('IPython.zmq.pykernel.Kernel')
80 79 kernel = Any()
81 80 poller = Any() # don't restrict this even though current pollers are all Threads
82 81 heartbeat = Instance(Heartbeat)
83 82 session = Instance('IPython.zmq.session.Session')
84 83 ports = Dict()
85 84
86 85 # inherit config file name from parent:
87 86 parent_appname = Unicode(config=True)
88 87 def _parent_appname_changed(self, name, old, new):
89 88 if self.config_file_specified:
90 89 # it was manually specified, ignore
91 90 return
92 91 self.config_file_name = new.replace('-','_') + u'_config.py'
93 92 # don't let this count as specifying the config file
94 93 self.config_file_specified = False
95 94
96 95 # connection info:
97 96 ip = Unicode(LOCALHOST, config=True,
98 97 help="Set the IP or interface on which the kernel will listen.")
99 98 hb_port = Int(0, config=True, help="set the heartbeat port [default: random]")
100 99 shell_port = Int(0, config=True, help="set the shell (XREP) port [default: random]")
101 100 iopub_port = Int(0, config=True, help="set the iopub (PUB) port [default: random]")
102 101 stdin_port = Int(0, config=True, help="set the stdin (XREQ) port [default: random]")
103 102
104 103 # streams, etc.
105 104 no_stdout = Bool(False, config=True, help="redirect stdout to the null device")
106 105 no_stderr = Bool(False, config=True, help="redirect stderr to the null device")
107 106 outstream_class = DottedObjectName('IPython.zmq.iostream.OutStream',
108 107 config=True, help="The importstring for the OutStream factory")
109 108 displayhook_class = DottedObjectName('IPython.zmq.displayhook.ZMQDisplayHook',
110 109 config=True, help="The importstring for the DisplayHook factory")
111 110
112 111 # polling
113 112 parent = Int(0, config=True,
114 113 help="""kill this process if its parent dies. On Windows, the argument
115 114 specifies the HANDLE of the parent process, otherwise it is simply boolean.
116 115 """)
117 116 interrupt = Int(0, config=True,
118 117 help="""ONLY USED ON WINDOWS
119 118 Interrupt this process when the parent is signalled.
120 119 """)
121 120
122 121 def init_crash_handler(self):
123 122 # Install minimal exception handling
124 123 sys.excepthook = FormattedTB(mode='Verbose', color_scheme='NoColor',
125 124 ostream=sys.__stdout__)
126 125
127 126 def init_poller(self):
128 127 if sys.platform == 'win32':
129 128 if self.interrupt or self.parent:
130 129 self.poller = ParentPollerWindows(self.interrupt, self.parent)
131 130 elif self.parent:
132 131 self.poller = ParentPollerUnix()
133 132
134 133 def _bind_socket(self, s, port):
135 134 iface = 'tcp://%s' % self.ip
136 135 if port <= 0:
137 136 port = s.bind_to_random_port(iface)
138 137 else:
139 138 s.bind(iface + ':%i'%port)
140 139 return port
141 140
142 141 def init_sockets(self):
143 142 # Create a context, a session, and the kernel sockets.
144 143 self.log.info("Starting the kernel at pid: %i", os.getpid())
145 144 context = zmq.Context.instance()
146 145 # Uncomment this to try closing the context.
147 146 # atexit.register(context.term)
148 147
149 148 self.shell_socket = context.socket(zmq.XREP)
150 149 self.shell_port = self._bind_socket(self.shell_socket, self.shell_port)
151 150 self.log.debug("shell XREP Channel on port: %i"%self.shell_port)
152 151
153 152 self.iopub_socket = context.socket(zmq.PUB)
154 153 self.iopub_port = self._bind_socket(self.iopub_socket, self.iopub_port)
155 154 self.log.debug("iopub PUB Channel on port: %i"%self.iopub_port)
156 155
157 156 self.stdin_socket = context.socket(zmq.XREQ)
158 157 self.stdin_port = self._bind_socket(self.stdin_socket, self.stdin_port)
159 158 self.log.debug("stdin XREQ Channel on port: %i"%self.stdin_port)
160 159
161 160 self.heartbeat = Heartbeat(context, (self.ip, self.hb_port))
162 161 self.hb_port = self.heartbeat.port
163 162 self.log.debug("Heartbeat REP Channel on port: %i"%self.hb_port)
164 163
165 164 # Helper to make it easier to connect to an existing kernel, until we have
166 165 # single-port connection negotiation fully implemented.
167 166 # set log-level to critical, to make sure it is output
168 167 self.log.critical("To connect another client to this kernel, use:")
169 168 self.log.critical("--existing --shell={0} --iopub={1} --stdin={2} --hb={3}".format(
170 169 self.shell_port, self.iopub_port, self.stdin_port, self.hb_port))
171 170
172 171
173 172 self.ports = dict(shell=self.shell_port, iopub=self.iopub_port,
174 173 stdin=self.stdin_port, hb=self.hb_port)
175 174
176 175 def init_session(self):
177 176 """create our session object"""
178 177 self.session = Session(config=self.config, username=u'kernel')
179 178
180 179 def init_blackhole(self):
181 180 """redirects stdout/stderr to devnull if necessary"""
182 181 if self.no_stdout or self.no_stderr:
183 182 blackhole = file(os.devnull, 'w')
184 183 if self.no_stdout:
185 184 sys.stdout = sys.__stdout__ = blackhole
186 185 if self.no_stderr:
187 186 sys.stderr = sys.__stderr__ = blackhole
188 187
189 188 def init_io(self):
190 189 """Redirect input streams and set a display hook."""
191 190 if self.outstream_class:
192 191 outstream_factory = import_item(str(self.outstream_class))
193 192 sys.stdout = outstream_factory(self.session, self.iopub_socket, u'stdout')
194 193 sys.stderr = outstream_factory(self.session, self.iopub_socket, u'stderr')
195 194 if self.displayhook_class:
196 195 displayhook_factory = import_item(str(self.displayhook_class))
197 196 sys.displayhook = displayhook_factory(self.session, self.iopub_socket)
198 197
199 198 def init_kernel(self):
200 199 """Create the Kernel object itself"""
201 200 kernel_factory = import_item(str(self.kernel_class))
202 201 self.kernel = kernel_factory(config=self.config, session=self.session,
203 202 shell_socket=self.shell_socket,
204 203 iopub_socket=self.iopub_socket,
205 204 stdin_socket=self.stdin_socket,
206 205 log=self.log
207 206 )
208 207 self.kernel.record_ports(self.ports)
209 208
210 209 def initialize(self, argv=None):
211 210 super(KernelApp, self).initialize(argv)
212 211 self.init_blackhole()
213 212 self.init_session()
214 213 self.init_poller()
215 214 self.init_sockets()
216 215 self.init_io()
217 216 self.init_kernel()
218 217
219 218 def start(self):
220 219 self.heartbeat.start()
221 220 if self.poller is not None:
222 221 self.poller.start()
223 222 try:
224 223 self.kernel.start()
225 224 except KeyboardInterrupt:
226 225 pass
227 226
@@ -1,679 +1,678 b''
1 #!/usr/bin/env python
2 1 """Session object for building, serializing, sending, and receiving messages in
3 2 IPython. The Session object supports serialization, HMAC signatures, and
4 3 metadata on messages.
5 4
6 5 Also defined here are utilities for working with Sessions:
7 6 * A SessionFactory to be used as a base class for configurables that work with
8 7 Sessions.
9 8 * A Message object for convenience that allows attribute-access to the msg dict.
10 9
11 10 Authors:
12 11
13 12 * Min RK
14 13 * Brian Granger
15 14 * Fernando Perez
16 15 """
17 16 #-----------------------------------------------------------------------------
18 17 # Copyright (C) 2010-2011 The IPython Development Team
19 18 #
20 19 # Distributed under the terms of the BSD License. The full license is in
21 20 # the file COPYING, distributed as part of this software.
22 21 #-----------------------------------------------------------------------------
23 22
24 23 #-----------------------------------------------------------------------------
25 24 # Imports
26 25 #-----------------------------------------------------------------------------
27 26
28 27 import hmac
29 28 import logging
30 29 import os
31 30 import pprint
32 31 import uuid
33 32 from datetime import datetime
34 33
35 34 try:
36 35 import cPickle
37 36 pickle = cPickle
38 37 except:
39 38 cPickle = None
40 39 import pickle
41 40
42 41 import zmq
43 42 from zmq.utils import jsonapi
44 43 from zmq.eventloop.ioloop import IOLoop
45 44 from zmq.eventloop.zmqstream import ZMQStream
46 45
47 46 from IPython.config.configurable import Configurable, LoggingConfigurable
48 47 from IPython.utils.importstring import import_item
49 48 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
50 49 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
51 50 DottedObjectName)
52 51
53 52 #-----------------------------------------------------------------------------
54 53 # utility functions
55 54 #-----------------------------------------------------------------------------
56 55
57 56 def squash_unicode(obj):
58 57 """coerce unicode back to bytestrings."""
59 58 if isinstance(obj,dict):
60 59 for key in obj.keys():
61 60 obj[key] = squash_unicode(obj[key])
62 61 if isinstance(key, unicode):
63 62 obj[squash_unicode(key)] = obj.pop(key)
64 63 elif isinstance(obj, list):
65 64 for i,v in enumerate(obj):
66 65 obj[i] = squash_unicode(v)
67 66 elif isinstance(obj, unicode):
68 67 obj = obj.encode('utf8')
69 68 return obj
70 69
71 70 #-----------------------------------------------------------------------------
72 71 # globals and defaults
73 72 #-----------------------------------------------------------------------------
74 73 key = 'on_unknown' if jsonapi.jsonmod.__name__ == 'jsonlib' else 'default'
75 74 json_packer = lambda obj: jsonapi.dumps(obj, **{key:date_default})
76 75 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
77 76
78 77 pickle_packer = lambda o: pickle.dumps(o,-1)
79 78 pickle_unpacker = pickle.loads
80 79
81 80 default_packer = json_packer
82 81 default_unpacker = json_unpacker
83 82
84 83
85 84 DELIM=b"<IDS|MSG>"
86 85
87 86 #-----------------------------------------------------------------------------
88 87 # Classes
89 88 #-----------------------------------------------------------------------------
90 89
91 90 class SessionFactory(LoggingConfigurable):
92 91 """The Base class for configurables that have a Session, Context, logger,
93 92 and IOLoop.
94 93 """
95 94
96 95 logname = Unicode('')
97 96 def _logname_changed(self, name, old, new):
98 97 self.log = logging.getLogger(new)
99 98
100 99 # not configurable:
101 100 context = Instance('zmq.Context')
102 101 def _context_default(self):
103 102 return zmq.Context.instance()
104 103
105 104 session = Instance('IPython.zmq.session.Session')
106 105
107 106 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
108 107 def _loop_default(self):
109 108 return IOLoop.instance()
110 109
111 110 def __init__(self, **kwargs):
112 111 super(SessionFactory, self).__init__(**kwargs)
113 112
114 113 if self.session is None:
115 114 # construct the session
116 115 self.session = Session(**kwargs)
117 116
118 117
119 118 class Message(object):
120 119 """A simple message object that maps dict keys to attributes.
121 120
122 121 A Message can be created from a dict and a dict from a Message instance
123 122 simply by calling dict(msg_obj)."""
124 123
125 124 def __init__(self, msg_dict):
126 125 dct = self.__dict__
127 126 for k, v in dict(msg_dict).iteritems():
128 127 if isinstance(v, dict):
129 128 v = Message(v)
130 129 dct[k] = v
131 130
132 131 # Having this iterator lets dict(msg_obj) work out of the box.
133 132 def __iter__(self):
134 133 return iter(self.__dict__.iteritems())
135 134
136 135 def __repr__(self):
137 136 return repr(self.__dict__)
138 137
139 138 def __str__(self):
140 139 return pprint.pformat(self.__dict__)
141 140
142 141 def __contains__(self, k):
143 142 return k in self.__dict__
144 143
145 144 def __getitem__(self, k):
146 145 return self.__dict__[k]
147 146
148 147
149 148 def msg_header(msg_id, msg_type, username, session):
150 149 date = datetime.now()
151 150 return locals()
152 151
153 152 def extract_header(msg_or_header):
154 153 """Given a message or header, return the header."""
155 154 if not msg_or_header:
156 155 return {}
157 156 try:
158 157 # See if msg_or_header is the entire message.
159 158 h = msg_or_header['header']
160 159 except KeyError:
161 160 try:
162 161 # See if msg_or_header is just the header
163 162 h = msg_or_header['msg_id']
164 163 except KeyError:
165 164 raise
166 165 else:
167 166 h = msg_or_header
168 167 if not isinstance(h, dict):
169 168 h = dict(h)
170 169 return h
171 170
172 171 class Session(Configurable):
173 172 """Object for handling serialization and sending of messages.
174 173
175 174 The Session object handles building messages and sending them
176 175 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
177 176 other over the network via Session objects, and only need to work with the
178 177 dict-based IPython message spec. The Session will handle
179 178 serialization/deserialization, security, and metadata.
180 179
181 180 Sessions support configurable serialiization via packer/unpacker traits,
182 181 and signing with HMAC digests via the key/keyfile traits.
183 182
184 183 Parameters
185 184 ----------
186 185
187 186 debug : bool
188 187 whether to trigger extra debugging statements
189 188 packer/unpacker : str : 'json', 'pickle' or import_string
190 189 importstrings for methods to serialize message parts. If just
191 190 'json' or 'pickle', predefined JSON and pickle packers will be used.
192 191 Otherwise, the entire importstring must be used.
193 192
194 193 The functions must accept at least valid JSON input, and output *bytes*.
195 194
196 195 For example, to use msgpack:
197 196 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
198 197 pack/unpack : callables
199 198 You can also set the pack/unpack callables for serialization directly.
200 199 session : bytes
201 200 the ID of this Session object. The default is to generate a new UUID.
202 201 username : unicode
203 202 username added to message headers. The default is to ask the OS.
204 203 key : bytes
205 204 The key used to initialize an HMAC signature. If unset, messages
206 205 will not be signed or checked.
207 206 keyfile : filepath
208 207 The file containing a key. If this is set, `key` will be initialized
209 208 to the contents of the file.
210 209
211 210 """
212 211
213 212 debug=Bool(False, config=True, help="""Debug output in the Session""")
214 213
215 214 packer = DottedObjectName('json',config=True,
216 215 help="""The name of the packer for serializing messages.
217 216 Should be one of 'json', 'pickle', or an import name
218 217 for a custom callable serializer.""")
219 218 def _packer_changed(self, name, old, new):
220 219 if new.lower() == 'json':
221 220 self.pack = json_packer
222 221 self.unpack = json_unpacker
223 222 elif new.lower() == 'pickle':
224 223 self.pack = pickle_packer
225 224 self.unpack = pickle_unpacker
226 225 else:
227 226 self.pack = import_item(str(new))
228 227
229 228 unpacker = DottedObjectName('json', config=True,
230 229 help="""The name of the unpacker for unserializing messages.
231 230 Only used with custom functions for `packer`.""")
232 231 def _unpacker_changed(self, name, old, new):
233 232 if new.lower() == 'json':
234 233 self.pack = json_packer
235 234 self.unpack = json_unpacker
236 235 elif new.lower() == 'pickle':
237 236 self.pack = pickle_packer
238 237 self.unpack = pickle_unpacker
239 238 else:
240 239 self.unpack = import_item(str(new))
241 240
242 241 session = CBytes(b'', config=True,
243 242 help="""The UUID identifying this session.""")
244 243 def _session_default(self):
245 244 return bytes(uuid.uuid4())
246 245
247 246 username = Unicode(os.environ.get('USER','username'), config=True,
248 247 help="""Username for the Session. Default is your system username.""")
249 248
250 249 # message signature related traits:
251 250 key = CBytes(b'', config=True,
252 251 help="""execution key, for extra authentication.""")
253 252 def _key_changed(self, name, old, new):
254 253 if new:
255 254 self.auth = hmac.HMAC(new)
256 255 else:
257 256 self.auth = None
258 257 auth = Instance(hmac.HMAC)
259 258 digest_history = Set()
260 259
261 260 keyfile = Unicode('', config=True,
262 261 help="""path to file containing execution key.""")
263 262 def _keyfile_changed(self, name, old, new):
264 263 with open(new, 'rb') as f:
265 264 self.key = f.read().strip()
266 265
267 266 pack = Any(default_packer) # the actual packer function
268 267 def _pack_changed(self, name, old, new):
269 268 if not callable(new):
270 269 raise TypeError("packer must be callable, not %s"%type(new))
271 270
272 271 unpack = Any(default_unpacker) # the actual packer function
273 272 def _unpack_changed(self, name, old, new):
274 273 # unpacker is not checked - it is assumed to be
275 274 if not callable(new):
276 275 raise TypeError("unpacker must be callable, not %s"%type(new))
277 276
278 277 def __init__(self, **kwargs):
279 278 """create a Session object
280 279
281 280 Parameters
282 281 ----------
283 282
284 283 debug : bool
285 284 whether to trigger extra debugging statements
286 285 packer/unpacker : str : 'json', 'pickle' or import_string
287 286 importstrings for methods to serialize message parts. If just
288 287 'json' or 'pickle', predefined JSON and pickle packers will be used.
289 288 Otherwise, the entire importstring must be used.
290 289
291 290 The functions must accept at least valid JSON input, and output
292 291 *bytes*.
293 292
294 293 For example, to use msgpack:
295 294 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
296 295 pack/unpack : callables
297 296 You can also set the pack/unpack callables for serialization
298 297 directly.
299 298 session : bytes
300 299 the ID of this Session object. The default is to generate a new
301 300 UUID.
302 301 username : unicode
303 302 username added to message headers. The default is to ask the OS.
304 303 key : bytes
305 304 The key used to initialize an HMAC signature. If unset, messages
306 305 will not be signed or checked.
307 306 keyfile : filepath
308 307 The file containing a key. If this is set, `key` will be
309 308 initialized to the contents of the file.
310 309 """
311 310 super(Session, self).__init__(**kwargs)
312 311 self._check_packers()
313 312 self.none = self.pack({})
314 313
315 314 @property
316 315 def msg_id(self):
317 316 """always return new uuid"""
318 317 return str(uuid.uuid4())
319 318
320 319 def _check_packers(self):
321 320 """check packers for binary data and datetime support."""
322 321 pack = self.pack
323 322 unpack = self.unpack
324 323
325 324 # check simple serialization
326 325 msg = dict(a=[1,'hi'])
327 326 try:
328 327 packed = pack(msg)
329 328 except Exception:
330 329 raise ValueError("packer could not serialize a simple message")
331 330
332 331 # ensure packed message is bytes
333 332 if not isinstance(packed, bytes):
334 333 raise ValueError("message packed to %r, but bytes are required"%type(packed))
335 334
336 335 # check that unpack is pack's inverse
337 336 try:
338 337 unpacked = unpack(packed)
339 338 except Exception:
340 339 raise ValueError("unpacker could not handle the packer's output")
341 340
342 341 # check datetime support
343 342 msg = dict(t=datetime.now())
344 343 try:
345 344 unpacked = unpack(pack(msg))
346 345 except Exception:
347 346 self.pack = lambda o: pack(squash_dates(o))
348 347 self.unpack = lambda s: extract_dates(unpack(s))
349 348
350 349 def msg_header(self, msg_type):
351 350 return msg_header(self.msg_id, msg_type, self.username, self.session)
352 351
353 352 def msg(self, msg_type, content=None, parent=None, subheader=None):
354 353 """Return the nested message dict.
355 354
356 355 This format is different from what is sent over the wire. The
357 356 self.serialize method converts this nested message dict to the wire
358 357 format, which uses a message list.
359 358 """
360 359 msg = {}
361 360 msg['header'] = self.msg_header(msg_type)
362 361 msg['msg_id'] = msg['header']['msg_id']
363 362 msg['parent_header'] = {} if parent is None else extract_header(parent)
364 363 msg['msg_type'] = msg_type
365 364 msg['content'] = {} if content is None else content
366 365 sub = {} if subheader is None else subheader
367 366 msg['header'].update(sub)
368 367 return msg
369 368
370 369 def sign(self, msg_list):
371 370 """Sign a message with HMAC digest. If no auth, return b''.
372 371
373 372 Parameters
374 373 ----------
375 374 msg_list : list
376 375 The [p_header,p_parent,p_content] part of the message list.
377 376 """
378 377 if self.auth is None:
379 378 return b''
380 379 h = self.auth.copy()
381 380 for m in msg_list:
382 381 h.update(m)
383 382 return h.hexdigest()
384 383
385 384 def serialize(self, msg, ident=None):
386 385 """Serialize the message components to bytes.
387 386
388 387 Parameters
389 388 ----------
390 389 msg : dict or Message
391 390 The nexted message dict as returned by the self.msg method.
392 391
393 392 Returns
394 393 -------
395 394 msg_list : list
396 395 The list of bytes objects to be sent with the format:
397 396 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
398 397 buffer1,buffer2,...]. In this list, the p_* entities are
399 398 the packed or serialized versions, so if JSON is used, these
400 399 are uft8 encoded JSON strings.
401 400 """
402 401 content = msg.get('content', {})
403 402 if content is None:
404 403 content = self.none
405 404 elif isinstance(content, dict):
406 405 content = self.pack(content)
407 406 elif isinstance(content, bytes):
408 407 # content is already packed, as in a relayed message
409 408 pass
410 409 elif isinstance(content, unicode):
411 410 # should be bytes, but JSON often spits out unicode
412 411 content = content.encode('utf8')
413 412 else:
414 413 raise TypeError("Content incorrect type: %s"%type(content))
415 414
416 415 real_message = [self.pack(msg['header']),
417 416 self.pack(msg['parent_header']),
418 417 content
419 418 ]
420 419
421 420 to_send = []
422 421
423 422 if isinstance(ident, list):
424 423 # accept list of idents
425 424 to_send.extend(ident)
426 425 elif ident is not None:
427 426 to_send.append(ident)
428 427 to_send.append(DELIM)
429 428
430 429 signature = self.sign(real_message)
431 430 to_send.append(signature)
432 431
433 432 to_send.extend(real_message)
434 433
435 434 return to_send
436 435
437 436 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
438 437 buffers=None, subheader=None, track=False):
439 438 """Build and send a message via stream or socket.
440 439
441 440 The message format used by this function internally is as follows:
442 441
443 442 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
444 443 buffer1,buffer2,...]
445 444
446 445 The self.serialize method converts the nested message dict into this
447 446 format.
448 447
449 448 Parameters
450 449 ----------
451 450
452 451 stream : zmq.Socket or ZMQStream
453 452 the socket-like object used to send the data
454 453 msg_or_type : str or Message/dict
455 454 Normally, msg_or_type will be a msg_type unless a message is being
456 455 sent more than once.
457 456
458 457 content : dict or None
459 458 the content of the message (ignored if msg_or_type is a message)
460 459 parent : Message or dict or None
461 460 the parent or parent header describing the parent of this message
462 461 ident : bytes or list of bytes
463 462 the zmq.IDENTITY routing path
464 463 subheader : dict or None
465 464 extra header keys for this message's header
466 465 buffers : list or None
467 466 the already-serialized buffers to be appended to the message
468 467 track : bool
469 468 whether to track. Only for use with Sockets,
470 469 because ZMQStream objects cannot track messages.
471 470
472 471 Returns
473 472 -------
474 473 msg : message dict
475 474 the constructed message
476 475 (msg,tracker) : (message dict, MessageTracker)
477 476 if track=True, then a 2-tuple will be returned,
478 477 the first element being the constructed
479 478 message, and the second being the MessageTracker
480 479
481 480 """
482 481
483 482 if not isinstance(stream, (zmq.Socket, ZMQStream)):
484 483 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
485 484 elif track and isinstance(stream, ZMQStream):
486 485 raise TypeError("ZMQStream cannot track messages")
487 486
488 487 if isinstance(msg_or_type, (Message, dict)):
489 488 # we got a Message, not a msg_type
490 489 # don't build a new Message
491 490 msg = msg_or_type
492 491 else:
493 492 msg = self.msg(msg_or_type, content, parent, subheader)
494 493
495 494 buffers = [] if buffers is None else buffers
496 495 to_send = self.serialize(msg, ident)
497 496 flag = 0
498 497 if buffers:
499 498 flag = zmq.SNDMORE
500 499 _track = False
501 500 else:
502 501 _track=track
503 502 if track:
504 503 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
505 504 else:
506 505 tracker = stream.send_multipart(to_send, flag, copy=False)
507 506 for b in buffers[:-1]:
508 507 stream.send(b, flag, copy=False)
509 508 if buffers:
510 509 if track:
511 510 tracker = stream.send(buffers[-1], copy=False, track=track)
512 511 else:
513 512 tracker = stream.send(buffers[-1], copy=False)
514 513
515 514 # omsg = Message(msg)
516 515 if self.debug:
517 516 pprint.pprint(msg)
518 517 pprint.pprint(to_send)
519 518 pprint.pprint(buffers)
520 519
521 520 msg['tracker'] = tracker
522 521
523 522 return msg
524 523
525 524 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
526 525 """Send a raw message via ident path.
527 526
528 527 This method is used to send a already serialized message.
529 528
530 529 Parameters
531 530 ----------
532 531 stream : ZMQStream or Socket
533 532 The ZMQ stream or socket to use for sending the message.
534 533 msg_list : list
535 534 The serialized list of messages to send. This only includes the
536 535 [p_header,p_parent,p_content,buffer1,buffer2,...] portion of
537 536 the message.
538 537 ident : ident or list
539 538 A single ident or a list of idents to use in sending.
540 539 """
541 540 to_send = []
542 541 if isinstance(ident, bytes):
543 542 ident = [ident]
544 543 if ident is not None:
545 544 to_send.extend(ident)
546 545
547 546 to_send.append(DELIM)
548 547 to_send.append(self.sign(msg_list))
549 548 to_send.extend(msg_list)
550 549 stream.send_multipart(msg_list, flags, copy=copy)
551 550
552 551 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
553 552 """Receive and unpack a message.
554 553
555 554 Parameters
556 555 ----------
557 556 socket : ZMQStream or Socket
558 557 The socket or stream to use in receiving.
559 558
560 559 Returns
561 560 -------
562 561 [idents], msg
563 562 [idents] is a list of idents and msg is a nested message dict of
564 563 same format as self.msg returns.
565 564 """
566 565 if isinstance(socket, ZMQStream):
567 566 socket = socket.socket
568 567 try:
569 568 msg_list = socket.recv_multipart(mode)
570 569 except zmq.ZMQError as e:
571 570 if e.errno == zmq.EAGAIN:
572 571 # We can convert EAGAIN to None as we know in this case
573 572 # recv_multipart won't return None.
574 573 return None,None
575 574 else:
576 575 raise
577 576 # split multipart message into identity list and message dict
578 577 # invalid large messages can cause very expensive string comparisons
579 578 idents, msg_list = self.feed_identities(msg_list, copy)
580 579 try:
581 580 return idents, self.unpack_message(msg_list, content=content, copy=copy)
582 581 except Exception as e:
583 582 print (idents, msg_list)
584 583 # TODO: handle it
585 584 raise e
586 585
587 586 def feed_identities(self, msg_list, copy=True):
588 587 """Split the identities from the rest of the message.
589 588
590 589 Feed until DELIM is reached, then return the prefix as idents and
591 590 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
592 591 but that would be silly.
593 592
594 593 Parameters
595 594 ----------
596 595 msg_list : a list of Message or bytes objects
597 596 The message to be split.
598 597 copy : bool
599 598 flag determining whether the arguments are bytes or Messages
600 599
601 600 Returns
602 601 -------
603 602 (idents,msg_list) : two lists
604 603 idents will always be a list of bytes - the indentity prefix
605 604 msg_list will be a list of bytes or Messages, unchanged from input
606 605 msg_list should be unpackable via self.unpack_message at this point.
607 606 """
608 607 if copy:
609 608 idx = msg_list.index(DELIM)
610 609 return msg_list[:idx], msg_list[idx+1:]
611 610 else:
612 611 failed = True
613 612 for idx,m in enumerate(msg_list):
614 613 if m.bytes == DELIM:
615 614 failed = False
616 615 break
617 616 if failed:
618 617 raise ValueError("DELIM not in msg_list")
619 618 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
620 619 return [m.bytes for m in idents], msg_list
621 620
622 621 def unpack_message(self, msg_list, content=True, copy=True):
623 622 """Return a message object from the format
624 623 sent by self.send.
625 624
626 625 Parameters:
627 626 -----------
628 627
629 628 content : bool (True)
630 629 whether to unpack the content dict (True),
631 630 or leave it serialized (False)
632 631
633 632 copy : bool (True)
634 633 whether to return the bytes (True),
635 634 or the non-copying Message object in each place (False)
636 635
637 636 """
638 637 minlen = 4
639 638 message = {}
640 639 if not copy:
641 640 for i in range(minlen):
642 641 msg_list[i] = msg_list[i].bytes
643 642 if self.auth is not None:
644 643 signature = msg_list[0]
645 644 if signature in self.digest_history:
646 645 raise ValueError("Duplicate Signature: %r"%signature)
647 646 self.digest_history.add(signature)
648 647 check = self.sign(msg_list[1:4])
649 648 if not signature == check:
650 649 raise ValueError("Invalid Signature: %r"%signature)
651 650 if not len(msg_list) >= minlen:
652 651 raise TypeError("malformed message, must have at least %i elements"%minlen)
653 652 message['header'] = self.unpack(msg_list[1])
654 653 message['msg_type'] = message['header']['msg_type']
655 654 message['parent_header'] = self.unpack(msg_list[2])
656 655 if content:
657 656 message['content'] = self.unpack(msg_list[3])
658 657 else:
659 658 message['content'] = msg_list[3]
660 659
661 660 message['buffers'] = msg_list[4:]
662 661 return message
663 662
664 663 def test_msg2obj():
665 664 am = dict(x=1)
666 665 ao = Message(am)
667 666 assert ao.x == am['x']
668 667
669 668 am['y'] = dict(z=1)
670 669 ao = Message(am)
671 670 assert ao.y.z == am['y']['z']
672 671
673 672 k1, k2 = 'y', 'z'
674 673 assert ao[k1][k2] == am[k1][k2]
675 674
676 675 am2 = dict(ao)
677 676 assert am['x'] == am2['x']
678 677 assert am['y']['z'] == am2['y']['z']
679 678
General Comments 0
You need to be logged in to leave comments. Login now