##// END OF EJS Templates
Merge pull request #1630 from minrk/mergekernel...
Brian E. Granger -
r6905:6cb2026f merge
parent child Browse files
Show More
@@ -0,0 +1,179 b''
1 """serialization utilities for apply messages
2
3 Authors:
4
5 * Min RK
6 """
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
13
14 #-----------------------------------------------------------------------------
15 # Imports
16 #-----------------------------------------------------------------------------
17
18 # Standard library imports
19 import logging
20 import os
21 import re
22 import socket
23 import sys
24
25 try:
26 import cPickle
27 pickle = cPickle
28 except:
29 cPickle = None
30 import pickle
31
32
33 # IPython imports
34 from IPython.utils import py3compat
35 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
36 from IPython.utils.newserialized import serialize, unserialize
37
38 if py3compat.PY3:
39 buffer = memoryview
40
41 #-----------------------------------------------------------------------------
42 # Serialization Functions
43 #-----------------------------------------------------------------------------
44
45 def serialize_object(obj, threshold=64e-6):
46 """Serialize an object into a list of sendable buffers.
47
48 Parameters
49 ----------
50
51 obj : object
52 The object to be serialized
53 threshold : float
54 The threshold for not double-pickling the content.
55
56
57 Returns
58 -------
59 ('pmd', [bufs]) :
60 where pmd is the pickled metadata wrapper,
61 bufs is a list of data buffers
62 """
63 databuffers = []
64 if isinstance(obj, (list, tuple)):
65 clist = canSequence(obj)
66 slist = map(serialize, clist)
67 for s in slist:
68 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
69 databuffers.append(s.getData())
70 s.data = None
71 return pickle.dumps(slist,-1), databuffers
72 elif isinstance(obj, dict):
73 sobj = {}
74 for k in sorted(obj.iterkeys()):
75 s = serialize(can(obj[k]))
76 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
77 databuffers.append(s.getData())
78 s.data = None
79 sobj[k] = s
80 return pickle.dumps(sobj,-1),databuffers
81 else:
82 s = serialize(can(obj))
83 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
84 databuffers.append(s.getData())
85 s.data = None
86 return pickle.dumps(s,-1),databuffers
87
88
89 def unserialize_object(bufs):
90 """reconstruct an object serialized by serialize_object from data buffers."""
91 bufs = list(bufs)
92 sobj = pickle.loads(bufs.pop(0))
93 if isinstance(sobj, (list, tuple)):
94 for s in sobj:
95 if s.data is None:
96 s.data = bufs.pop(0)
97 return uncanSequence(map(unserialize, sobj)), bufs
98 elif isinstance(sobj, dict):
99 newobj = {}
100 for k in sorted(sobj.iterkeys()):
101 s = sobj[k]
102 if s.data is None:
103 s.data = bufs.pop(0)
104 newobj[k] = uncan(unserialize(s))
105 return newobj, bufs
106 else:
107 if sobj.data is None:
108 sobj.data = bufs.pop(0)
109 return uncan(unserialize(sobj)), bufs
110
111 def pack_apply_message(f, args, kwargs, threshold=64e-6):
112 """pack up a function, args, and kwargs to be sent over the wire
113 as a series of buffers. Any object whose data is larger than `threshold`
114 will not have their data copied (currently only numpy arrays support zero-copy)"""
115 msg = [pickle.dumps(can(f),-1)]
116 databuffers = [] # for large objects
117 sargs, bufs = serialize_object(args,threshold)
118 msg.append(sargs)
119 databuffers.extend(bufs)
120 skwargs, bufs = serialize_object(kwargs,threshold)
121 msg.append(skwargs)
122 databuffers.extend(bufs)
123 msg.extend(databuffers)
124 return msg
125
126 def unpack_apply_message(bufs, g=None, copy=True):
127 """unpack f,args,kwargs from buffers packed by pack_apply_message()
128 Returns: original f,args,kwargs"""
129 bufs = list(bufs) # allow us to pop
130 assert len(bufs) >= 3, "not enough buffers!"
131 if not copy:
132 for i in range(3):
133 bufs[i] = bufs[i].bytes
134 cf = pickle.loads(bufs.pop(0))
135 sargs = list(pickle.loads(bufs.pop(0)))
136 skwargs = dict(pickle.loads(bufs.pop(0)))
137 # print sargs, skwargs
138 f = uncan(cf, g)
139 for sa in sargs:
140 if sa.data is None:
141 m = bufs.pop(0)
142 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
143 # always use a buffer, until memoryviews get sorted out
144 sa.data = buffer(m)
145 # disable memoryview support
146 # if copy:
147 # sa.data = buffer(m)
148 # else:
149 # sa.data = m.buffer
150 else:
151 if copy:
152 sa.data = m
153 else:
154 sa.data = m.bytes
155
156 args = uncanSequence(map(unserialize, sargs), g)
157 kwargs = {}
158 for k in sorted(skwargs.iterkeys()):
159 sa = skwargs[k]
160 if sa.data is None:
161 m = bufs.pop(0)
162 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
163 # always use a buffer, until memoryviews get sorted out
164 sa.data = buffer(m)
165 # disable memoryview support
166 # if copy:
167 # sa.data = buffer(m)
168 # else:
169 # sa.data = m.buffer
170 else:
171 if copy:
172 sa.data = m
173 else:
174 sa.data = m.bytes
175
176 kwargs[k] = uncan(unserialize(sa), g)
177
178 return f,args,kwargs
179
@@ -1,508 +1,511 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 A base class for a configurable application.
3 A base class for a configurable application.
4
4
5 Authors:
5 Authors:
6
6
7 * Brian Granger
7 * Brian Granger
8 * Min RK
8 * Min RK
9 """
9 """
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Copyright (C) 2008-2011 The IPython Development Team
12 # Copyright (C) 2008-2011 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19 # Imports
19 # Imports
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21
21
22 import logging
22 import logging
23 import os
23 import os
24 import re
24 import re
25 import sys
25 import sys
26 from copy import deepcopy
26 from copy import deepcopy
27 from collections import defaultdict
27 from collections import defaultdict
28
28
29 from IPython.external.decorator import decorator
29 from IPython.external.decorator import decorator
30
30
31 from IPython.config.configurable import SingletonConfigurable
31 from IPython.config.configurable import SingletonConfigurable
32 from IPython.config.loader import (
32 from IPython.config.loader import (
33 KVArgParseConfigLoader, PyFileConfigLoader, Config, ArgumentError, ConfigFileNotFound,
33 KVArgParseConfigLoader, PyFileConfigLoader, Config, ArgumentError, ConfigFileNotFound,
34 )
34 )
35
35
36 from IPython.utils.traitlets import (
36 from IPython.utils.traitlets import (
37 Unicode, List, Enum, Dict, Instance, TraitError
37 Unicode, List, Enum, Dict, Instance, TraitError
38 )
38 )
39 from IPython.utils.importstring import import_item
39 from IPython.utils.importstring import import_item
40 from IPython.utils.text import indent, wrap_paragraphs, dedent
40 from IPython.utils.text import indent, wrap_paragraphs, dedent
41
41
42 #-----------------------------------------------------------------------------
42 #-----------------------------------------------------------------------------
43 # function for re-wrapping a helpstring
43 # function for re-wrapping a helpstring
44 #-----------------------------------------------------------------------------
44 #-----------------------------------------------------------------------------
45
45
46 #-----------------------------------------------------------------------------
46 #-----------------------------------------------------------------------------
47 # Descriptions for the various sections
47 # Descriptions for the various sections
48 #-----------------------------------------------------------------------------
48 #-----------------------------------------------------------------------------
49
49
50 # merge flags&aliases into options
50 # merge flags&aliases into options
51 option_description = """
51 option_description = """
52 Arguments that take values are actually convenience aliases to full
52 Arguments that take values are actually convenience aliases to full
53 Configurables, whose aliases are listed on the help line. For more information
53 Configurables, whose aliases are listed on the help line. For more information
54 on full configurables, see '--help-all'.
54 on full configurables, see '--help-all'.
55 """.strip() # trim newlines of front and back
55 """.strip() # trim newlines of front and back
56
56
57 keyvalue_description = """
57 keyvalue_description = """
58 Parameters are set from command-line arguments of the form:
58 Parameters are set from command-line arguments of the form:
59 `--Class.trait=value`.
59 `--Class.trait=value`.
60 This line is evaluated in Python, so simple expressions are allowed, e.g.::
60 This line is evaluated in Python, so simple expressions are allowed, e.g.::
61 `--C.a='range(3)'` For setting C.a=[0,1,2].
61 `--C.a='range(3)'` For setting C.a=[0,1,2].
62 """.strip() # trim newlines of front and back
62 """.strip() # trim newlines of front and back
63
63
64 subcommand_description = """
64 subcommand_description = """
65 Subcommands are launched as `{app} cmd [args]`. For information on using
65 Subcommands are launched as `{app} cmd [args]`. For information on using
66 subcommand 'cmd', do: `{app} cmd -h`.
66 subcommand 'cmd', do: `{app} cmd -h`.
67 """.strip().format(app=os.path.basename(sys.argv[0]))
67 """.strip().format(app=os.path.basename(sys.argv[0]))
68 # get running program name
68 # get running program name
69
69
70 #-----------------------------------------------------------------------------
70 #-----------------------------------------------------------------------------
71 # Application class
71 # Application class
72 #-----------------------------------------------------------------------------
72 #-----------------------------------------------------------------------------
73
73
74 @decorator
74 @decorator
75 def catch_config_error(method, app, *args, **kwargs):
75 def catch_config_error(method, app, *args, **kwargs):
76 """Method decorator for catching invalid config (Trait/ArgumentErrors) during init.
76 """Method decorator for catching invalid config (Trait/ArgumentErrors) during init.
77
77
78 On a TraitError (generally caused by bad config), this will print the trait's
78 On a TraitError (generally caused by bad config), this will print the trait's
79 message, and exit the app.
79 message, and exit the app.
80
80
81 For use on init methods, to prevent invoking excepthook on invalid input.
81 For use on init methods, to prevent invoking excepthook on invalid input.
82 """
82 """
83 try:
83 try:
84 return method(app, *args, **kwargs)
84 return method(app, *args, **kwargs)
85 except (TraitError, ArgumentError) as e:
85 except (TraitError, ArgumentError) as e:
86 app.print_description()
86 app.print_description()
87 app.print_help()
87 app.print_help()
88 app.print_examples()
88 app.print_examples()
89 app.log.fatal("Bad config encountered during initialization:")
89 app.log.fatal("Bad config encountered during initialization:")
90 app.log.fatal(str(e))
90 app.log.fatal(str(e))
91 app.log.debug("Config at the time: %s", app.config)
91 app.log.debug("Config at the time: %s", app.config)
92 app.exit(1)
92 app.exit(1)
93
93
94
94
95 class ApplicationError(Exception):
95 class ApplicationError(Exception):
96 pass
96 pass
97
97
98
98
99 class Application(SingletonConfigurable):
99 class Application(SingletonConfigurable):
100 """A singleton application with full configuration support."""
100 """A singleton application with full configuration support."""
101
101
102 # The name of the application, will usually match the name of the command
102 # The name of the application, will usually match the name of the command
103 # line application
103 # line application
104 name = Unicode(u'application')
104 name = Unicode(u'application')
105
105
106 # The description of the application that is printed at the beginning
106 # The description of the application that is printed at the beginning
107 # of the help.
107 # of the help.
108 description = Unicode(u'This is an application.')
108 description = Unicode(u'This is an application.')
109 # default section descriptions
109 # default section descriptions
110 option_description = Unicode(option_description)
110 option_description = Unicode(option_description)
111 keyvalue_description = Unicode(keyvalue_description)
111 keyvalue_description = Unicode(keyvalue_description)
112 subcommand_description = Unicode(subcommand_description)
112 subcommand_description = Unicode(subcommand_description)
113
113
114 # The usage and example string that goes at the end of the help string.
114 # The usage and example string that goes at the end of the help string.
115 examples = Unicode()
115 examples = Unicode()
116
116
117 # A sequence of Configurable subclasses whose config=True attributes will
117 # A sequence of Configurable subclasses whose config=True attributes will
118 # be exposed at the command line.
118 # be exposed at the command line.
119 classes = List([])
119 classes = List([])
120
120
121 # The version string of this application.
121 # The version string of this application.
122 version = Unicode(u'0.0')
122 version = Unicode(u'0.0')
123
123
124 # The log level for the application
124 # The log level for the application
125 log_level = Enum((0,10,20,30,40,50,'DEBUG','INFO','WARN','ERROR','CRITICAL'),
125 log_level = Enum((0,10,20,30,40,50,'DEBUG','INFO','WARN','ERROR','CRITICAL'),
126 default_value=logging.WARN,
126 default_value=logging.WARN,
127 config=True,
127 config=True,
128 help="Set the log level by value or name.")
128 help="Set the log level by value or name.")
129 def _log_level_changed(self, name, old, new):
129 def _log_level_changed(self, name, old, new):
130 """Adjust the log level when log_level is set."""
130 """Adjust the log level when log_level is set."""
131 if isinstance(new, basestring):
131 if isinstance(new, basestring):
132 new = getattr(logging, new)
132 new = getattr(logging, new)
133 self.log_level = new
133 self.log_level = new
134 self.log.setLevel(new)
134 self.log.setLevel(new)
135
136 log_format = Unicode("[%(name)s] %(message)s", config=True,
137 help="The Logging format template",
138 )
139 log = Instance(logging.Logger)
140 def _log_default(self):
141 """Start logging for this application.
142
143 The default is to log to stdout using a StreaHandler. The log level
144 starts at loggin.WARN, but this can be adjusted by setting the
145 ``log_level`` attribute.
146 """
147 log = logging.getLogger(self.__class__.__name__)
148 log.setLevel(self.log_level)
149 if sys.executable.endswith('pythonw.exe'):
150 # this should really go to a file, but file-logging is only
151 # hooked up in parallel applications
152 _log_handler = logging.StreamHandler(open(os.devnull, 'w'))
153 else:
154 _log_handler = logging.StreamHandler()
155 _log_formatter = logging.Formatter(self.log_format)
156 _log_handler.setFormatter(_log_formatter)
157 log.addHandler(_log_handler)
158 return log
135
159
136 # the alias map for configurables
160 # the alias map for configurables
137 aliases = Dict({'log-level' : 'Application.log_level'})
161 aliases = Dict({'log-level' : 'Application.log_level'})
138
162
139 # flags for loading Configurables or store_const style flags
163 # flags for loading Configurables or store_const style flags
140 # flags are loaded from this dict by '--key' flags
164 # flags are loaded from this dict by '--key' flags
141 # this must be a dict of two-tuples, the first element being the Config/dict
165 # this must be a dict of two-tuples, the first element being the Config/dict
142 # and the second being the help string for the flag
166 # and the second being the help string for the flag
143 flags = Dict()
167 flags = Dict()
144 def _flags_changed(self, name, old, new):
168 def _flags_changed(self, name, old, new):
145 """ensure flags dict is valid"""
169 """ensure flags dict is valid"""
146 for key,value in new.iteritems():
170 for key,value in new.iteritems():
147 assert len(value) == 2, "Bad flag: %r:%s"%(key,value)
171 assert len(value) == 2, "Bad flag: %r:%s"%(key,value)
148 assert isinstance(value[0], (dict, Config)), "Bad flag: %r:%s"%(key,value)
172 assert isinstance(value[0], (dict, Config)), "Bad flag: %r:%s"%(key,value)
149 assert isinstance(value[1], basestring), "Bad flag: %r:%s"%(key,value)
173 assert isinstance(value[1], basestring), "Bad flag: %r:%s"%(key,value)
150
174
151
175
152 # subcommands for launching other applications
176 # subcommands for launching other applications
153 # if this is not empty, this will be a parent Application
177 # if this is not empty, this will be a parent Application
154 # this must be a dict of two-tuples,
178 # this must be a dict of two-tuples,
155 # the first element being the application class/import string
179 # the first element being the application class/import string
156 # and the second being the help string for the subcommand
180 # and the second being the help string for the subcommand
157 subcommands = Dict()
181 subcommands = Dict()
158 # parse_command_line will initialize a subapp, if requested
182 # parse_command_line will initialize a subapp, if requested
159 subapp = Instance('IPython.config.application.Application', allow_none=True)
183 subapp = Instance('IPython.config.application.Application', allow_none=True)
160
184
161 # extra command-line arguments that don't set config values
185 # extra command-line arguments that don't set config values
162 extra_args = List(Unicode)
186 extra_args = List(Unicode)
163
187
164
188
165 def __init__(self, **kwargs):
189 def __init__(self, **kwargs):
166 SingletonConfigurable.__init__(self, **kwargs)
190 SingletonConfigurable.__init__(self, **kwargs)
167 # Ensure my class is in self.classes, so my attributes appear in command line
191 # Ensure my class is in self.classes, so my attributes appear in command line
168 # options and config files.
192 # options and config files.
169 if self.__class__ not in self.classes:
193 if self.__class__ not in self.classes:
170 self.classes.insert(0, self.__class__)
194 self.classes.insert(0, self.__class__)
171
195
172 self.init_logging()
173
174 def _config_changed(self, name, old, new):
196 def _config_changed(self, name, old, new):
175 SingletonConfigurable._config_changed(self, name, old, new)
197 SingletonConfigurable._config_changed(self, name, old, new)
176 self.log.debug('Config changed:')
198 self.log.debug('Config changed:')
177 self.log.debug(repr(new))
199 self.log.debug(repr(new))
178
200
179 def init_logging(self):
180 """Start logging for this application.
181
182 The default is to log to stdout using a StreaHandler. The log level
183 starts at loggin.WARN, but this can be adjusted by setting the
184 ``log_level`` attribute.
185 """
186 self.log = logging.getLogger(self.__class__.__name__)
187 self.log.setLevel(self.log_level)
188 if sys.executable.endswith('pythonw.exe'):
189 # this should really go to a file, but file-logging is only
190 # hooked up in parallel applications
191 self._log_handler = logging.StreamHandler(open(os.devnull, 'w'))
192 else:
193 self._log_handler = logging.StreamHandler()
194 self._log_formatter = logging.Formatter("[%(name)s] %(message)s")
195 self._log_handler.setFormatter(self._log_formatter)
196 self.log.addHandler(self._log_handler)
197
198 @catch_config_error
201 @catch_config_error
199 def initialize(self, argv=None):
202 def initialize(self, argv=None):
200 """Do the basic steps to configure me.
203 """Do the basic steps to configure me.
201
204
202 Override in subclasses.
205 Override in subclasses.
203 """
206 """
204 self.parse_command_line(argv)
207 self.parse_command_line(argv)
205
208
206
209
207 def start(self):
210 def start(self):
208 """Start the app mainloop.
211 """Start the app mainloop.
209
212
210 Override in subclasses.
213 Override in subclasses.
211 """
214 """
212 if self.subapp is not None:
215 if self.subapp is not None:
213 return self.subapp.start()
216 return self.subapp.start()
214
217
215 def print_alias_help(self):
218 def print_alias_help(self):
216 """Print the alias part of the help."""
219 """Print the alias part of the help."""
217 if not self.aliases:
220 if not self.aliases:
218 return
221 return
219
222
220 lines = []
223 lines = []
221 classdict = {}
224 classdict = {}
222 for cls in self.classes:
225 for cls in self.classes:
223 # include all parents (up to, but excluding Configurable) in available names
226 # include all parents (up to, but excluding Configurable) in available names
224 for c in cls.mro()[:-3]:
227 for c in cls.mro()[:-3]:
225 classdict[c.__name__] = c
228 classdict[c.__name__] = c
226
229
227 for alias, longname in self.aliases.iteritems():
230 for alias, longname in self.aliases.iteritems():
228 classname, traitname = longname.split('.',1)
231 classname, traitname = longname.split('.',1)
229 cls = classdict[classname]
232 cls = classdict[classname]
230
233
231 trait = cls.class_traits(config=True)[traitname]
234 trait = cls.class_traits(config=True)[traitname]
232 help = cls.class_get_trait_help(trait).splitlines()
235 help = cls.class_get_trait_help(trait).splitlines()
233 # reformat first line
236 # reformat first line
234 help[0] = help[0].replace(longname, alias) + ' (%s)'%longname
237 help[0] = help[0].replace(longname, alias) + ' (%s)'%longname
235 if len(alias) == 1:
238 if len(alias) == 1:
236 help[0] = help[0].replace('--%s='%alias, '-%s '%alias)
239 help[0] = help[0].replace('--%s='%alias, '-%s '%alias)
237 lines.extend(help)
240 lines.extend(help)
238 # lines.append('')
241 # lines.append('')
239 print os.linesep.join(lines)
242 print os.linesep.join(lines)
240
243
241 def print_flag_help(self):
244 def print_flag_help(self):
242 """Print the flag part of the help."""
245 """Print the flag part of the help."""
243 if not self.flags:
246 if not self.flags:
244 return
247 return
245
248
246 lines = []
249 lines = []
247 for m, (cfg,help) in self.flags.iteritems():
250 for m, (cfg,help) in self.flags.iteritems():
248 prefix = '--' if len(m) > 1 else '-'
251 prefix = '--' if len(m) > 1 else '-'
249 lines.append(prefix+m)
252 lines.append(prefix+m)
250 lines.append(indent(dedent(help.strip())))
253 lines.append(indent(dedent(help.strip())))
251 # lines.append('')
254 # lines.append('')
252 print os.linesep.join(lines)
255 print os.linesep.join(lines)
253
256
254 def print_options(self):
257 def print_options(self):
255 if not self.flags and not self.aliases:
258 if not self.flags and not self.aliases:
256 return
259 return
257 lines = ['Options']
260 lines = ['Options']
258 lines.append('-'*len(lines[0]))
261 lines.append('-'*len(lines[0]))
259 lines.append('')
262 lines.append('')
260 for p in wrap_paragraphs(self.option_description):
263 for p in wrap_paragraphs(self.option_description):
261 lines.append(p)
264 lines.append(p)
262 lines.append('')
265 lines.append('')
263 print os.linesep.join(lines)
266 print os.linesep.join(lines)
264 self.print_flag_help()
267 self.print_flag_help()
265 self.print_alias_help()
268 self.print_alias_help()
266 print
269 print
267
270
268 def print_subcommands(self):
271 def print_subcommands(self):
269 """Print the subcommand part of the help."""
272 """Print the subcommand part of the help."""
270 if not self.subcommands:
273 if not self.subcommands:
271 return
274 return
272
275
273 lines = ["Subcommands"]
276 lines = ["Subcommands"]
274 lines.append('-'*len(lines[0]))
277 lines.append('-'*len(lines[0]))
275 lines.append('')
278 lines.append('')
276 for p in wrap_paragraphs(self.subcommand_description):
279 for p in wrap_paragraphs(self.subcommand_description):
277 lines.append(p)
280 lines.append(p)
278 lines.append('')
281 lines.append('')
279 for subc, (cls, help) in self.subcommands.iteritems():
282 for subc, (cls, help) in self.subcommands.iteritems():
280 lines.append(subc)
283 lines.append(subc)
281 if help:
284 if help:
282 lines.append(indent(dedent(help.strip())))
285 lines.append(indent(dedent(help.strip())))
283 lines.append('')
286 lines.append('')
284 print os.linesep.join(lines)
287 print os.linesep.join(lines)
285
288
286 def print_help(self, classes=False):
289 def print_help(self, classes=False):
287 """Print the help for each Configurable class in self.classes.
290 """Print the help for each Configurable class in self.classes.
288
291
289 If classes=False (the default), only flags and aliases are printed.
292 If classes=False (the default), only flags and aliases are printed.
290 """
293 """
291 self.print_subcommands()
294 self.print_subcommands()
292 self.print_options()
295 self.print_options()
293
296
294 if classes:
297 if classes:
295 if self.classes:
298 if self.classes:
296 print "Class parameters"
299 print "Class parameters"
297 print "----------------"
300 print "----------------"
298 print
301 print
299 for p in wrap_paragraphs(self.keyvalue_description):
302 for p in wrap_paragraphs(self.keyvalue_description):
300 print p
303 print p
301 print
304 print
302
305
303 for cls in self.classes:
306 for cls in self.classes:
304 cls.class_print_help()
307 cls.class_print_help()
305 print
308 print
306 else:
309 else:
307 print "To see all available configurables, use `--help-all`"
310 print "To see all available configurables, use `--help-all`"
308 print
311 print
309
312
310 def print_description(self):
313 def print_description(self):
311 """Print the application description."""
314 """Print the application description."""
312 for p in wrap_paragraphs(self.description):
315 for p in wrap_paragraphs(self.description):
313 print p
316 print p
314 print
317 print
315
318
316 def print_examples(self):
319 def print_examples(self):
317 """Print usage and examples.
320 """Print usage and examples.
318
321
319 This usage string goes at the end of the command line help string
322 This usage string goes at the end of the command line help string
320 and should contain examples of the application's usage.
323 and should contain examples of the application's usage.
321 """
324 """
322 if self.examples:
325 if self.examples:
323 print "Examples"
326 print "Examples"
324 print "--------"
327 print "--------"
325 print
328 print
326 print indent(dedent(self.examples.strip()))
329 print indent(dedent(self.examples.strip()))
327 print
330 print
328
331
329 def print_version(self):
332 def print_version(self):
330 """Print the version string."""
333 """Print the version string."""
331 print self.version
334 print self.version
332
335
333 def update_config(self, config):
336 def update_config(self, config):
334 """Fire the traits events when the config is updated."""
337 """Fire the traits events when the config is updated."""
335 # Save a copy of the current config.
338 # Save a copy of the current config.
336 newconfig = deepcopy(self.config)
339 newconfig = deepcopy(self.config)
337 # Merge the new config into the current one.
340 # Merge the new config into the current one.
338 newconfig._merge(config)
341 newconfig._merge(config)
339 # Save the combined config as self.config, which triggers the traits
342 # Save the combined config as self.config, which triggers the traits
340 # events.
343 # events.
341 self.config = newconfig
344 self.config = newconfig
342
345
343 @catch_config_error
346 @catch_config_error
344 def initialize_subcommand(self, subc, argv=None):
347 def initialize_subcommand(self, subc, argv=None):
345 """Initialize a subcommand with argv."""
348 """Initialize a subcommand with argv."""
346 subapp,help = self.subcommands.get(subc)
349 subapp,help = self.subcommands.get(subc)
347
350
348 if isinstance(subapp, basestring):
351 if isinstance(subapp, basestring):
349 subapp = import_item(subapp)
352 subapp = import_item(subapp)
350
353
351 # clear existing instances
354 # clear existing instances
352 self.__class__.clear_instance()
355 self.__class__.clear_instance()
353 # instantiate
356 # instantiate
354 self.subapp = subapp.instance()
357 self.subapp = subapp.instance()
355 # and initialize subapp
358 # and initialize subapp
356 self.subapp.initialize(argv)
359 self.subapp.initialize(argv)
357
360
358 def flatten_flags(self):
361 def flatten_flags(self):
359 """flatten flags and aliases, so cl-args override as expected.
362 """flatten flags and aliases, so cl-args override as expected.
360
363
361 This prevents issues such as an alias pointing to InteractiveShell,
364 This prevents issues such as an alias pointing to InteractiveShell,
362 but a config file setting the same trait in TerminalInteraciveShell
365 but a config file setting the same trait in TerminalInteraciveShell
363 getting inappropriate priority over the command-line arg.
366 getting inappropriate priority over the command-line arg.
364
367
365 Only aliases with exactly one descendent in the class list
368 Only aliases with exactly one descendent in the class list
366 will be promoted.
369 will be promoted.
367
370
368 """
371 """
369 # build a tree of classes in our list that inherit from a particular
372 # build a tree of classes in our list that inherit from a particular
370 # it will be a dict by parent classname of classes in our list
373 # it will be a dict by parent classname of classes in our list
371 # that are descendents
374 # that are descendents
372 mro_tree = defaultdict(list)
375 mro_tree = defaultdict(list)
373 for cls in self.classes:
376 for cls in self.classes:
374 clsname = cls.__name__
377 clsname = cls.__name__
375 for parent in cls.mro()[1:-3]:
378 for parent in cls.mro()[1:-3]:
376 # exclude cls itself and Configurable,HasTraits,object
379 # exclude cls itself and Configurable,HasTraits,object
377 mro_tree[parent.__name__].append(clsname)
380 mro_tree[parent.__name__].append(clsname)
378 # flatten aliases, which have the form:
381 # flatten aliases, which have the form:
379 # { 'alias' : 'Class.trait' }
382 # { 'alias' : 'Class.trait' }
380 aliases = {}
383 aliases = {}
381 for alias, cls_trait in self.aliases.iteritems():
384 for alias, cls_trait in self.aliases.iteritems():
382 cls,trait = cls_trait.split('.',1)
385 cls,trait = cls_trait.split('.',1)
383 children = mro_tree[cls]
386 children = mro_tree[cls]
384 if len(children) == 1:
387 if len(children) == 1:
385 # exactly one descendent, promote alias
388 # exactly one descendent, promote alias
386 cls = children[0]
389 cls = children[0]
387 aliases[alias] = '.'.join([cls,trait])
390 aliases[alias] = '.'.join([cls,trait])
388
391
389 # flatten flags, which are of the form:
392 # flatten flags, which are of the form:
390 # { 'key' : ({'Cls' : {'trait' : value}}, 'help')}
393 # { 'key' : ({'Cls' : {'trait' : value}}, 'help')}
391 flags = {}
394 flags = {}
392 for key, (flagdict, help) in self.flags.iteritems():
395 for key, (flagdict, help) in self.flags.iteritems():
393 newflag = {}
396 newflag = {}
394 for cls, subdict in flagdict.iteritems():
397 for cls, subdict in flagdict.iteritems():
395 children = mro_tree[cls]
398 children = mro_tree[cls]
396 # exactly one descendent, promote flag section
399 # exactly one descendent, promote flag section
397 if len(children) == 1:
400 if len(children) == 1:
398 cls = children[0]
401 cls = children[0]
399 newflag[cls] = subdict
402 newflag[cls] = subdict
400 flags[key] = (newflag, help)
403 flags[key] = (newflag, help)
401 return flags, aliases
404 return flags, aliases
402
405
403 @catch_config_error
406 @catch_config_error
404 def parse_command_line(self, argv=None):
407 def parse_command_line(self, argv=None):
405 """Parse the command line arguments."""
408 """Parse the command line arguments."""
406 argv = sys.argv[1:] if argv is None else argv
409 argv = sys.argv[1:] if argv is None else argv
407
410
408 if argv and argv[0] == 'help':
411 if argv and argv[0] == 'help':
409 # turn `ipython help notebook` into `ipython notebook -h`
412 # turn `ipython help notebook` into `ipython notebook -h`
410 argv = argv[1:] + ['-h']
413 argv = argv[1:] + ['-h']
411
414
412 if self.subcommands and len(argv) > 0:
415 if self.subcommands and len(argv) > 0:
413 # we have subcommands, and one may have been specified
416 # we have subcommands, and one may have been specified
414 subc, subargv = argv[0], argv[1:]
417 subc, subargv = argv[0], argv[1:]
415 if re.match(r'^\w(\-?\w)*$', subc) and subc in self.subcommands:
418 if re.match(r'^\w(\-?\w)*$', subc) and subc in self.subcommands:
416 # it's a subcommand, and *not* a flag or class parameter
419 # it's a subcommand, and *not* a flag or class parameter
417 return self.initialize_subcommand(subc, subargv)
420 return self.initialize_subcommand(subc, subargv)
418
421
419 if '-h' in argv or '--help' in argv or '--help-all' in argv:
422 if '-h' in argv or '--help' in argv or '--help-all' in argv:
420 self.print_description()
423 self.print_description()
421 self.print_help('--help-all' in argv)
424 self.print_help('--help-all' in argv)
422 self.print_examples()
425 self.print_examples()
423 self.exit(0)
426 self.exit(0)
424
427
425 if '--version' in argv or '-V' in argv:
428 if '--version' in argv or '-V' in argv:
426 self.print_version()
429 self.print_version()
427 self.exit(0)
430 self.exit(0)
428
431
429 # flatten flags&aliases, so cl-args get appropriate priority:
432 # flatten flags&aliases, so cl-args get appropriate priority:
430 flags,aliases = self.flatten_flags()
433 flags,aliases = self.flatten_flags()
431
434
432 loader = KVArgParseConfigLoader(argv=argv, aliases=aliases,
435 loader = KVArgParseConfigLoader(argv=argv, aliases=aliases,
433 flags=flags)
436 flags=flags)
434 config = loader.load_config()
437 config = loader.load_config()
435 self.update_config(config)
438 self.update_config(config)
436 # store unparsed args in extra_args
439 # store unparsed args in extra_args
437 self.extra_args = loader.extra_args
440 self.extra_args = loader.extra_args
438
441
439 @catch_config_error
442 @catch_config_error
440 def load_config_file(self, filename, path=None):
443 def load_config_file(self, filename, path=None):
441 """Load a .py based config file by filename and path."""
444 """Load a .py based config file by filename and path."""
442 loader = PyFileConfigLoader(filename, path=path)
445 loader = PyFileConfigLoader(filename, path=path)
443 try:
446 try:
444 config = loader.load_config()
447 config = loader.load_config()
445 except ConfigFileNotFound:
448 except ConfigFileNotFound:
446 # problem finding the file, raise
449 # problem finding the file, raise
447 raise
450 raise
448 except Exception:
451 except Exception:
449 # try to get the full filename, but it will be empty in the
452 # try to get the full filename, but it will be empty in the
450 # unlikely event that the error raised before filefind finished
453 # unlikely event that the error raised before filefind finished
451 filename = loader.full_filename or filename
454 filename = loader.full_filename or filename
452 # problem while running the file
455 # problem while running the file
453 self.log.error("Exception while loading config file %s",
456 self.log.error("Exception while loading config file %s",
454 filename, exc_info=True)
457 filename, exc_info=True)
455 else:
458 else:
456 self.log.debug("Loaded config file: %s", loader.full_filename)
459 self.log.debug("Loaded config file: %s", loader.full_filename)
457 self.update_config(config)
460 self.update_config(config)
458
461
459 def generate_config_file(self):
462 def generate_config_file(self):
460 """generate default config file from Configurables"""
463 """generate default config file from Configurables"""
461 lines = ["# Configuration file for %s."%self.name]
464 lines = ["# Configuration file for %s."%self.name]
462 lines.append('')
465 lines.append('')
463 lines.append('c = get_config()')
466 lines.append('c = get_config()')
464 lines.append('')
467 lines.append('')
465 for cls in self.classes:
468 for cls in self.classes:
466 lines.append(cls.class_config_section())
469 lines.append(cls.class_config_section())
467 return '\n'.join(lines)
470 return '\n'.join(lines)
468
471
469 def exit(self, exit_status=0):
472 def exit(self, exit_status=0):
470 self.log.debug("Exiting application: %s" % self.name)
473 self.log.debug("Exiting application: %s" % self.name)
471 sys.exit(exit_status)
474 sys.exit(exit_status)
472
475
473 #-----------------------------------------------------------------------------
476 #-----------------------------------------------------------------------------
474 # utility functions, for convenience
477 # utility functions, for convenience
475 #-----------------------------------------------------------------------------
478 #-----------------------------------------------------------------------------
476
479
477 def boolean_flag(name, configurable, set_help='', unset_help=''):
480 def boolean_flag(name, configurable, set_help='', unset_help=''):
478 """Helper for building basic --trait, --no-trait flags.
481 """Helper for building basic --trait, --no-trait flags.
479
482
480 Parameters
483 Parameters
481 ----------
484 ----------
482
485
483 name : str
486 name : str
484 The name of the flag.
487 The name of the flag.
485 configurable : str
488 configurable : str
486 The 'Class.trait' string of the trait to be set/unset with the flag
489 The 'Class.trait' string of the trait to be set/unset with the flag
487 set_help : unicode
490 set_help : unicode
488 help string for --name flag
491 help string for --name flag
489 unset_help : unicode
492 unset_help : unicode
490 help string for --no-name flag
493 help string for --no-name flag
491
494
492 Returns
495 Returns
493 -------
496 -------
494
497
495 cfg : dict
498 cfg : dict
496 A dict with two keys: 'name', and 'no-name', for setting and unsetting
499 A dict with two keys: 'name', and 'no-name', for setting and unsetting
497 the trait, respectively.
500 the trait, respectively.
498 """
501 """
499 # default helpstrings
502 # default helpstrings
500 set_help = set_help or "set %s=True"%configurable
503 set_help = set_help or "set %s=True"%configurable
501 unset_help = unset_help or "set %s=False"%configurable
504 unset_help = unset_help or "set %s=False"%configurable
502
505
503 cls,trait = configurable.split('.')
506 cls,trait = configurable.split('.')
504
507
505 setter = {cls : {trait : True}}
508 setter = {cls : {trait : True}}
506 unsetter = {cls : {trait : False}}
509 unsetter = {cls : {trait : False}}
507 return {name : (setter, set_help), 'no-'+name : (unsetter, unset_help)}
510 return {name : (setter, set_help), 'no-'+name : (unsetter, unset_help)}
508
511
@@ -1,976 +1,977 b''
1 """ History related magics and functionality """
1 """ History related magics and functionality """
2 #-----------------------------------------------------------------------------
2 #-----------------------------------------------------------------------------
3 # Copyright (C) 2010-2011 The IPython Development Team.
3 # Copyright (C) 2010-2011 The IPython Development Team.
4 #
4 #
5 # Distributed under the terms of the BSD License.
5 # Distributed under the terms of the BSD License.
6 #
6 #
7 # The full license is in the file COPYING.txt, distributed with this software.
7 # The full license is in the file COPYING.txt, distributed with this software.
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9
9
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 from __future__ import print_function
13 from __future__ import print_function
14
14
15 # Stdlib imports
15 # Stdlib imports
16 import atexit
16 import atexit
17 import datetime
17 import datetime
18 from io import open as io_open
18 from io import open as io_open
19 import os
19 import os
20 import re
20 import re
21 try:
21 try:
22 import sqlite3
22 import sqlite3
23 except ImportError:
23 except ImportError:
24 sqlite3 = None
24 sqlite3 = None
25 import threading
25 import threading
26
26
27 # Our own packages
27 # Our own packages
28 from IPython.core.error import StdinNotImplementedError
28 from IPython.core.error import StdinNotImplementedError
29 from IPython.config.configurable import Configurable
29 from IPython.config.configurable import Configurable
30 from IPython.external.decorator import decorator
30 from IPython.external.decorator import decorator
31 from IPython.testing.skipdoctest import skip_doctest
31 from IPython.testing.skipdoctest import skip_doctest
32 from IPython.utils import io
32 from IPython.utils import io
33 from IPython.utils.path import locate_profile
33 from IPython.utils.path import locate_profile
34 from IPython.utils.traitlets import Bool, Dict, Instance, Integer, List, Unicode
34 from IPython.utils.traitlets import Bool, Dict, Instance, Integer, List, Unicode
35 from IPython.utils.warn import warn
35 from IPython.utils.warn import warn
36
36
37 #-----------------------------------------------------------------------------
37 #-----------------------------------------------------------------------------
38 # Classes and functions
38 # Classes and functions
39 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
40
40
41 class DummyDB(object):
41 class DummyDB(object):
42 """Dummy DB that will act as a black hole for history.
42 """Dummy DB that will act as a black hole for history.
43
43
44 Only used in the absence of sqlite"""
44 Only used in the absence of sqlite"""
45 def execute(*args, **kwargs):
45 def execute(*args, **kwargs):
46 return []
46 return []
47
47
48 def commit(self, *args, **kwargs):
48 def commit(self, *args, **kwargs):
49 pass
49 pass
50
50
51 def __enter__(self, *args, **kwargs):
51 def __enter__(self, *args, **kwargs):
52 pass
52 pass
53
53
54 def __exit__(self, *args, **kwargs):
54 def __exit__(self, *args, **kwargs):
55 pass
55 pass
56
56
57 @decorator
57 @decorator
58 def needs_sqlite(f,*a,**kw):
58 def needs_sqlite(f,*a,**kw):
59 """return an empty list in the absence of sqlite"""
59 """return an empty list in the absence of sqlite"""
60 if sqlite3 is None:
60 if sqlite3 is None:
61 return []
61 return []
62 else:
62 else:
63 return f(*a,**kw)
63 return f(*a,**kw)
64
64
65 class HistoryAccessor(Configurable):
65 class HistoryAccessor(Configurable):
66 """Access the history database without adding to it.
66 """Access the history database without adding to it.
67
67
68 This is intended for use by standalone history tools. IPython shells use
68 This is intended for use by standalone history tools. IPython shells use
69 HistoryManager, below, which is a subclass of this."""
69 HistoryManager, below, which is a subclass of this."""
70
70
71 # String holding the path to the history file
71 # String holding the path to the history file
72 hist_file = Unicode(config=True,
72 hist_file = Unicode(config=True,
73 help="""Path to file to use for SQLite history database.
73 help="""Path to file to use for SQLite history database.
74
74
75 By default, IPython will put the history database in the IPython profile
75 By default, IPython will put the history database in the IPython profile
76 directory. If you would rather share one history among profiles,
76 directory. If you would rather share one history among profiles,
77 you ca set this value in each, so that they are consistent.
77 you ca set this value in each, so that they are consistent.
78
78
79 Due to an issue with fcntl, SQLite is known to misbehave on some NFS mounts.
79 Due to an issue with fcntl, SQLite is known to misbehave on some NFS mounts.
80 If you see IPython hanging, try setting this to something on a local disk,
80 If you see IPython hanging, try setting this to something on a local disk,
81 e.g::
81 e.g::
82
82
83 ipython --HistoryManager.hist_file=/tmp/ipython_hist.sqlite
83 ipython --HistoryManager.hist_file=/tmp/ipython_hist.sqlite
84
84
85 """)
85 """)
86
86
87
87
88 # The SQLite database
88 # The SQLite database
89 if sqlite3:
89 if sqlite3:
90 db = Instance(sqlite3.Connection)
90 db = Instance(sqlite3.Connection)
91 else:
91 else:
92 db = Instance(DummyDB)
92 db = Instance(DummyDB)
93
93
94 def __init__(self, profile='default', hist_file=u'', config=None, **traits):
94 def __init__(self, profile='default', hist_file=u'', config=None, **traits):
95 """Create a new history accessor.
95 """Create a new history accessor.
96
96
97 Parameters
97 Parameters
98 ----------
98 ----------
99 profile : str
99 profile : str
100 The name of the profile from which to open history.
100 The name of the profile from which to open history.
101 hist_file : str
101 hist_file : str
102 Path to an SQLite history database stored by IPython. If specified,
102 Path to an SQLite history database stored by IPython. If specified,
103 hist_file overrides profile.
103 hist_file overrides profile.
104 config :
104 config :
105 Config object. hist_file can also be set through this.
105 Config object. hist_file can also be set through this.
106 """
106 """
107 # We need a pointer back to the shell for various tasks.
107 # We need a pointer back to the shell for various tasks.
108 super(HistoryAccessor, self).__init__(config=config, **traits)
108 super(HistoryAccessor, self).__init__(config=config, **traits)
109 # defer setting hist_file from kwarg until after init,
109 # defer setting hist_file from kwarg until after init,
110 # otherwise the default kwarg value would clobber any value
110 # otherwise the default kwarg value would clobber any value
111 # set by config
111 # set by config
112 if hist_file:
112 if hist_file:
113 self.hist_file = hist_file
113 self.hist_file = hist_file
114
114
115 if self.hist_file == u'':
115 if self.hist_file == u'':
116 # No one has set the hist_file, yet.
116 # No one has set the hist_file, yet.
117 self.hist_file = self._get_hist_file_name(profile)
117 self.hist_file = self._get_hist_file_name(profile)
118
118
119 if sqlite3 is None:
119 if sqlite3 is None:
120 warn("IPython History requires SQLite, your history will not be saved\n")
120 warn("IPython History requires SQLite, your history will not be saved\n")
121 self.db = DummyDB()
121 self.db = DummyDB()
122 return
122 return
123
123
124 try:
124 try:
125 self.init_db()
125 self.init_db()
126 except sqlite3.DatabaseError:
126 except sqlite3.DatabaseError:
127 if os.path.isfile(self.hist_file):
127 if os.path.isfile(self.hist_file):
128 # Try to move the file out of the way
128 # Try to move the file out of the way
129 base,ext = os.path.splitext(self.hist_file)
129 base,ext = os.path.splitext(self.hist_file)
130 newpath = base + '-corrupt' + ext
130 newpath = base + '-corrupt' + ext
131 os.rename(self.hist_file, newpath)
131 os.rename(self.hist_file, newpath)
132 print("ERROR! History file wasn't a valid SQLite database.",
132 print("ERROR! History file wasn't a valid SQLite database.",
133 "It was moved to %s" % newpath, "and a new file created.")
133 "It was moved to %s" % newpath, "and a new file created.")
134 self.init_db()
134 self.init_db()
135 else:
135 else:
136 # The hist_file is probably :memory: or something else.
136 # The hist_file is probably :memory: or something else.
137 raise
137 raise
138
138
139 def _get_hist_file_name(self, profile='default'):
139 def _get_hist_file_name(self, profile='default'):
140 """Find the history file for the given profile name.
140 """Find the history file for the given profile name.
141
141
142 This is overridden by the HistoryManager subclass, to use the shell's
142 This is overridden by the HistoryManager subclass, to use the shell's
143 active profile.
143 active profile.
144
144
145 Parameters
145 Parameters
146 ----------
146 ----------
147 profile : str
147 profile : str
148 The name of a profile which has a history file.
148 The name of a profile which has a history file.
149 """
149 """
150 return os.path.join(locate_profile(profile), 'history.sqlite')
150 return os.path.join(locate_profile(profile), 'history.sqlite')
151
151
152 def init_db(self):
152 def init_db(self):
153 """Connect to the database, and create tables if necessary."""
153 """Connect to the database, and create tables if necessary."""
154 # use detect_types so that timestamps return datetime objects
154 # use detect_types so that timestamps return datetime objects
155 self.db = sqlite3.connect(self.hist_file, detect_types=sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES)
155 self.db = sqlite3.connect(self.hist_file, detect_types=sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES)
156 self.db.execute("""CREATE TABLE IF NOT EXISTS sessions (session integer
156 self.db.execute("""CREATE TABLE IF NOT EXISTS sessions (session integer
157 primary key autoincrement, start timestamp,
157 primary key autoincrement, start timestamp,
158 end timestamp, num_cmds integer, remark text)""")
158 end timestamp, num_cmds integer, remark text)""")
159 self.db.execute("""CREATE TABLE IF NOT EXISTS history
159 self.db.execute("""CREATE TABLE IF NOT EXISTS history
160 (session integer, line integer, source text, source_raw text,
160 (session integer, line integer, source text, source_raw text,
161 PRIMARY KEY (session, line))""")
161 PRIMARY KEY (session, line))""")
162 # Output history is optional, but ensure the table's there so it can be
162 # Output history is optional, but ensure the table's there so it can be
163 # enabled later.
163 # enabled later.
164 self.db.execute("""CREATE TABLE IF NOT EXISTS output_history
164 self.db.execute("""CREATE TABLE IF NOT EXISTS output_history
165 (session integer, line integer, output text,
165 (session integer, line integer, output text,
166 PRIMARY KEY (session, line))""")
166 PRIMARY KEY (session, line))""")
167 self.db.commit()
167 self.db.commit()
168
168
169 def writeout_cache(self):
169 def writeout_cache(self):
170 """Overridden by HistoryManager to dump the cache before certain
170 """Overridden by HistoryManager to dump the cache before certain
171 database lookups."""
171 database lookups."""
172 pass
172 pass
173
173
174 ## -------------------------------
174 ## -------------------------------
175 ## Methods for retrieving history:
175 ## Methods for retrieving history:
176 ## -------------------------------
176 ## -------------------------------
177 def _run_sql(self, sql, params, raw=True, output=False):
177 def _run_sql(self, sql, params, raw=True, output=False):
178 """Prepares and runs an SQL query for the history database.
178 """Prepares and runs an SQL query for the history database.
179
179
180 Parameters
180 Parameters
181 ----------
181 ----------
182 sql : str
182 sql : str
183 Any filtering expressions to go after SELECT ... FROM ...
183 Any filtering expressions to go after SELECT ... FROM ...
184 params : tuple
184 params : tuple
185 Parameters passed to the SQL query (to replace "?")
185 Parameters passed to the SQL query (to replace "?")
186 raw, output : bool
186 raw, output : bool
187 See :meth:`get_range`
187 See :meth:`get_range`
188
188
189 Returns
189 Returns
190 -------
190 -------
191 Tuples as :meth:`get_range`
191 Tuples as :meth:`get_range`
192 """
192 """
193 toget = 'source_raw' if raw else 'source'
193 toget = 'source_raw' if raw else 'source'
194 sqlfrom = "history"
194 sqlfrom = "history"
195 if output:
195 if output:
196 sqlfrom = "history LEFT JOIN output_history USING (session, line)"
196 sqlfrom = "history LEFT JOIN output_history USING (session, line)"
197 toget = "history.%s, output_history.output" % toget
197 toget = "history.%s, output_history.output" % toget
198 cur = self.db.execute("SELECT session, line, %s FROM %s " %\
198 cur = self.db.execute("SELECT session, line, %s FROM %s " %\
199 (toget, sqlfrom) + sql, params)
199 (toget, sqlfrom) + sql, params)
200 if output: # Regroup into 3-tuples, and parse JSON
200 if output: # Regroup into 3-tuples, and parse JSON
201 return ((ses, lin, (inp, out)) for ses, lin, inp, out in cur)
201 return ((ses, lin, (inp, out)) for ses, lin, inp, out in cur)
202 return cur
202 return cur
203
203
204 @needs_sqlite
204 @needs_sqlite
205 def get_session_info(self, session=0):
205 def get_session_info(self, session=0):
206 """get info about a session
206 """get info about a session
207
207
208 Parameters
208 Parameters
209 ----------
209 ----------
210
210
211 session : int
211 session : int
212 Session number to retrieve. The current session is 0, and negative
212 Session number to retrieve. The current session is 0, and negative
213 numbers count back from current session, so -1 is previous session.
213 numbers count back from current session, so -1 is previous session.
214
214
215 Returns
215 Returns
216 -------
216 -------
217
217
218 (session_id [int], start [datetime], end [datetime], num_cmds [int], remark [unicode])
218 (session_id [int], start [datetime], end [datetime], num_cmds [int], remark [unicode])
219
219
220 Sessions that are running or did not exit cleanly will have `end=None`
220 Sessions that are running or did not exit cleanly will have `end=None`
221 and `num_cmds=None`.
221 and `num_cmds=None`.
222
222
223 """
223 """
224
224
225 if session <= 0:
225 if session <= 0:
226 session += self.session_number
226 session += self.session_number
227
227
228 query = "SELECT * from sessions where session == ?"
228 query = "SELECT * from sessions where session == ?"
229 return self.db.execute(query, (session,)).fetchone()
229 return self.db.execute(query, (session,)).fetchone()
230
230
231 def get_tail(self, n=10, raw=True, output=False, include_latest=False):
231 def get_tail(self, n=10, raw=True, output=False, include_latest=False):
232 """Get the last n lines from the history database.
232 """Get the last n lines from the history database.
233
233
234 Parameters
234 Parameters
235 ----------
235 ----------
236 n : int
236 n : int
237 The number of lines to get
237 The number of lines to get
238 raw, output : bool
238 raw, output : bool
239 See :meth:`get_range`
239 See :meth:`get_range`
240 include_latest : bool
240 include_latest : bool
241 If False (default), n+1 lines are fetched, and the latest one
241 If False (default), n+1 lines are fetched, and the latest one
242 is discarded. This is intended to be used where the function
242 is discarded. This is intended to be used where the function
243 is called by a user command, which it should not return.
243 is called by a user command, which it should not return.
244
244
245 Returns
245 Returns
246 -------
246 -------
247 Tuples as :meth:`get_range`
247 Tuples as :meth:`get_range`
248 """
248 """
249 self.writeout_cache()
249 self.writeout_cache()
250 if not include_latest:
250 if not include_latest:
251 n += 1
251 n += 1
252 cur = self._run_sql("ORDER BY session DESC, line DESC LIMIT ?",
252 cur = self._run_sql("ORDER BY session DESC, line DESC LIMIT ?",
253 (n,), raw=raw, output=output)
253 (n,), raw=raw, output=output)
254 if not include_latest:
254 if not include_latest:
255 return reversed(list(cur)[1:])
255 return reversed(list(cur)[1:])
256 return reversed(list(cur))
256 return reversed(list(cur))
257
257
258 def search(self, pattern="*", raw=True, search_raw=True,
258 def search(self, pattern="*", raw=True, search_raw=True,
259 output=False):
259 output=False):
260 """Search the database using unix glob-style matching (wildcards
260 """Search the database using unix glob-style matching (wildcards
261 * and ?).
261 * and ?).
262
262
263 Parameters
263 Parameters
264 ----------
264 ----------
265 pattern : str
265 pattern : str
266 The wildcarded pattern to match when searching
266 The wildcarded pattern to match when searching
267 search_raw : bool
267 search_raw : bool
268 If True, search the raw input, otherwise, the parsed input
268 If True, search the raw input, otherwise, the parsed input
269 raw, output : bool
269 raw, output : bool
270 See :meth:`get_range`
270 See :meth:`get_range`
271
271
272 Returns
272 Returns
273 -------
273 -------
274 Tuples as :meth:`get_range`
274 Tuples as :meth:`get_range`
275 """
275 """
276 tosearch = "source_raw" if search_raw else "source"
276 tosearch = "source_raw" if search_raw else "source"
277 if output:
277 if output:
278 tosearch = "history." + tosearch
278 tosearch = "history." + tosearch
279 self.writeout_cache()
279 self.writeout_cache()
280 return self._run_sql("WHERE %s GLOB ?" % tosearch, (pattern,),
280 return self._run_sql("WHERE %s GLOB ?" % tosearch, (pattern,),
281 raw=raw, output=output)
281 raw=raw, output=output)
282
282
283 def get_range(self, session, start=1, stop=None, raw=True,output=False):
283 def get_range(self, session, start=1, stop=None, raw=True,output=False):
284 """Retrieve input by session.
284 """Retrieve input by session.
285
285
286 Parameters
286 Parameters
287 ----------
287 ----------
288 session : int
288 session : int
289 Session number to retrieve.
289 Session number to retrieve.
290 start : int
290 start : int
291 First line to retrieve.
291 First line to retrieve.
292 stop : int
292 stop : int
293 End of line range (excluded from output itself). If None, retrieve
293 End of line range (excluded from output itself). If None, retrieve
294 to the end of the session.
294 to the end of the session.
295 raw : bool
295 raw : bool
296 If True, return untranslated input
296 If True, return untranslated input
297 output : bool
297 output : bool
298 If True, attempt to include output. This will be 'real' Python
298 If True, attempt to include output. This will be 'real' Python
299 objects for the current session, or text reprs from previous
299 objects for the current session, or text reprs from previous
300 sessions if db_log_output was enabled at the time. Where no output
300 sessions if db_log_output was enabled at the time. Where no output
301 is found, None is used.
301 is found, None is used.
302
302
303 Returns
303 Returns
304 -------
304 -------
305 An iterator over the desired lines. Each line is a 3-tuple, either
305 An iterator over the desired lines. Each line is a 3-tuple, either
306 (session, line, input) if output is False, or
306 (session, line, input) if output is False, or
307 (session, line, (input, output)) if output is True.
307 (session, line, (input, output)) if output is True.
308 """
308 """
309 if stop:
309 if stop:
310 lineclause = "line >= ? AND line < ?"
310 lineclause = "line >= ? AND line < ?"
311 params = (session, start, stop)
311 params = (session, start, stop)
312 else:
312 else:
313 lineclause = "line>=?"
313 lineclause = "line>=?"
314 params = (session, start)
314 params = (session, start)
315
315
316 return self._run_sql("WHERE session==? AND %s""" % lineclause,
316 return self._run_sql("WHERE session==? AND %s""" % lineclause,
317 params, raw=raw, output=output)
317 params, raw=raw, output=output)
318
318
319 def get_range_by_str(self, rangestr, raw=True, output=False):
319 def get_range_by_str(self, rangestr, raw=True, output=False):
320 """Get lines of history from a string of ranges, as used by magic
320 """Get lines of history from a string of ranges, as used by magic
321 commands %hist, %save, %macro, etc.
321 commands %hist, %save, %macro, etc.
322
322
323 Parameters
323 Parameters
324 ----------
324 ----------
325 rangestr : str
325 rangestr : str
326 A string specifying ranges, e.g. "5 ~2/1-4". See
326 A string specifying ranges, e.g. "5 ~2/1-4". See
327 :func:`magic_history` for full details.
327 :func:`magic_history` for full details.
328 raw, output : bool
328 raw, output : bool
329 As :meth:`get_range`
329 As :meth:`get_range`
330
330
331 Returns
331 Returns
332 -------
332 -------
333 Tuples as :meth:`get_range`
333 Tuples as :meth:`get_range`
334 """
334 """
335 for sess, s, e in extract_hist_ranges(rangestr):
335 for sess, s, e in extract_hist_ranges(rangestr):
336 for line in self.get_range(sess, s, e, raw=raw, output=output):
336 for line in self.get_range(sess, s, e, raw=raw, output=output):
337 yield line
337 yield line
338
338
339
339
340 class HistoryManager(HistoryAccessor):
340 class HistoryManager(HistoryAccessor):
341 """A class to organize all history-related functionality in one place.
341 """A class to organize all history-related functionality in one place.
342 """
342 """
343 # Public interface
343 # Public interface
344
344
345 # An instance of the IPython shell we are attached to
345 # An instance of the IPython shell we are attached to
346 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
346 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
347 # Lists to hold processed and raw history. These start with a blank entry
347 # Lists to hold processed and raw history. These start with a blank entry
348 # so that we can index them starting from 1
348 # so that we can index them starting from 1
349 input_hist_parsed = List([""])
349 input_hist_parsed = List([""])
350 input_hist_raw = List([""])
350 input_hist_raw = List([""])
351 # A list of directories visited during session
351 # A list of directories visited during session
352 dir_hist = List()
352 dir_hist = List()
353 def _dir_hist_default(self):
353 def _dir_hist_default(self):
354 try:
354 try:
355 return [os.getcwdu()]
355 return [os.getcwdu()]
356 except OSError:
356 except OSError:
357 return []
357 return []
358
358
359 # A dict of output history, keyed with ints from the shell's
359 # A dict of output history, keyed with ints from the shell's
360 # execution count.
360 # execution count.
361 output_hist = Dict()
361 output_hist = Dict()
362 # The text/plain repr of outputs.
362 # The text/plain repr of outputs.
363 output_hist_reprs = Dict()
363 output_hist_reprs = Dict()
364
364
365 # The number of the current session in the history database
365 # The number of the current session in the history database
366 session_number = Integer()
366 session_number = Integer()
367 # Should we log output to the database? (default no)
367 # Should we log output to the database? (default no)
368 db_log_output = Bool(False, config=True)
368 db_log_output = Bool(False, config=True)
369 # Write to database every x commands (higher values save disk access & power)
369 # Write to database every x commands (higher values save disk access & power)
370 # Values of 1 or less effectively disable caching.
370 # Values of 1 or less effectively disable caching.
371 db_cache_size = Integer(0, config=True)
371 db_cache_size = Integer(0, config=True)
372 # The input and output caches
372 # The input and output caches
373 db_input_cache = List()
373 db_input_cache = List()
374 db_output_cache = List()
374 db_output_cache = List()
375
375
376 # History saving in separate thread
376 # History saving in separate thread
377 save_thread = Instance('IPython.core.history.HistorySavingThread')
377 save_thread = Instance('IPython.core.history.HistorySavingThread')
378 try: # Event is a function returning an instance of _Event...
378 try: # Event is a function returning an instance of _Event...
379 save_flag = Instance(threading._Event)
379 save_flag = Instance(threading._Event)
380 except AttributeError: # ...until Python 3.3, when it's a class.
380 except AttributeError: # ...until Python 3.3, when it's a class.
381 save_flag = Instance(threading.Event)
381 save_flag = Instance(threading.Event)
382
382
383 # Private interface
383 # Private interface
384 # Variables used to store the three last inputs from the user. On each new
384 # Variables used to store the three last inputs from the user. On each new
385 # history update, we populate the user's namespace with these, shifted as
385 # history update, we populate the user's namespace with these, shifted as
386 # necessary.
386 # necessary.
387 _i00 = Unicode(u'')
387 _i00 = Unicode(u'')
388 _i = Unicode(u'')
388 _i = Unicode(u'')
389 _ii = Unicode(u'')
389 _ii = Unicode(u'')
390 _iii = Unicode(u'')
390 _iii = Unicode(u'')
391
391
392 # A regex matching all forms of the exit command, so that we don't store
392 # A regex matching all forms of the exit command, so that we don't store
393 # them in the history (it's annoying to rewind the first entry and land on
393 # them in the history (it's annoying to rewind the first entry and land on
394 # an exit call).
394 # an exit call).
395 _exit_re = re.compile(r"(exit|quit)(\s*\(.*\))?$")
395 _exit_re = re.compile(r"(exit|quit)(\s*\(.*\))?$")
396
396
397 def __init__(self, shell=None, config=None, **traits):
397 def __init__(self, shell=None, config=None, **traits):
398 """Create a new history manager associated with a shell instance.
398 """Create a new history manager associated with a shell instance.
399 """
399 """
400 # We need a pointer back to the shell for various tasks.
400 # We need a pointer back to the shell for various tasks.
401 super(HistoryManager, self).__init__(shell=shell, config=config,
401 super(HistoryManager, self).__init__(shell=shell, config=config,
402 **traits)
402 **traits)
403 self.save_flag = threading.Event()
403 self.save_flag = threading.Event()
404 self.db_input_cache_lock = threading.Lock()
404 self.db_input_cache_lock = threading.Lock()
405 self.db_output_cache_lock = threading.Lock()
405 self.db_output_cache_lock = threading.Lock()
406 self.save_thread = HistorySavingThread(self)
406 if self.hist_file != ':memory:':
407 self.save_thread.start()
407 self.save_thread = HistorySavingThread(self)
408 self.save_thread.start()
408
409
409 self.new_session()
410 self.new_session()
410
411
411 def _get_hist_file_name(self, profile=None):
412 def _get_hist_file_name(self, profile=None):
412 """Get default history file name based on the Shell's profile.
413 """Get default history file name based on the Shell's profile.
413
414
414 The profile parameter is ignored, but must exist for compatibility with
415 The profile parameter is ignored, but must exist for compatibility with
415 the parent class."""
416 the parent class."""
416 profile_dir = self.shell.profile_dir.location
417 profile_dir = self.shell.profile_dir.location
417 return os.path.join(profile_dir, 'history.sqlite')
418 return os.path.join(profile_dir, 'history.sqlite')
418
419
419 @needs_sqlite
420 @needs_sqlite
420 def new_session(self, conn=None):
421 def new_session(self, conn=None):
421 """Get a new session number."""
422 """Get a new session number."""
422 if conn is None:
423 if conn is None:
423 conn = self.db
424 conn = self.db
424
425
425 with conn:
426 with conn:
426 cur = conn.execute("""INSERT INTO sessions VALUES (NULL, ?, NULL,
427 cur = conn.execute("""INSERT INTO sessions VALUES (NULL, ?, NULL,
427 NULL, "") """, (datetime.datetime.now(),))
428 NULL, "") """, (datetime.datetime.now(),))
428 self.session_number = cur.lastrowid
429 self.session_number = cur.lastrowid
429
430
430 def end_session(self):
431 def end_session(self):
431 """Close the database session, filling in the end time and line count."""
432 """Close the database session, filling in the end time and line count."""
432 self.writeout_cache()
433 self.writeout_cache()
433 with self.db:
434 with self.db:
434 self.db.execute("""UPDATE sessions SET end=?, num_cmds=? WHERE
435 self.db.execute("""UPDATE sessions SET end=?, num_cmds=? WHERE
435 session==?""", (datetime.datetime.now(),
436 session==?""", (datetime.datetime.now(),
436 len(self.input_hist_parsed)-1, self.session_number))
437 len(self.input_hist_parsed)-1, self.session_number))
437 self.session_number = 0
438 self.session_number = 0
438
439
439 def name_session(self, name):
440 def name_session(self, name):
440 """Give the current session a name in the history database."""
441 """Give the current session a name in the history database."""
441 with self.db:
442 with self.db:
442 self.db.execute("UPDATE sessions SET remark=? WHERE session==?",
443 self.db.execute("UPDATE sessions SET remark=? WHERE session==?",
443 (name, self.session_number))
444 (name, self.session_number))
444
445
445 def reset(self, new_session=True):
446 def reset(self, new_session=True):
446 """Clear the session history, releasing all object references, and
447 """Clear the session history, releasing all object references, and
447 optionally open a new session."""
448 optionally open a new session."""
448 self.output_hist.clear()
449 self.output_hist.clear()
449 # The directory history can't be completely empty
450 # The directory history can't be completely empty
450 self.dir_hist[:] = [os.getcwdu()]
451 self.dir_hist[:] = [os.getcwdu()]
451
452
452 if new_session:
453 if new_session:
453 if self.session_number:
454 if self.session_number:
454 self.end_session()
455 self.end_session()
455 self.input_hist_parsed[:] = [""]
456 self.input_hist_parsed[:] = [""]
456 self.input_hist_raw[:] = [""]
457 self.input_hist_raw[:] = [""]
457 self.new_session()
458 self.new_session()
458
459
459 # ------------------------------
460 # ------------------------------
460 # Methods for retrieving history
461 # Methods for retrieving history
461 # ------------------------------
462 # ------------------------------
462 def _get_range_session(self, start=1, stop=None, raw=True, output=False):
463 def _get_range_session(self, start=1, stop=None, raw=True, output=False):
463 """Get input and output history from the current session. Called by
464 """Get input and output history from the current session. Called by
464 get_range, and takes similar parameters."""
465 get_range, and takes similar parameters."""
465 input_hist = self.input_hist_raw if raw else self.input_hist_parsed
466 input_hist = self.input_hist_raw if raw else self.input_hist_parsed
466
467
467 n = len(input_hist)
468 n = len(input_hist)
468 if start < 0:
469 if start < 0:
469 start += n
470 start += n
470 if not stop or (stop > n):
471 if not stop or (stop > n):
471 stop = n
472 stop = n
472 elif stop < 0:
473 elif stop < 0:
473 stop += n
474 stop += n
474
475
475 for i in range(start, stop):
476 for i in range(start, stop):
476 if output:
477 if output:
477 line = (input_hist[i], self.output_hist_reprs.get(i))
478 line = (input_hist[i], self.output_hist_reprs.get(i))
478 else:
479 else:
479 line = input_hist[i]
480 line = input_hist[i]
480 yield (0, i, line)
481 yield (0, i, line)
481
482
482 def get_range(self, session=0, start=1, stop=None, raw=True,output=False):
483 def get_range(self, session=0, start=1, stop=None, raw=True,output=False):
483 """Retrieve input by session.
484 """Retrieve input by session.
484
485
485 Parameters
486 Parameters
486 ----------
487 ----------
487 session : int
488 session : int
488 Session number to retrieve. The current session is 0, and negative
489 Session number to retrieve. The current session is 0, and negative
489 numbers count back from current session, so -1 is previous session.
490 numbers count back from current session, so -1 is previous session.
490 start : int
491 start : int
491 First line to retrieve.
492 First line to retrieve.
492 stop : int
493 stop : int
493 End of line range (excluded from output itself). If None, retrieve
494 End of line range (excluded from output itself). If None, retrieve
494 to the end of the session.
495 to the end of the session.
495 raw : bool
496 raw : bool
496 If True, return untranslated input
497 If True, return untranslated input
497 output : bool
498 output : bool
498 If True, attempt to include output. This will be 'real' Python
499 If True, attempt to include output. This will be 'real' Python
499 objects for the current session, or text reprs from previous
500 objects for the current session, or text reprs from previous
500 sessions if db_log_output was enabled at the time. Where no output
501 sessions if db_log_output was enabled at the time. Where no output
501 is found, None is used.
502 is found, None is used.
502
503
503 Returns
504 Returns
504 -------
505 -------
505 An iterator over the desired lines. Each line is a 3-tuple, either
506 An iterator over the desired lines. Each line is a 3-tuple, either
506 (session, line, input) if output is False, or
507 (session, line, input) if output is False, or
507 (session, line, (input, output)) if output is True.
508 (session, line, (input, output)) if output is True.
508 """
509 """
509 if session <= 0:
510 if session <= 0:
510 session += self.session_number
511 session += self.session_number
511 if session==self.session_number: # Current session
512 if session==self.session_number: # Current session
512 return self._get_range_session(start, stop, raw, output)
513 return self._get_range_session(start, stop, raw, output)
513 return super(HistoryManager, self).get_range(session, start, stop, raw, output)
514 return super(HistoryManager, self).get_range(session, start, stop, raw, output)
514
515
515 ## ----------------------------
516 ## ----------------------------
516 ## Methods for storing history:
517 ## Methods for storing history:
517 ## ----------------------------
518 ## ----------------------------
518 def store_inputs(self, line_num, source, source_raw=None):
519 def store_inputs(self, line_num, source, source_raw=None):
519 """Store source and raw input in history and create input cache
520 """Store source and raw input in history and create input cache
520 variables _i*.
521 variables _i*.
521
522
522 Parameters
523 Parameters
523 ----------
524 ----------
524 line_num : int
525 line_num : int
525 The prompt number of this input.
526 The prompt number of this input.
526
527
527 source : str
528 source : str
528 Python input.
529 Python input.
529
530
530 source_raw : str, optional
531 source_raw : str, optional
531 If given, this is the raw input without any IPython transformations
532 If given, this is the raw input without any IPython transformations
532 applied to it. If not given, ``source`` is used.
533 applied to it. If not given, ``source`` is used.
533 """
534 """
534 if source_raw is None:
535 if source_raw is None:
535 source_raw = source
536 source_raw = source
536 source = source.rstrip('\n')
537 source = source.rstrip('\n')
537 source_raw = source_raw.rstrip('\n')
538 source_raw = source_raw.rstrip('\n')
538
539
539 # do not store exit/quit commands
540 # do not store exit/quit commands
540 if self._exit_re.match(source_raw.strip()):
541 if self._exit_re.match(source_raw.strip()):
541 return
542 return
542
543
543 self.input_hist_parsed.append(source)
544 self.input_hist_parsed.append(source)
544 self.input_hist_raw.append(source_raw)
545 self.input_hist_raw.append(source_raw)
545
546
546 with self.db_input_cache_lock:
547 with self.db_input_cache_lock:
547 self.db_input_cache.append((line_num, source, source_raw))
548 self.db_input_cache.append((line_num, source, source_raw))
548 # Trigger to flush cache and write to DB.
549 # Trigger to flush cache and write to DB.
549 if len(self.db_input_cache) >= self.db_cache_size:
550 if len(self.db_input_cache) >= self.db_cache_size:
550 self.save_flag.set()
551 self.save_flag.set()
551
552
552 # update the auto _i variables
553 # update the auto _i variables
553 self._iii = self._ii
554 self._iii = self._ii
554 self._ii = self._i
555 self._ii = self._i
555 self._i = self._i00
556 self._i = self._i00
556 self._i00 = source_raw
557 self._i00 = source_raw
557
558
558 # hackish access to user namespace to create _i1,_i2... dynamically
559 # hackish access to user namespace to create _i1,_i2... dynamically
559 new_i = '_i%s' % line_num
560 new_i = '_i%s' % line_num
560 to_main = {'_i': self._i,
561 to_main = {'_i': self._i,
561 '_ii': self._ii,
562 '_ii': self._ii,
562 '_iii': self._iii,
563 '_iii': self._iii,
563 new_i : self._i00 }
564 new_i : self._i00 }
564
565
565 self.shell.push(to_main, interactive=False)
566 self.shell.push(to_main, interactive=False)
566
567
567 def store_output(self, line_num):
568 def store_output(self, line_num):
568 """If database output logging is enabled, this saves all the
569 """If database output logging is enabled, this saves all the
569 outputs from the indicated prompt number to the database. It's
570 outputs from the indicated prompt number to the database. It's
570 called by run_cell after code has been executed.
571 called by run_cell after code has been executed.
571
572
572 Parameters
573 Parameters
573 ----------
574 ----------
574 line_num : int
575 line_num : int
575 The line number from which to save outputs
576 The line number from which to save outputs
576 """
577 """
577 if (not self.db_log_output) or (line_num not in self.output_hist_reprs):
578 if (not self.db_log_output) or (line_num not in self.output_hist_reprs):
578 return
579 return
579 output = self.output_hist_reprs[line_num]
580 output = self.output_hist_reprs[line_num]
580
581
581 with self.db_output_cache_lock:
582 with self.db_output_cache_lock:
582 self.db_output_cache.append((line_num, output))
583 self.db_output_cache.append((line_num, output))
583 if self.db_cache_size <= 1:
584 if self.db_cache_size <= 1:
584 self.save_flag.set()
585 self.save_flag.set()
585
586
586 def _writeout_input_cache(self, conn):
587 def _writeout_input_cache(self, conn):
587 with conn:
588 with conn:
588 for line in self.db_input_cache:
589 for line in self.db_input_cache:
589 conn.execute("INSERT INTO history VALUES (?, ?, ?, ?)",
590 conn.execute("INSERT INTO history VALUES (?, ?, ?, ?)",
590 (self.session_number,)+line)
591 (self.session_number,)+line)
591
592
592 def _writeout_output_cache(self, conn):
593 def _writeout_output_cache(self, conn):
593 with conn:
594 with conn:
594 for line in self.db_output_cache:
595 for line in self.db_output_cache:
595 conn.execute("INSERT INTO output_history VALUES (?, ?, ?)",
596 conn.execute("INSERT INTO output_history VALUES (?, ?, ?)",
596 (self.session_number,)+line)
597 (self.session_number,)+line)
597
598
598 @needs_sqlite
599 @needs_sqlite
599 def writeout_cache(self, conn=None):
600 def writeout_cache(self, conn=None):
600 """Write any entries in the cache to the database."""
601 """Write any entries in the cache to the database."""
601 if conn is None:
602 if conn is None:
602 conn = self.db
603 conn = self.db
603
604
604 with self.db_input_cache_lock:
605 with self.db_input_cache_lock:
605 try:
606 try:
606 self._writeout_input_cache(conn)
607 self._writeout_input_cache(conn)
607 except sqlite3.IntegrityError:
608 except sqlite3.IntegrityError:
608 self.new_session(conn)
609 self.new_session(conn)
609 print("ERROR! Session/line number was not unique in",
610 print("ERROR! Session/line number was not unique in",
610 "database. History logging moved to new session",
611 "database. History logging moved to new session",
611 self.session_number)
612 self.session_number)
612 try: # Try writing to the new session. If this fails, don't recurse
613 try: # Try writing to the new session. If this fails, don't recurse
613 self._writeout_input_cache(conn)
614 self._writeout_input_cache(conn)
614 except sqlite3.IntegrityError:
615 except sqlite3.IntegrityError:
615 pass
616 pass
616 finally:
617 finally:
617 self.db_input_cache = []
618 self.db_input_cache = []
618
619
619 with self.db_output_cache_lock:
620 with self.db_output_cache_lock:
620 try:
621 try:
621 self._writeout_output_cache(conn)
622 self._writeout_output_cache(conn)
622 except sqlite3.IntegrityError:
623 except sqlite3.IntegrityError:
623 print("!! Session/line number for output was not unique",
624 print("!! Session/line number for output was not unique",
624 "in database. Output will not be stored.")
625 "in database. Output will not be stored.")
625 finally:
626 finally:
626 self.db_output_cache = []
627 self.db_output_cache = []
627
628
628
629
629 class HistorySavingThread(threading.Thread):
630 class HistorySavingThread(threading.Thread):
630 """This thread takes care of writing history to the database, so that
631 """This thread takes care of writing history to the database, so that
631 the UI isn't held up while that happens.
632 the UI isn't held up while that happens.
632
633
633 It waits for the HistoryManager's save_flag to be set, then writes out
634 It waits for the HistoryManager's save_flag to be set, then writes out
634 the history cache. The main thread is responsible for setting the flag when
635 the history cache. The main thread is responsible for setting the flag when
635 the cache size reaches a defined threshold."""
636 the cache size reaches a defined threshold."""
636 daemon = True
637 daemon = True
637 stop_now = False
638 stop_now = False
638 def __init__(self, history_manager):
639 def __init__(self, history_manager):
639 super(HistorySavingThread, self).__init__()
640 super(HistorySavingThread, self).__init__()
640 self.history_manager = history_manager
641 self.history_manager = history_manager
641 atexit.register(self.stop)
642 atexit.register(self.stop)
642
643
643 @needs_sqlite
644 @needs_sqlite
644 def run(self):
645 def run(self):
645 # We need a separate db connection per thread:
646 # We need a separate db connection per thread:
646 try:
647 try:
647 self.db = sqlite3.connect(self.history_manager.hist_file)
648 self.db = sqlite3.connect(self.history_manager.hist_file)
648 while True:
649 while True:
649 self.history_manager.save_flag.wait()
650 self.history_manager.save_flag.wait()
650 if self.stop_now:
651 if self.stop_now:
651 return
652 return
652 self.history_manager.save_flag.clear()
653 self.history_manager.save_flag.clear()
653 self.history_manager.writeout_cache(self.db)
654 self.history_manager.writeout_cache(self.db)
654 except Exception as e:
655 except Exception as e:
655 print(("The history saving thread hit an unexpected error (%s)."
656 print(("The history saving thread hit an unexpected error (%s)."
656 "History will not be written to the database.") % repr(e))
657 "History will not be written to the database.") % repr(e))
657
658
658 def stop(self):
659 def stop(self):
659 """This can be called from the main thread to safely stop this thread.
660 """This can be called from the main thread to safely stop this thread.
660
661
661 Note that it does not attempt to write out remaining history before
662 Note that it does not attempt to write out remaining history before
662 exiting. That should be done by calling the HistoryManager's
663 exiting. That should be done by calling the HistoryManager's
663 end_session method."""
664 end_session method."""
664 self.stop_now = True
665 self.stop_now = True
665 self.history_manager.save_flag.set()
666 self.history_manager.save_flag.set()
666 self.join()
667 self.join()
667
668
668
669
669 # To match, e.g. ~5/8-~2/3
670 # To match, e.g. ~5/8-~2/3
670 range_re = re.compile(r"""
671 range_re = re.compile(r"""
671 ((?P<startsess>~?\d+)/)?
672 ((?P<startsess>~?\d+)/)?
672 (?P<start>\d+) # Only the start line num is compulsory
673 (?P<start>\d+) # Only the start line num is compulsory
673 ((?P<sep>[\-:])
674 ((?P<sep>[\-:])
674 ((?P<endsess>~?\d+)/)?
675 ((?P<endsess>~?\d+)/)?
675 (?P<end>\d+))?
676 (?P<end>\d+))?
676 $""", re.VERBOSE)
677 $""", re.VERBOSE)
677
678
678 def extract_hist_ranges(ranges_str):
679 def extract_hist_ranges(ranges_str):
679 """Turn a string of history ranges into 3-tuples of (session, start, stop).
680 """Turn a string of history ranges into 3-tuples of (session, start, stop).
680
681
681 Examples
682 Examples
682 --------
683 --------
683 list(extract_input_ranges("~8/5-~7/4 2"))
684 list(extract_input_ranges("~8/5-~7/4 2"))
684 [(-8, 5, None), (-7, 1, 4), (0, 2, 3)]
685 [(-8, 5, None), (-7, 1, 4), (0, 2, 3)]
685 """
686 """
686 for range_str in ranges_str.split():
687 for range_str in ranges_str.split():
687 rmatch = range_re.match(range_str)
688 rmatch = range_re.match(range_str)
688 if not rmatch:
689 if not rmatch:
689 continue
690 continue
690 start = int(rmatch.group("start"))
691 start = int(rmatch.group("start"))
691 end = rmatch.group("end")
692 end = rmatch.group("end")
692 end = int(end) if end else start+1 # If no end specified, get (a, a+1)
693 end = int(end) if end else start+1 # If no end specified, get (a, a+1)
693 if rmatch.group("sep") == "-": # 1-3 == 1:4 --> [1, 2, 3]
694 if rmatch.group("sep") == "-": # 1-3 == 1:4 --> [1, 2, 3]
694 end += 1
695 end += 1
695 startsess = rmatch.group("startsess") or "0"
696 startsess = rmatch.group("startsess") or "0"
696 endsess = rmatch.group("endsess") or startsess
697 endsess = rmatch.group("endsess") or startsess
697 startsess = int(startsess.replace("~","-"))
698 startsess = int(startsess.replace("~","-"))
698 endsess = int(endsess.replace("~","-"))
699 endsess = int(endsess.replace("~","-"))
699 assert endsess >= startsess
700 assert endsess >= startsess
700
701
701 if endsess == startsess:
702 if endsess == startsess:
702 yield (startsess, start, end)
703 yield (startsess, start, end)
703 continue
704 continue
704 # Multiple sessions in one range:
705 # Multiple sessions in one range:
705 yield (startsess, start, None)
706 yield (startsess, start, None)
706 for sess in range(startsess+1, endsess):
707 for sess in range(startsess+1, endsess):
707 yield (sess, 1, None)
708 yield (sess, 1, None)
708 yield (endsess, 1, end)
709 yield (endsess, 1, end)
709
710
710 def _format_lineno(session, line):
711 def _format_lineno(session, line):
711 """Helper function to format line numbers properly."""
712 """Helper function to format line numbers properly."""
712 if session == 0:
713 if session == 0:
713 return str(line)
714 return str(line)
714 return "%s#%s" % (session, line)
715 return "%s#%s" % (session, line)
715
716
716 @skip_doctest
717 @skip_doctest
717 def magic_history(self, parameter_s = ''):
718 def magic_history(self, parameter_s = ''):
718 """Print input history (_i<n> variables), with most recent last.
719 """Print input history (_i<n> variables), with most recent last.
719
720
720 %history [-o -p -t -n] [-f filename] [range | -g pattern | -l number]
721 %history [-o -p -t -n] [-f filename] [range | -g pattern | -l number]
721
722
722 By default, input history is printed without line numbers so it can be
723 By default, input history is printed without line numbers so it can be
723 directly pasted into an editor. Use -n to show them.
724 directly pasted into an editor. Use -n to show them.
724
725
725 By default, all input history from the current session is displayed.
726 By default, all input history from the current session is displayed.
726 Ranges of history can be indicated using the syntax:
727 Ranges of history can be indicated using the syntax:
727 4 : Line 4, current session
728 4 : Line 4, current session
728 4-6 : Lines 4-6, current session
729 4-6 : Lines 4-6, current session
729 243/1-5: Lines 1-5, session 243
730 243/1-5: Lines 1-5, session 243
730 ~2/7 : Line 7, session 2 before current
731 ~2/7 : Line 7, session 2 before current
731 ~8/1-~6/5 : From the first line of 8 sessions ago, to the fifth line
732 ~8/1-~6/5 : From the first line of 8 sessions ago, to the fifth line
732 of 6 sessions ago.
733 of 6 sessions ago.
733 Multiple ranges can be entered, separated by spaces
734 Multiple ranges can be entered, separated by spaces
734
735
735 The same syntax is used by %macro, %save, %edit, %rerun
736 The same syntax is used by %macro, %save, %edit, %rerun
736
737
737 Options:
738 Options:
738
739
739 -n: print line numbers for each input.
740 -n: print line numbers for each input.
740 This feature is only available if numbered prompts are in use.
741 This feature is only available if numbered prompts are in use.
741
742
742 -o: also print outputs for each input.
743 -o: also print outputs for each input.
743
744
744 -p: print classic '>>>' python prompts before each input. This is useful
745 -p: print classic '>>>' python prompts before each input. This is useful
745 for making documentation, and in conjunction with -o, for producing
746 for making documentation, and in conjunction with -o, for producing
746 doctest-ready output.
747 doctest-ready output.
747
748
748 -r: (default) print the 'raw' history, i.e. the actual commands you typed.
749 -r: (default) print the 'raw' history, i.e. the actual commands you typed.
749
750
750 -t: print the 'translated' history, as IPython understands it. IPython
751 -t: print the 'translated' history, as IPython understands it. IPython
751 filters your input and converts it all into valid Python source before
752 filters your input and converts it all into valid Python source before
752 executing it (things like magics or aliases are turned into function
753 executing it (things like magics or aliases are turned into function
753 calls, for example). With this option, you'll see the native history
754 calls, for example). With this option, you'll see the native history
754 instead of the user-entered version: '%cd /' will be seen as
755 instead of the user-entered version: '%cd /' will be seen as
755 'get_ipython().magic("%cd /")' instead of '%cd /'.
756 'get_ipython().magic("%cd /")' instead of '%cd /'.
756
757
757 -g: treat the arg as a pattern to grep for in (full) history.
758 -g: treat the arg as a pattern to grep for in (full) history.
758 This includes the saved history (almost all commands ever written).
759 This includes the saved history (almost all commands ever written).
759 Use '%hist -g' to show full saved history (may be very long).
760 Use '%hist -g' to show full saved history (may be very long).
760
761
761 -l: get the last n lines from all sessions. Specify n as a single arg, or
762 -l: get the last n lines from all sessions. Specify n as a single arg, or
762 the default is the last 10 lines.
763 the default is the last 10 lines.
763
764
764 -f FILENAME: instead of printing the output to the screen, redirect it to
765 -f FILENAME: instead of printing the output to the screen, redirect it to
765 the given file. The file is always overwritten, though *when it can*,
766 the given file. The file is always overwritten, though *when it can*,
766 IPython asks for confirmation first. In particular, running the command
767 IPython asks for confirmation first. In particular, running the command
767 "history -f FILENAME" from the IPython Notebook interface will replace
768 "history -f FILENAME" from the IPython Notebook interface will replace
768 FILENAME even if it already exists *without* confirmation.
769 FILENAME even if it already exists *without* confirmation.
769
770
770 Examples
771 Examples
771 --------
772 --------
772 ::
773 ::
773
774
774 In [6]: %hist -n 4-6
775 In [6]: %hist -n 4-6
775 4:a = 12
776 4:a = 12
776 5:print a**2
777 5:print a**2
777 6:%hist -n 4-6
778 6:%hist -n 4-6
778
779
779 """
780 """
780
781
781 if not self.shell.displayhook.do_full_cache:
782 if not self.shell.displayhook.do_full_cache:
782 print('This feature is only available if numbered prompts are in use.')
783 print('This feature is only available if numbered prompts are in use.')
783 return
784 return
784 opts,args = self.parse_options(parameter_s,'noprtglf:',mode='string')
785 opts,args = self.parse_options(parameter_s,'noprtglf:',mode='string')
785
786
786 # For brevity
787 # For brevity
787 history_manager = self.shell.history_manager
788 history_manager = self.shell.history_manager
788
789
789 def _format_lineno(session, line):
790 def _format_lineno(session, line):
790 """Helper function to format line numbers properly."""
791 """Helper function to format line numbers properly."""
791 if session in (0, history_manager.session_number):
792 if session in (0, history_manager.session_number):
792 return str(line)
793 return str(line)
793 return "%s/%s" % (session, line)
794 return "%s/%s" % (session, line)
794
795
795 # Check if output to specific file was requested.
796 # Check if output to specific file was requested.
796 try:
797 try:
797 outfname = opts['f']
798 outfname = opts['f']
798 except KeyError:
799 except KeyError:
799 outfile = io.stdout # default
800 outfile = io.stdout # default
800 # We don't want to close stdout at the end!
801 # We don't want to close stdout at the end!
801 close_at_end = False
802 close_at_end = False
802 else:
803 else:
803 if os.path.exists(outfname):
804 if os.path.exists(outfname):
804 try:
805 try:
805 ans = io.ask_yes_no("File %r exists. Overwrite?" % outfname)
806 ans = io.ask_yes_no("File %r exists. Overwrite?" % outfname)
806 except StdinNotImplementedError:
807 except StdinNotImplementedError:
807 ans = True
808 ans = True
808 if not ans:
809 if not ans:
809 print('Aborting.')
810 print('Aborting.')
810 return
811 return
811 print("Overwriting file.")
812 print("Overwriting file.")
812 outfile = io_open(outfname, 'w', encoding='utf-8')
813 outfile = io_open(outfname, 'w', encoding='utf-8')
813 close_at_end = True
814 close_at_end = True
814
815
815 print_nums = 'n' in opts
816 print_nums = 'n' in opts
816 get_output = 'o' in opts
817 get_output = 'o' in opts
817 pyprompts = 'p' in opts
818 pyprompts = 'p' in opts
818 # Raw history is the default
819 # Raw history is the default
819 raw = not('t' in opts)
820 raw = not('t' in opts)
820
821
821 default_length = 40
822 default_length = 40
822 pattern = None
823 pattern = None
823
824
824 if 'g' in opts: # Glob search
825 if 'g' in opts: # Glob search
825 pattern = "*" + args + "*" if args else "*"
826 pattern = "*" + args + "*" if args else "*"
826 hist = history_manager.search(pattern, raw=raw, output=get_output)
827 hist = history_manager.search(pattern, raw=raw, output=get_output)
827 print_nums = True
828 print_nums = True
828 elif 'l' in opts: # Get 'tail'
829 elif 'l' in opts: # Get 'tail'
829 try:
830 try:
830 n = int(args)
831 n = int(args)
831 except ValueError, IndexError:
832 except ValueError, IndexError:
832 n = 10
833 n = 10
833 hist = history_manager.get_tail(n, raw=raw, output=get_output)
834 hist = history_manager.get_tail(n, raw=raw, output=get_output)
834 else:
835 else:
835 if args: # Get history by ranges
836 if args: # Get history by ranges
836 hist = history_manager.get_range_by_str(args, raw, get_output)
837 hist = history_manager.get_range_by_str(args, raw, get_output)
837 else: # Just get history for the current session
838 else: # Just get history for the current session
838 hist = history_manager.get_range(raw=raw, output=get_output)
839 hist = history_manager.get_range(raw=raw, output=get_output)
839
840
840 # We could be displaying the entire history, so let's not try to pull it
841 # We could be displaying the entire history, so let's not try to pull it
841 # into a list in memory. Anything that needs more space will just misalign.
842 # into a list in memory. Anything that needs more space will just misalign.
842 width = 4
843 width = 4
843
844
844 for session, lineno, inline in hist:
845 for session, lineno, inline in hist:
845 # Print user history with tabs expanded to 4 spaces. The GUI clients
846 # Print user history with tabs expanded to 4 spaces. The GUI clients
846 # use hard tabs for easier usability in auto-indented code, but we want
847 # use hard tabs for easier usability in auto-indented code, but we want
847 # to produce PEP-8 compliant history for safe pasting into an editor.
848 # to produce PEP-8 compliant history for safe pasting into an editor.
848 if get_output:
849 if get_output:
849 inline, output = inline
850 inline, output = inline
850 inline = inline.expandtabs(4).rstrip()
851 inline = inline.expandtabs(4).rstrip()
851
852
852 multiline = "\n" in inline
853 multiline = "\n" in inline
853 line_sep = '\n' if multiline else ' '
854 line_sep = '\n' if multiline else ' '
854 if print_nums:
855 if print_nums:
855 print(u'%s:%s' % (_format_lineno(session, lineno).rjust(width),
856 print(u'%s:%s' % (_format_lineno(session, lineno).rjust(width),
856 line_sep), file=outfile, end=u'')
857 line_sep), file=outfile, end=u'')
857 if pyprompts:
858 if pyprompts:
858 print(u">>> ", end=u"", file=outfile)
859 print(u">>> ", end=u"", file=outfile)
859 if multiline:
860 if multiline:
860 inline = "\n... ".join(inline.splitlines()) + "\n..."
861 inline = "\n... ".join(inline.splitlines()) + "\n..."
861 print(inline, file=outfile)
862 print(inline, file=outfile)
862 if get_output and output:
863 if get_output and output:
863 print(output, file=outfile)
864 print(output, file=outfile)
864
865
865 if close_at_end:
866 if close_at_end:
866 outfile.close()
867 outfile.close()
867
868
868
869
869 def magic_rep(self, arg):
870 def magic_rep(self, arg):
870 r"""Repeat a command, or get command to input line for editing.
871 r"""Repeat a command, or get command to input line for editing.
871
872
872 %recall and %rep are equivalent.
873 %recall and %rep are equivalent.
873
874
874 - %recall (no arguments):
875 - %recall (no arguments):
875
876
876 Place a string version of last computation result (stored in the special '_'
877 Place a string version of last computation result (stored in the special '_'
877 variable) to the next input prompt. Allows you to create elaborate command
878 variable) to the next input prompt. Allows you to create elaborate command
878 lines without using copy-paste::
879 lines without using copy-paste::
879
880
880 In[1]: l = ["hei", "vaan"]
881 In[1]: l = ["hei", "vaan"]
881 In[2]: "".join(l)
882 In[2]: "".join(l)
882 Out[2]: heivaan
883 Out[2]: heivaan
883 In[3]: %rep
884 In[3]: %rep
884 In[4]: heivaan_ <== cursor blinking
885 In[4]: heivaan_ <== cursor blinking
885
886
886 %recall 45
887 %recall 45
887
888
888 Place history line 45 on the next input prompt. Use %hist to find
889 Place history line 45 on the next input prompt. Use %hist to find
889 out the number.
890 out the number.
890
891
891 %recall 1-4
892 %recall 1-4
892
893
893 Combine the specified lines into one cell, and place it on the next
894 Combine the specified lines into one cell, and place it on the next
894 input prompt. See %history for the slice syntax.
895 input prompt. See %history for the slice syntax.
895
896
896 %recall foo+bar
897 %recall foo+bar
897
898
898 If foo+bar can be evaluated in the user namespace, the result is
899 If foo+bar can be evaluated in the user namespace, the result is
899 placed at the next input prompt. Otherwise, the history is searched
900 placed at the next input prompt. Otherwise, the history is searched
900 for lines which contain that substring, and the most recent one is
901 for lines which contain that substring, and the most recent one is
901 placed at the next input prompt.
902 placed at the next input prompt.
902 """
903 """
903 if not arg: # Last output
904 if not arg: # Last output
904 self.set_next_input(str(self.shell.user_ns["_"]))
905 self.set_next_input(str(self.shell.user_ns["_"]))
905 return
906 return
906 # Get history range
907 # Get history range
907 histlines = self.history_manager.get_range_by_str(arg)
908 histlines = self.history_manager.get_range_by_str(arg)
908 cmd = "\n".join(x[2] for x in histlines)
909 cmd = "\n".join(x[2] for x in histlines)
909 if cmd:
910 if cmd:
910 self.set_next_input(cmd.rstrip())
911 self.set_next_input(cmd.rstrip())
911 return
912 return
912
913
913 try: # Variable in user namespace
914 try: # Variable in user namespace
914 cmd = str(eval(arg, self.shell.user_ns))
915 cmd = str(eval(arg, self.shell.user_ns))
915 except Exception: # Search for term in history
916 except Exception: # Search for term in history
916 histlines = self.history_manager.search("*"+arg+"*")
917 histlines = self.history_manager.search("*"+arg+"*")
917 for h in reversed([x[2] for x in histlines]):
918 for h in reversed([x[2] for x in histlines]):
918 if 'rep' in h:
919 if 'rep' in h:
919 continue
920 continue
920 self.set_next_input(h.rstrip())
921 self.set_next_input(h.rstrip())
921 return
922 return
922 else:
923 else:
923 self.set_next_input(cmd.rstrip())
924 self.set_next_input(cmd.rstrip())
924 print("Couldn't evaluate or find in history:", arg)
925 print("Couldn't evaluate or find in history:", arg)
925
926
926 def magic_rerun(self, parameter_s=''):
927 def magic_rerun(self, parameter_s=''):
927 """Re-run previous input
928 """Re-run previous input
928
929
929 By default, you can specify ranges of input history to be repeated
930 By default, you can specify ranges of input history to be repeated
930 (as with %history). With no arguments, it will repeat the last line.
931 (as with %history). With no arguments, it will repeat the last line.
931
932
932 Options:
933 Options:
933
934
934 -l <n> : Repeat the last n lines of input, not including the
935 -l <n> : Repeat the last n lines of input, not including the
935 current command.
936 current command.
936
937
937 -g foo : Repeat the most recent line which contains foo
938 -g foo : Repeat the most recent line which contains foo
938 """
939 """
939 opts, args = self.parse_options(parameter_s, 'l:g:', mode='string')
940 opts, args = self.parse_options(parameter_s, 'l:g:', mode='string')
940 if "l" in opts: # Last n lines
941 if "l" in opts: # Last n lines
941 n = int(opts['l'])
942 n = int(opts['l'])
942 hist = self.history_manager.get_tail(n)
943 hist = self.history_manager.get_tail(n)
943 elif "g" in opts: # Search
944 elif "g" in opts: # Search
944 p = "*"+opts['g']+"*"
945 p = "*"+opts['g']+"*"
945 hist = list(self.history_manager.search(p))
946 hist = list(self.history_manager.search(p))
946 for l in reversed(hist):
947 for l in reversed(hist):
947 if "rerun" not in l[2]:
948 if "rerun" not in l[2]:
948 hist = [l] # The last match which isn't a %rerun
949 hist = [l] # The last match which isn't a %rerun
949 break
950 break
950 else:
951 else:
951 hist = [] # No matches except %rerun
952 hist = [] # No matches except %rerun
952 elif args: # Specify history ranges
953 elif args: # Specify history ranges
953 hist = self.history_manager.get_range_by_str(args)
954 hist = self.history_manager.get_range_by_str(args)
954 else: # Last line
955 else: # Last line
955 hist = self.history_manager.get_tail(1)
956 hist = self.history_manager.get_tail(1)
956 hist = [x[2] for x in hist]
957 hist = [x[2] for x in hist]
957 if not hist:
958 if not hist:
958 print("No lines in history match specification")
959 print("No lines in history match specification")
959 return
960 return
960 histlines = "\n".join(hist)
961 histlines = "\n".join(hist)
961 print("=== Executing: ===")
962 print("=== Executing: ===")
962 print(histlines)
963 print(histlines)
963 print("=== Output: ===")
964 print("=== Output: ===")
964 self.run_cell("\n".join(hist), store_history=False)
965 self.run_cell("\n".join(hist), store_history=False)
965
966
966
967
967 def init_ipython(ip):
968 def init_ipython(ip):
968 ip.define_magic("rep", magic_rep)
969 ip.define_magic("rep", magic_rep)
969 ip.define_magic("recall", magic_rep)
970 ip.define_magic("recall", magic_rep)
970 ip.define_magic("rerun", magic_rerun)
971 ip.define_magic("rerun", magic_rerun)
971 ip.define_magic("hist",magic_history) # Alternative name
972 ip.define_magic("hist",magic_history) # Alternative name
972 ip.define_magic("history",magic_history)
973 ip.define_magic("history",magic_history)
973
974
974 # XXX - ipy_completers are in quarantine, need to be updated to new apis
975 # XXX - ipy_completers are in quarantine, need to be updated to new apis
975 #import ipy_completers
976 #import ipy_completers
976 #ipy_completers.quick_completer('%hist' ,'-g -t -r -n')
977 #ipy_completers.quick_completer('%hist' ,'-g -t -r -n')
@@ -1,2853 +1,2865 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Main IPython class."""
2 """Main IPython class."""
3
3
4 #-----------------------------------------------------------------------------
4 #-----------------------------------------------------------------------------
5 # Copyright (C) 2001 Janko Hauser <jhauser@zscout.de>
5 # Copyright (C) 2001 Janko Hauser <jhauser@zscout.de>
6 # Copyright (C) 2001-2007 Fernando Perez. <fperez@colorado.edu>
6 # Copyright (C) 2001-2007 Fernando Perez. <fperez@colorado.edu>
7 # Copyright (C) 2008-2011 The IPython Development Team
7 # Copyright (C) 2008-2011 The IPython Development Team
8 #
8 #
9 # Distributed under the terms of the BSD License. The full license is in
9 # Distributed under the terms of the BSD License. The full license is in
10 # the file COPYING, distributed as part of this software.
10 # the file COPYING, distributed as part of this software.
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16
16
17 from __future__ import with_statement
17 from __future__ import with_statement
18 from __future__ import absolute_import
18 from __future__ import absolute_import
19
19
20 import __builtin__ as builtin_mod
20 import __builtin__ as builtin_mod
21 import __future__
21 import __future__
22 import abc
22 import abc
23 import ast
23 import ast
24 import atexit
24 import atexit
25 import os
25 import os
26 import re
26 import re
27 import runpy
27 import runpy
28 import sys
28 import sys
29 import tempfile
29 import tempfile
30 import types
30 import types
31 import urllib
31 import urllib
32 from io import open as io_open
32 from io import open as io_open
33
33
34 from IPython.config.configurable import SingletonConfigurable
34 from IPython.config.configurable import SingletonConfigurable
35 from IPython.core import debugger, oinspect
35 from IPython.core import debugger, oinspect
36 from IPython.core import page
36 from IPython.core import page
37 from IPython.core import prefilter
37 from IPython.core import prefilter
38 from IPython.core import shadowns
38 from IPython.core import shadowns
39 from IPython.core import ultratb
39 from IPython.core import ultratb
40 from IPython.core.alias import AliasManager, AliasError
40 from IPython.core.alias import AliasManager, AliasError
41 from IPython.core.autocall import ExitAutocall
41 from IPython.core.autocall import ExitAutocall
42 from IPython.core.builtin_trap import BuiltinTrap
42 from IPython.core.builtin_trap import BuiltinTrap
43 from IPython.core.compilerop import CachingCompiler
43 from IPython.core.compilerop import CachingCompiler
44 from IPython.core.display_trap import DisplayTrap
44 from IPython.core.display_trap import DisplayTrap
45 from IPython.core.displayhook import DisplayHook
45 from IPython.core.displayhook import DisplayHook
46 from IPython.core.displaypub import DisplayPublisher
46 from IPython.core.displaypub import DisplayPublisher
47 from IPython.core.error import UsageError
47 from IPython.core.error import UsageError
48 from IPython.core.extensions import ExtensionManager
48 from IPython.core.extensions import ExtensionManager
49 from IPython.core.fakemodule import FakeModule, init_fakemod_dict
49 from IPython.core.fakemodule import FakeModule, init_fakemod_dict
50 from IPython.core.formatters import DisplayFormatter
50 from IPython.core.formatters import DisplayFormatter
51 from IPython.core.history import HistoryManager
51 from IPython.core.history import HistoryManager
52 from IPython.core.inputsplitter import IPythonInputSplitter
52 from IPython.core.inputsplitter import IPythonInputSplitter
53 from IPython.core.logger import Logger
53 from IPython.core.logger import Logger
54 from IPython.core.macro import Macro
54 from IPython.core.macro import Macro
55 from IPython.core.magic import Magic
55 from IPython.core.magic import Magic
56 from IPython.core.payload import PayloadManager
56 from IPython.core.payload import PayloadManager
57 from IPython.core.plugin import PluginManager
57 from IPython.core.plugin import PluginManager
58 from IPython.core.prefilter import PrefilterManager, ESC_MAGIC
58 from IPython.core.prefilter import PrefilterManager, ESC_MAGIC
59 from IPython.core.profiledir import ProfileDir
59 from IPython.core.profiledir import ProfileDir
60 from IPython.core.pylabtools import pylab_activate
60 from IPython.core.pylabtools import pylab_activate
61 from IPython.core.prompts import PromptManager
61 from IPython.core.prompts import PromptManager
62 from IPython.utils import PyColorize
62 from IPython.utils import PyColorize
63 from IPython.utils import io
63 from IPython.utils import io
64 from IPython.utils import py3compat
64 from IPython.utils import py3compat
65 from IPython.utils import openpy
65 from IPython.utils import openpy
66 from IPython.utils.doctestreload import doctest_reload
66 from IPython.utils.doctestreload import doctest_reload
67 from IPython.utils.io import ask_yes_no
67 from IPython.utils.io import ask_yes_no
68 from IPython.utils.ipstruct import Struct
68 from IPython.utils.ipstruct import Struct
69 from IPython.utils.path import get_home_dir, get_ipython_dir, get_py_filename, unquote_filename
69 from IPython.utils.path import get_home_dir, get_ipython_dir, get_py_filename, unquote_filename
70 from IPython.utils.pickleshare import PickleShareDB
70 from IPython.utils.pickleshare import PickleShareDB
71 from IPython.utils.process import system, getoutput
71 from IPython.utils.process import system, getoutput
72 from IPython.utils.strdispatch import StrDispatch
72 from IPython.utils.strdispatch import StrDispatch
73 from IPython.utils.syspathcontext import prepended_to_syspath
73 from IPython.utils.syspathcontext import prepended_to_syspath
74 from IPython.utils.text import (format_screen, LSString, SList,
74 from IPython.utils.text import (format_screen, LSString, SList,
75 DollarFormatter)
75 DollarFormatter)
76 from IPython.utils.traitlets import (Integer, CBool, CaselessStrEnum, Enum,
76 from IPython.utils.traitlets import (Integer, CBool, CaselessStrEnum, Enum,
77 List, Unicode, Instance, Type)
77 List, Unicode, Instance, Type)
78 from IPython.utils.warn import warn, error
78 from IPython.utils.warn import warn, error
79 import IPython.core.hooks
79 import IPython.core.hooks
80
80
81 #-----------------------------------------------------------------------------
81 #-----------------------------------------------------------------------------
82 # Globals
82 # Globals
83 #-----------------------------------------------------------------------------
83 #-----------------------------------------------------------------------------
84
84
85 # compiled regexps for autoindent management
85 # compiled regexps for autoindent management
86 dedent_re = re.compile(r'^\s+raise|^\s+return|^\s+pass')
86 dedent_re = re.compile(r'^\s+raise|^\s+return|^\s+pass')
87
87
88 #-----------------------------------------------------------------------------
88 #-----------------------------------------------------------------------------
89 # Utilities
89 # Utilities
90 #-----------------------------------------------------------------------------
90 #-----------------------------------------------------------------------------
91
91
92 def softspace(file, newvalue):
92 def softspace(file, newvalue):
93 """Copied from code.py, to remove the dependency"""
93 """Copied from code.py, to remove the dependency"""
94
94
95 oldvalue = 0
95 oldvalue = 0
96 try:
96 try:
97 oldvalue = file.softspace
97 oldvalue = file.softspace
98 except AttributeError:
98 except AttributeError:
99 pass
99 pass
100 try:
100 try:
101 file.softspace = newvalue
101 file.softspace = newvalue
102 except (AttributeError, TypeError):
102 except (AttributeError, TypeError):
103 # "attribute-less object" or "read-only attributes"
103 # "attribute-less object" or "read-only attributes"
104 pass
104 pass
105 return oldvalue
105 return oldvalue
106
106
107
107
108 def no_op(*a, **kw): pass
108 def no_op(*a, **kw): pass
109
109
110 class NoOpContext(object):
110 class NoOpContext(object):
111 def __enter__(self): pass
111 def __enter__(self): pass
112 def __exit__(self, type, value, traceback): pass
112 def __exit__(self, type, value, traceback): pass
113 no_op_context = NoOpContext()
113 no_op_context = NoOpContext()
114
114
115 class SpaceInInput(Exception): pass
115 class SpaceInInput(Exception): pass
116
116
117 class Bunch: pass
117 class Bunch: pass
118
118
119
119
120 def get_default_colors():
120 def get_default_colors():
121 if sys.platform=='darwin':
121 if sys.platform=='darwin':
122 return "LightBG"
122 return "LightBG"
123 elif os.name=='nt':
123 elif os.name=='nt':
124 return 'Linux'
124 return 'Linux'
125 else:
125 else:
126 return 'Linux'
126 return 'Linux'
127
127
128
128
129 class SeparateUnicode(Unicode):
129 class SeparateUnicode(Unicode):
130 """A Unicode subclass to validate separate_in, separate_out, etc.
130 """A Unicode subclass to validate separate_in, separate_out, etc.
131
131
132 This is a Unicode based trait that converts '0'->'' and '\\n'->'\n'.
132 This is a Unicode based trait that converts '0'->'' and '\\n'->'\n'.
133 """
133 """
134
134
135 def validate(self, obj, value):
135 def validate(self, obj, value):
136 if value == '0': value = ''
136 if value == '0': value = ''
137 value = value.replace('\\n','\n')
137 value = value.replace('\\n','\n')
138 return super(SeparateUnicode, self).validate(obj, value)
138 return super(SeparateUnicode, self).validate(obj, value)
139
139
140
140
141 class ReadlineNoRecord(object):
141 class ReadlineNoRecord(object):
142 """Context manager to execute some code, then reload readline history
142 """Context manager to execute some code, then reload readline history
143 so that interactive input to the code doesn't appear when pressing up."""
143 so that interactive input to the code doesn't appear when pressing up."""
144 def __init__(self, shell):
144 def __init__(self, shell):
145 self.shell = shell
145 self.shell = shell
146 self._nested_level = 0
146 self._nested_level = 0
147
147
148 def __enter__(self):
148 def __enter__(self):
149 if self._nested_level == 0:
149 if self._nested_level == 0:
150 try:
150 try:
151 self.orig_length = self.current_length()
151 self.orig_length = self.current_length()
152 self.readline_tail = self.get_readline_tail()
152 self.readline_tail = self.get_readline_tail()
153 except (AttributeError, IndexError): # Can fail with pyreadline
153 except (AttributeError, IndexError): # Can fail with pyreadline
154 self.orig_length, self.readline_tail = 999999, []
154 self.orig_length, self.readline_tail = 999999, []
155 self._nested_level += 1
155 self._nested_level += 1
156
156
157 def __exit__(self, type, value, traceback):
157 def __exit__(self, type, value, traceback):
158 self._nested_level -= 1
158 self._nested_level -= 1
159 if self._nested_level == 0:
159 if self._nested_level == 0:
160 # Try clipping the end if it's got longer
160 # Try clipping the end if it's got longer
161 try:
161 try:
162 e = self.current_length() - self.orig_length
162 e = self.current_length() - self.orig_length
163 if e > 0:
163 if e > 0:
164 for _ in range(e):
164 for _ in range(e):
165 self.shell.readline.remove_history_item(self.orig_length)
165 self.shell.readline.remove_history_item(self.orig_length)
166
166
167 # If it still doesn't match, just reload readline history.
167 # If it still doesn't match, just reload readline history.
168 if self.current_length() != self.orig_length \
168 if self.current_length() != self.orig_length \
169 or self.get_readline_tail() != self.readline_tail:
169 or self.get_readline_tail() != self.readline_tail:
170 self.shell.refill_readline_hist()
170 self.shell.refill_readline_hist()
171 except (AttributeError, IndexError):
171 except (AttributeError, IndexError):
172 pass
172 pass
173 # Returning False will cause exceptions to propagate
173 # Returning False will cause exceptions to propagate
174 return False
174 return False
175
175
176 def current_length(self):
176 def current_length(self):
177 return self.shell.readline.get_current_history_length()
177 return self.shell.readline.get_current_history_length()
178
178
179 def get_readline_tail(self, n=10):
179 def get_readline_tail(self, n=10):
180 """Get the last n items in readline history."""
180 """Get the last n items in readline history."""
181 end = self.shell.readline.get_current_history_length() + 1
181 end = self.shell.readline.get_current_history_length() + 1
182 start = max(end-n, 1)
182 start = max(end-n, 1)
183 ghi = self.shell.readline.get_history_item
183 ghi = self.shell.readline.get_history_item
184 return [ghi(x) for x in range(start, end)]
184 return [ghi(x) for x in range(start, end)]
185
185
186 #-----------------------------------------------------------------------------
186 #-----------------------------------------------------------------------------
187 # Main IPython class
187 # Main IPython class
188 #-----------------------------------------------------------------------------
188 #-----------------------------------------------------------------------------
189
189
190 class InteractiveShell(SingletonConfigurable, Magic):
190 class InteractiveShell(SingletonConfigurable, Magic):
191 """An enhanced, interactive shell for Python."""
191 """An enhanced, interactive shell for Python."""
192
192
193 _instance = None
193 _instance = None
194
194
195 autocall = Enum((0,1,2), default_value=0, config=True, help=
195 autocall = Enum((0,1,2), default_value=0, config=True, help=
196 """
196 """
197 Make IPython automatically call any callable object even if you didn't
197 Make IPython automatically call any callable object even if you didn't
198 type explicit parentheses. For example, 'str 43' becomes 'str(43)'
198 type explicit parentheses. For example, 'str 43' becomes 'str(43)'
199 automatically. The value can be '0' to disable the feature, '1' for
199 automatically. The value can be '0' to disable the feature, '1' for
200 'smart' autocall, where it is not applied if there are no more
200 'smart' autocall, where it is not applied if there are no more
201 arguments on the line, and '2' for 'full' autocall, where all callable
201 arguments on the line, and '2' for 'full' autocall, where all callable
202 objects are automatically called (even if no arguments are present).
202 objects are automatically called (even if no arguments are present).
203 """
203 """
204 )
204 )
205 # TODO: remove all autoindent logic and put into frontends.
205 # TODO: remove all autoindent logic and put into frontends.
206 # We can't do this yet because even runlines uses the autoindent.
206 # We can't do this yet because even runlines uses the autoindent.
207 autoindent = CBool(True, config=True, help=
207 autoindent = CBool(True, config=True, help=
208 """
208 """
209 Autoindent IPython code entered interactively.
209 Autoindent IPython code entered interactively.
210 """
210 """
211 )
211 )
212 automagic = CBool(True, config=True, help=
212 automagic = CBool(True, config=True, help=
213 """
213 """
214 Enable magic commands to be called without the leading %.
214 Enable magic commands to be called without the leading %.
215 """
215 """
216 )
216 )
217 cache_size = Integer(1000, config=True, help=
217 cache_size = Integer(1000, config=True, help=
218 """
218 """
219 Set the size of the output cache. The default is 1000, you can
219 Set the size of the output cache. The default is 1000, you can
220 change it permanently in your config file. Setting it to 0 completely
220 change it permanently in your config file. Setting it to 0 completely
221 disables the caching system, and the minimum value accepted is 20 (if
221 disables the caching system, and the minimum value accepted is 20 (if
222 you provide a value less than 20, it is reset to 0 and a warning is
222 you provide a value less than 20, it is reset to 0 and a warning is
223 issued). This limit is defined because otherwise you'll spend more
223 issued). This limit is defined because otherwise you'll spend more
224 time re-flushing a too small cache than working
224 time re-flushing a too small cache than working
225 """
225 """
226 )
226 )
227 color_info = CBool(True, config=True, help=
227 color_info = CBool(True, config=True, help=
228 """
228 """
229 Use colors for displaying information about objects. Because this
229 Use colors for displaying information about objects. Because this
230 information is passed through a pager (like 'less'), and some pagers
230 information is passed through a pager (like 'less'), and some pagers
231 get confused with color codes, this capability can be turned off.
231 get confused with color codes, this capability can be turned off.
232 """
232 """
233 )
233 )
234 colors = CaselessStrEnum(('NoColor','LightBG','Linux'),
234 colors = CaselessStrEnum(('NoColor','LightBG','Linux'),
235 default_value=get_default_colors(), config=True,
235 default_value=get_default_colors(), config=True,
236 help="Set the color scheme (NoColor, Linux, or LightBG)."
236 help="Set the color scheme (NoColor, Linux, or LightBG)."
237 )
237 )
238 colors_force = CBool(False, help=
238 colors_force = CBool(False, help=
239 """
239 """
240 Force use of ANSI color codes, regardless of OS and readline
240 Force use of ANSI color codes, regardless of OS and readline
241 availability.
241 availability.
242 """
242 """
243 # FIXME: This is essentially a hack to allow ZMQShell to show colors
243 # FIXME: This is essentially a hack to allow ZMQShell to show colors
244 # without readline on Win32. When the ZMQ formatting system is
244 # without readline on Win32. When the ZMQ formatting system is
245 # refactored, this should be removed.
245 # refactored, this should be removed.
246 )
246 )
247 debug = CBool(False, config=True)
247 debug = CBool(False, config=True)
248 deep_reload = CBool(False, config=True, help=
248 deep_reload = CBool(False, config=True, help=
249 """
249 """
250 Enable deep (recursive) reloading by default. IPython can use the
250 Enable deep (recursive) reloading by default. IPython can use the
251 deep_reload module which reloads changes in modules recursively (it
251 deep_reload module which reloads changes in modules recursively (it
252 replaces the reload() function, so you don't need to change anything to
252 replaces the reload() function, so you don't need to change anything to
253 use it). deep_reload() forces a full reload of modules whose code may
253 use it). deep_reload() forces a full reload of modules whose code may
254 have changed, which the default reload() function does not. When
254 have changed, which the default reload() function does not. When
255 deep_reload is off, IPython will use the normal reload(), but
255 deep_reload is off, IPython will use the normal reload(), but
256 deep_reload will still be available as dreload().
256 deep_reload will still be available as dreload().
257 """
257 """
258 )
258 )
259 disable_failing_post_execute = CBool(False, config=True,
259 disable_failing_post_execute = CBool(False, config=True,
260 help="Don't call post-execute functions that have failed in the past."""
260 help="Don't call post-execute functions that have failed in the past."""
261 )
261 )
262 display_formatter = Instance(DisplayFormatter)
262 display_formatter = Instance(DisplayFormatter)
263 displayhook_class = Type(DisplayHook)
263 displayhook_class = Type(DisplayHook)
264 display_pub_class = Type(DisplayPublisher)
264 display_pub_class = Type(DisplayPublisher)
265
265
266 exit_now = CBool(False)
266 exit_now = CBool(False)
267 exiter = Instance(ExitAutocall)
267 exiter = Instance(ExitAutocall)
268 def _exiter_default(self):
268 def _exiter_default(self):
269 return ExitAutocall(self)
269 return ExitAutocall(self)
270 # Monotonically increasing execution counter
270 # Monotonically increasing execution counter
271 execution_count = Integer(1)
271 execution_count = Integer(1)
272 filename = Unicode("<ipython console>")
272 filename = Unicode("<ipython console>")
273 ipython_dir= Unicode('', config=True) # Set to get_ipython_dir() in __init__
273 ipython_dir= Unicode('', config=True) # Set to get_ipython_dir() in __init__
274
274
275 # Input splitter, to split entire cells of input into either individual
275 # Input splitter, to split entire cells of input into either individual
276 # interactive statements or whole blocks.
276 # interactive statements or whole blocks.
277 input_splitter = Instance('IPython.core.inputsplitter.IPythonInputSplitter',
277 input_splitter = Instance('IPython.core.inputsplitter.IPythonInputSplitter',
278 (), {})
278 (), {})
279 logstart = CBool(False, config=True, help=
279 logstart = CBool(False, config=True, help=
280 """
280 """
281 Start logging to the default log file.
281 Start logging to the default log file.
282 """
282 """
283 )
283 )
284 logfile = Unicode('', config=True, help=
284 logfile = Unicode('', config=True, help=
285 """
285 """
286 The name of the logfile to use.
286 The name of the logfile to use.
287 """
287 """
288 )
288 )
289 logappend = Unicode('', config=True, help=
289 logappend = Unicode('', config=True, help=
290 """
290 """
291 Start logging to the given file in append mode.
291 Start logging to the given file in append mode.
292 """
292 """
293 )
293 )
294 object_info_string_level = Enum((0,1,2), default_value=0,
294 object_info_string_level = Enum((0,1,2), default_value=0,
295 config=True)
295 config=True)
296 pdb = CBool(False, config=True, help=
296 pdb = CBool(False, config=True, help=
297 """
297 """
298 Automatically call the pdb debugger after every exception.
298 Automatically call the pdb debugger after every exception.
299 """
299 """
300 )
300 )
301 multiline_history = CBool(sys.platform != 'win32', config=True,
301 multiline_history = CBool(sys.platform != 'win32', config=True,
302 help="Save multi-line entries as one entry in readline history"
302 help="Save multi-line entries as one entry in readline history"
303 )
303 )
304
304
305 # deprecated prompt traits:
305 # deprecated prompt traits:
306
306
307 prompt_in1 = Unicode('In [\\#]: ', config=True,
307 prompt_in1 = Unicode('In [\\#]: ', config=True,
308 help="Deprecated, use PromptManager.in_template")
308 help="Deprecated, use PromptManager.in_template")
309 prompt_in2 = Unicode(' .\\D.: ', config=True,
309 prompt_in2 = Unicode(' .\\D.: ', config=True,
310 help="Deprecated, use PromptManager.in2_template")
310 help="Deprecated, use PromptManager.in2_template")
311 prompt_out = Unicode('Out[\\#]: ', config=True,
311 prompt_out = Unicode('Out[\\#]: ', config=True,
312 help="Deprecated, use PromptManager.out_template")
312 help="Deprecated, use PromptManager.out_template")
313 prompts_pad_left = CBool(True, config=True,
313 prompts_pad_left = CBool(True, config=True,
314 help="Deprecated, use PromptManager.justify")
314 help="Deprecated, use PromptManager.justify")
315
315
316 def _prompt_trait_changed(self, name, old, new):
316 def _prompt_trait_changed(self, name, old, new):
317 table = {
317 table = {
318 'prompt_in1' : 'in_template',
318 'prompt_in1' : 'in_template',
319 'prompt_in2' : 'in2_template',
319 'prompt_in2' : 'in2_template',
320 'prompt_out' : 'out_template',
320 'prompt_out' : 'out_template',
321 'prompts_pad_left' : 'justify',
321 'prompts_pad_left' : 'justify',
322 }
322 }
323 warn("InteractiveShell.{name} is deprecated, use PromptManager.{newname}\n".format(
323 warn("InteractiveShell.{name} is deprecated, use PromptManager.{newname}\n".format(
324 name=name, newname=table[name])
324 name=name, newname=table[name])
325 )
325 )
326 # protect against weird cases where self.config may not exist:
326 # protect against weird cases where self.config may not exist:
327 if self.config is not None:
327 if self.config is not None:
328 # propagate to corresponding PromptManager trait
328 # propagate to corresponding PromptManager trait
329 setattr(self.config.PromptManager, table[name], new)
329 setattr(self.config.PromptManager, table[name], new)
330
330
331 _prompt_in1_changed = _prompt_trait_changed
331 _prompt_in1_changed = _prompt_trait_changed
332 _prompt_in2_changed = _prompt_trait_changed
332 _prompt_in2_changed = _prompt_trait_changed
333 _prompt_out_changed = _prompt_trait_changed
333 _prompt_out_changed = _prompt_trait_changed
334 _prompt_pad_left_changed = _prompt_trait_changed
334 _prompt_pad_left_changed = _prompt_trait_changed
335
335
336 show_rewritten_input = CBool(True, config=True,
336 show_rewritten_input = CBool(True, config=True,
337 help="Show rewritten input, e.g. for autocall."
337 help="Show rewritten input, e.g. for autocall."
338 )
338 )
339
339
340 quiet = CBool(False, config=True)
340 quiet = CBool(False, config=True)
341
341
342 history_length = Integer(10000, config=True)
342 history_length = Integer(10000, config=True)
343
343
344 # The readline stuff will eventually be moved to the terminal subclass
344 # The readline stuff will eventually be moved to the terminal subclass
345 # but for now, we can't do that as readline is welded in everywhere.
345 # but for now, we can't do that as readline is welded in everywhere.
346 readline_use = CBool(True, config=True)
346 readline_use = CBool(True, config=True)
347 readline_remove_delims = Unicode('-/~', config=True)
347 readline_remove_delims = Unicode('-/~', config=True)
348 # don't use \M- bindings by default, because they
348 # don't use \M- bindings by default, because they
349 # conflict with 8-bit encodings. See gh-58,gh-88
349 # conflict with 8-bit encodings. See gh-58,gh-88
350 readline_parse_and_bind = List([
350 readline_parse_and_bind = List([
351 'tab: complete',
351 'tab: complete',
352 '"\C-l": clear-screen',
352 '"\C-l": clear-screen',
353 'set show-all-if-ambiguous on',
353 'set show-all-if-ambiguous on',
354 '"\C-o": tab-insert',
354 '"\C-o": tab-insert',
355 '"\C-r": reverse-search-history',
355 '"\C-r": reverse-search-history',
356 '"\C-s": forward-search-history',
356 '"\C-s": forward-search-history',
357 '"\C-p": history-search-backward',
357 '"\C-p": history-search-backward',
358 '"\C-n": history-search-forward',
358 '"\C-n": history-search-forward',
359 '"\e[A": history-search-backward',
359 '"\e[A": history-search-backward',
360 '"\e[B": history-search-forward',
360 '"\e[B": history-search-forward',
361 '"\C-k": kill-line',
361 '"\C-k": kill-line',
362 '"\C-u": unix-line-discard',
362 '"\C-u": unix-line-discard',
363 ], allow_none=False, config=True)
363 ], allow_none=False, config=True)
364
364
365 # TODO: this part of prompt management should be moved to the frontends.
365 # TODO: this part of prompt management should be moved to the frontends.
366 # Use custom TraitTypes that convert '0'->'' and '\\n'->'\n'
366 # Use custom TraitTypes that convert '0'->'' and '\\n'->'\n'
367 separate_in = SeparateUnicode('\n', config=True)
367 separate_in = SeparateUnicode('\n', config=True)
368 separate_out = SeparateUnicode('', config=True)
368 separate_out = SeparateUnicode('', config=True)
369 separate_out2 = SeparateUnicode('', config=True)
369 separate_out2 = SeparateUnicode('', config=True)
370 wildcards_case_sensitive = CBool(True, config=True)
370 wildcards_case_sensitive = CBool(True, config=True)
371 xmode = CaselessStrEnum(('Context','Plain', 'Verbose'),
371 xmode = CaselessStrEnum(('Context','Plain', 'Verbose'),
372 default_value='Context', config=True)
372 default_value='Context', config=True)
373
373
374 # Subcomponents of InteractiveShell
374 # Subcomponents of InteractiveShell
375 alias_manager = Instance('IPython.core.alias.AliasManager')
375 alias_manager = Instance('IPython.core.alias.AliasManager')
376 prefilter_manager = Instance('IPython.core.prefilter.PrefilterManager')
376 prefilter_manager = Instance('IPython.core.prefilter.PrefilterManager')
377 builtin_trap = Instance('IPython.core.builtin_trap.BuiltinTrap')
377 builtin_trap = Instance('IPython.core.builtin_trap.BuiltinTrap')
378 display_trap = Instance('IPython.core.display_trap.DisplayTrap')
378 display_trap = Instance('IPython.core.display_trap.DisplayTrap')
379 extension_manager = Instance('IPython.core.extensions.ExtensionManager')
379 extension_manager = Instance('IPython.core.extensions.ExtensionManager')
380 plugin_manager = Instance('IPython.core.plugin.PluginManager')
380 plugin_manager = Instance('IPython.core.plugin.PluginManager')
381 payload_manager = Instance('IPython.core.payload.PayloadManager')
381 payload_manager = Instance('IPython.core.payload.PayloadManager')
382 history_manager = Instance('IPython.core.history.HistoryManager')
382 history_manager = Instance('IPython.core.history.HistoryManager')
383
383
384 profile_dir = Instance('IPython.core.application.ProfileDir')
384 profile_dir = Instance('IPython.core.application.ProfileDir')
385 @property
385 @property
386 def profile(self):
386 def profile(self):
387 if self.profile_dir is not None:
387 if self.profile_dir is not None:
388 name = os.path.basename(self.profile_dir.location)
388 name = os.path.basename(self.profile_dir.location)
389 return name.replace('profile_','')
389 return name.replace('profile_','')
390
390
391
391
392 # Private interface
392 # Private interface
393 _post_execute = Instance(dict)
393 _post_execute = Instance(dict)
394
394
395 def __init__(self, config=None, ipython_dir=None, profile_dir=None,
395 def __init__(self, config=None, ipython_dir=None, profile_dir=None,
396 user_module=None, user_ns=None,
396 user_module=None, user_ns=None,
397 custom_exceptions=((), None)):
397 custom_exceptions=((), None)):
398
398
399 # This is where traits with a config_key argument are updated
399 # This is where traits with a config_key argument are updated
400 # from the values on config.
400 # from the values on config.
401 super(InteractiveShell, self).__init__(config=config)
401 super(InteractiveShell, self).__init__(config=config)
402 self.configurables = [self]
402 self.configurables = [self]
403
403
404 # These are relatively independent and stateless
404 # These are relatively independent and stateless
405 self.init_ipython_dir(ipython_dir)
405 self.init_ipython_dir(ipython_dir)
406 self.init_profile_dir(profile_dir)
406 self.init_profile_dir(profile_dir)
407 self.init_instance_attrs()
407 self.init_instance_attrs()
408 self.init_environment()
408 self.init_environment()
409
409
410 # Check if we're in a virtualenv, and set up sys.path.
410 # Check if we're in a virtualenv, and set up sys.path.
411 self.init_virtualenv()
411 self.init_virtualenv()
412
412
413 # Create namespaces (user_ns, user_global_ns, etc.)
413 # Create namespaces (user_ns, user_global_ns, etc.)
414 self.init_create_namespaces(user_module, user_ns)
414 self.init_create_namespaces(user_module, user_ns)
415 # This has to be done after init_create_namespaces because it uses
415 # This has to be done after init_create_namespaces because it uses
416 # something in self.user_ns, but before init_sys_modules, which
416 # something in self.user_ns, but before init_sys_modules, which
417 # is the first thing to modify sys.
417 # is the first thing to modify sys.
418 # TODO: When we override sys.stdout and sys.stderr before this class
418 # TODO: When we override sys.stdout and sys.stderr before this class
419 # is created, we are saving the overridden ones here. Not sure if this
419 # is created, we are saving the overridden ones here. Not sure if this
420 # is what we want to do.
420 # is what we want to do.
421 self.save_sys_module_state()
421 self.save_sys_module_state()
422 self.init_sys_modules()
422 self.init_sys_modules()
423
423
424 # While we're trying to have each part of the code directly access what
424 # While we're trying to have each part of the code directly access what
425 # it needs without keeping redundant references to objects, we have too
425 # it needs without keeping redundant references to objects, we have too
426 # much legacy code that expects ip.db to exist.
426 # much legacy code that expects ip.db to exist.
427 self.db = PickleShareDB(os.path.join(self.profile_dir.location, 'db'))
427 self.db = PickleShareDB(os.path.join(self.profile_dir.location, 'db'))
428
428
429 self.init_history()
429 self.init_history()
430 self.init_encoding()
430 self.init_encoding()
431 self.init_prefilter()
431 self.init_prefilter()
432
432
433 Magic.__init__(self, self)
433 Magic.__init__(self, self)
434
434
435 self.init_syntax_highlighting()
435 self.init_syntax_highlighting()
436 self.init_hooks()
436 self.init_hooks()
437 self.init_pushd_popd_magic()
437 self.init_pushd_popd_magic()
438 # self.init_traceback_handlers use to be here, but we moved it below
438 # self.init_traceback_handlers use to be here, but we moved it below
439 # because it and init_io have to come after init_readline.
439 # because it and init_io have to come after init_readline.
440 self.init_user_ns()
440 self.init_user_ns()
441 self.init_logger()
441 self.init_logger()
442 self.init_alias()
442 self.init_alias()
443 self.init_builtins()
443 self.init_builtins()
444
444
445 # pre_config_initialization
445 # pre_config_initialization
446
446
447 # The next section should contain everything that was in ipmaker.
447 # The next section should contain everything that was in ipmaker.
448 self.init_logstart()
448 self.init_logstart()
449
449
450 # The following was in post_config_initialization
450 # The following was in post_config_initialization
451 self.init_inspector()
451 self.init_inspector()
452 # init_readline() must come before init_io(), because init_io uses
452 # init_readline() must come before init_io(), because init_io uses
453 # readline related things.
453 # readline related things.
454 self.init_readline()
454 self.init_readline()
455 # We save this here in case user code replaces raw_input, but it needs
455 # We save this here in case user code replaces raw_input, but it needs
456 # to be after init_readline(), because PyPy's readline works by replacing
456 # to be after init_readline(), because PyPy's readline works by replacing
457 # raw_input.
457 # raw_input.
458 if py3compat.PY3:
458 if py3compat.PY3:
459 self.raw_input_original = input
459 self.raw_input_original = input
460 else:
460 else:
461 self.raw_input_original = raw_input
461 self.raw_input_original = raw_input
462 # init_completer must come after init_readline, because it needs to
462 # init_completer must come after init_readline, because it needs to
463 # know whether readline is present or not system-wide to configure the
463 # know whether readline is present or not system-wide to configure the
464 # completers, since the completion machinery can now operate
464 # completers, since the completion machinery can now operate
465 # independently of readline (e.g. over the network)
465 # independently of readline (e.g. over the network)
466 self.init_completer()
466 self.init_completer()
467 # TODO: init_io() needs to happen before init_traceback handlers
467 # TODO: init_io() needs to happen before init_traceback handlers
468 # because the traceback handlers hardcode the stdout/stderr streams.
468 # because the traceback handlers hardcode the stdout/stderr streams.
469 # This logic in in debugger.Pdb and should eventually be changed.
469 # This logic in in debugger.Pdb and should eventually be changed.
470 self.init_io()
470 self.init_io()
471 self.init_traceback_handlers(custom_exceptions)
471 self.init_traceback_handlers(custom_exceptions)
472 self.init_prompts()
472 self.init_prompts()
473 self.init_display_formatter()
473 self.init_display_formatter()
474 self.init_display_pub()
474 self.init_display_pub()
475 self.init_displayhook()
475 self.init_displayhook()
476 self.init_reload_doctest()
476 self.init_reload_doctest()
477 self.init_magics()
477 self.init_magics()
478 self.init_pdb()
478 self.init_pdb()
479 self.init_extension_manager()
479 self.init_extension_manager()
480 self.init_plugin_manager()
480 self.init_plugin_manager()
481 self.init_payload()
481 self.init_payload()
482 self.hooks.late_startup_hook()
482 self.hooks.late_startup_hook()
483 atexit.register(self.atexit_operations)
483 atexit.register(self.atexit_operations)
484
484
485 def get_ipython(self):
485 def get_ipython(self):
486 """Return the currently running IPython instance."""
486 """Return the currently running IPython instance."""
487 return self
487 return self
488
488
489 #-------------------------------------------------------------------------
489 #-------------------------------------------------------------------------
490 # Trait changed handlers
490 # Trait changed handlers
491 #-------------------------------------------------------------------------
491 #-------------------------------------------------------------------------
492
492
493 def _ipython_dir_changed(self, name, new):
493 def _ipython_dir_changed(self, name, new):
494 if not os.path.isdir(new):
494 if not os.path.isdir(new):
495 os.makedirs(new, mode = 0777)
495 os.makedirs(new, mode = 0777)
496
496
497 def set_autoindent(self,value=None):
497 def set_autoindent(self,value=None):
498 """Set the autoindent flag, checking for readline support.
498 """Set the autoindent flag, checking for readline support.
499
499
500 If called with no arguments, it acts as a toggle."""
500 If called with no arguments, it acts as a toggle."""
501
501
502 if value != 0 and not self.has_readline:
502 if value != 0 and not self.has_readline:
503 if os.name == 'posix':
503 if os.name == 'posix':
504 warn("The auto-indent feature requires the readline library")
504 warn("The auto-indent feature requires the readline library")
505 self.autoindent = 0
505 self.autoindent = 0
506 return
506 return
507 if value is None:
507 if value is None:
508 self.autoindent = not self.autoindent
508 self.autoindent = not self.autoindent
509 else:
509 else:
510 self.autoindent = value
510 self.autoindent = value
511
511
512 #-------------------------------------------------------------------------
512 #-------------------------------------------------------------------------
513 # init_* methods called by __init__
513 # init_* methods called by __init__
514 #-------------------------------------------------------------------------
514 #-------------------------------------------------------------------------
515
515
516 def init_ipython_dir(self, ipython_dir):
516 def init_ipython_dir(self, ipython_dir):
517 if ipython_dir is not None:
517 if ipython_dir is not None:
518 self.ipython_dir = ipython_dir
518 self.ipython_dir = ipython_dir
519 return
519 return
520
520
521 self.ipython_dir = get_ipython_dir()
521 self.ipython_dir = get_ipython_dir()
522
522
523 def init_profile_dir(self, profile_dir):
523 def init_profile_dir(self, profile_dir):
524 if profile_dir is not None:
524 if profile_dir is not None:
525 self.profile_dir = profile_dir
525 self.profile_dir = profile_dir
526 return
526 return
527 self.profile_dir =\
527 self.profile_dir =\
528 ProfileDir.create_profile_dir_by_name(self.ipython_dir, 'default')
528 ProfileDir.create_profile_dir_by_name(self.ipython_dir, 'default')
529
529
530 def init_instance_attrs(self):
530 def init_instance_attrs(self):
531 self.more = False
531 self.more = False
532
532
533 # command compiler
533 # command compiler
534 self.compile = CachingCompiler()
534 self.compile = CachingCompiler()
535
535
536 # Make an empty namespace, which extension writers can rely on both
536 # Make an empty namespace, which extension writers can rely on both
537 # existing and NEVER being used by ipython itself. This gives them a
537 # existing and NEVER being used by ipython itself. This gives them a
538 # convenient location for storing additional information and state
538 # convenient location for storing additional information and state
539 # their extensions may require, without fear of collisions with other
539 # their extensions may require, without fear of collisions with other
540 # ipython names that may develop later.
540 # ipython names that may develop later.
541 self.meta = Struct()
541 self.meta = Struct()
542
542
543 # Temporary files used for various purposes. Deleted at exit.
543 # Temporary files used for various purposes. Deleted at exit.
544 self.tempfiles = []
544 self.tempfiles = []
545
545
546 # Keep track of readline usage (later set by init_readline)
546 # Keep track of readline usage (later set by init_readline)
547 self.has_readline = False
547 self.has_readline = False
548
548
549 # keep track of where we started running (mainly for crash post-mortem)
549 # keep track of where we started running (mainly for crash post-mortem)
550 # This is not being used anywhere currently.
550 # This is not being used anywhere currently.
551 self.starting_dir = os.getcwdu()
551 self.starting_dir = os.getcwdu()
552
552
553 # Indentation management
553 # Indentation management
554 self.indent_current_nsp = 0
554 self.indent_current_nsp = 0
555
555
556 # Dict to track post-execution functions that have been registered
556 # Dict to track post-execution functions that have been registered
557 self._post_execute = {}
557 self._post_execute = {}
558
558
559 def init_environment(self):
559 def init_environment(self):
560 """Any changes we need to make to the user's environment."""
560 """Any changes we need to make to the user's environment."""
561 pass
561 pass
562
562
563 def init_encoding(self):
563 def init_encoding(self):
564 # Get system encoding at startup time. Certain terminals (like Emacs
564 # Get system encoding at startup time. Certain terminals (like Emacs
565 # under Win32 have it set to None, and we need to have a known valid
565 # under Win32 have it set to None, and we need to have a known valid
566 # encoding to use in the raw_input() method
566 # encoding to use in the raw_input() method
567 try:
567 try:
568 self.stdin_encoding = sys.stdin.encoding or 'ascii'
568 self.stdin_encoding = sys.stdin.encoding or 'ascii'
569 except AttributeError:
569 except AttributeError:
570 self.stdin_encoding = 'ascii'
570 self.stdin_encoding = 'ascii'
571
571
572 def init_syntax_highlighting(self):
572 def init_syntax_highlighting(self):
573 # Python source parser/formatter for syntax highlighting
573 # Python source parser/formatter for syntax highlighting
574 pyformat = PyColorize.Parser().format
574 pyformat = PyColorize.Parser().format
575 self.pycolorize = lambda src: pyformat(src,'str',self.colors)
575 self.pycolorize = lambda src: pyformat(src,'str',self.colors)
576
576
577 def init_pushd_popd_magic(self):
577 def init_pushd_popd_magic(self):
578 # for pushd/popd management
578 # for pushd/popd management
579 self.home_dir = get_home_dir()
579 self.home_dir = get_home_dir()
580
580
581 self.dir_stack = []
581 self.dir_stack = []
582
582
583 def init_logger(self):
583 def init_logger(self):
584 self.logger = Logger(self.home_dir, logfname='ipython_log.py',
584 self.logger = Logger(self.home_dir, logfname='ipython_log.py',
585 logmode='rotate')
585 logmode='rotate')
586
586
587 def init_logstart(self):
587 def init_logstart(self):
588 """Initialize logging in case it was requested at the command line.
588 """Initialize logging in case it was requested at the command line.
589 """
589 """
590 if self.logappend:
590 if self.logappend:
591 self.magic_logstart(self.logappend + ' append')
591 self.magic_logstart(self.logappend + ' append')
592 elif self.logfile:
592 elif self.logfile:
593 self.magic_logstart(self.logfile)
593 self.magic_logstart(self.logfile)
594 elif self.logstart:
594 elif self.logstart:
595 self.magic_logstart()
595 self.magic_logstart()
596
596
597 def init_builtins(self):
597 def init_builtins(self):
598 # A single, static flag that we set to True. Its presence indicates
598 # A single, static flag that we set to True. Its presence indicates
599 # that an IPython shell has been created, and we make no attempts at
599 # that an IPython shell has been created, and we make no attempts at
600 # removing on exit or representing the existence of more than one
600 # removing on exit or representing the existence of more than one
601 # IPython at a time.
601 # IPython at a time.
602 builtin_mod.__dict__['__IPYTHON__'] = True
602 builtin_mod.__dict__['__IPYTHON__'] = True
603
603
604 # In 0.11 we introduced '__IPYTHON__active' as an integer we'd try to
604 # In 0.11 we introduced '__IPYTHON__active' as an integer we'd try to
605 # manage on enter/exit, but with all our shells it's virtually
605 # manage on enter/exit, but with all our shells it's virtually
606 # impossible to get all the cases right. We're leaving the name in for
606 # impossible to get all the cases right. We're leaving the name in for
607 # those who adapted their codes to check for this flag, but will
607 # those who adapted their codes to check for this flag, but will
608 # eventually remove it after a few more releases.
608 # eventually remove it after a few more releases.
609 builtin_mod.__dict__['__IPYTHON__active'] = \
609 builtin_mod.__dict__['__IPYTHON__active'] = \
610 'Deprecated, check for __IPYTHON__'
610 'Deprecated, check for __IPYTHON__'
611
611
612 self.builtin_trap = BuiltinTrap(shell=self)
612 self.builtin_trap = BuiltinTrap(shell=self)
613
613
614 def init_inspector(self):
614 def init_inspector(self):
615 # Object inspector
615 # Object inspector
616 self.inspector = oinspect.Inspector(oinspect.InspectColors,
616 self.inspector = oinspect.Inspector(oinspect.InspectColors,
617 PyColorize.ANSICodeColors,
617 PyColorize.ANSICodeColors,
618 'NoColor',
618 'NoColor',
619 self.object_info_string_level)
619 self.object_info_string_level)
620
620
621 def init_io(self):
621 def init_io(self):
622 # This will just use sys.stdout and sys.stderr. If you want to
622 # This will just use sys.stdout and sys.stderr. If you want to
623 # override sys.stdout and sys.stderr themselves, you need to do that
623 # override sys.stdout and sys.stderr themselves, you need to do that
624 # *before* instantiating this class, because io holds onto
624 # *before* instantiating this class, because io holds onto
625 # references to the underlying streams.
625 # references to the underlying streams.
626 if sys.platform == 'win32' and self.has_readline:
626 if sys.platform == 'win32' and self.has_readline:
627 io.stdout = io.stderr = io.IOStream(self.readline._outputfile)
627 io.stdout = io.stderr = io.IOStream(self.readline._outputfile)
628 else:
628 else:
629 io.stdout = io.IOStream(sys.stdout)
629 io.stdout = io.IOStream(sys.stdout)
630 io.stderr = io.IOStream(sys.stderr)
630 io.stderr = io.IOStream(sys.stderr)
631
631
632 def init_prompts(self):
632 def init_prompts(self):
633 self.prompt_manager = PromptManager(shell=self, config=self.config)
633 self.prompt_manager = PromptManager(shell=self, config=self.config)
634 self.configurables.append(self.prompt_manager)
634 self.configurables.append(self.prompt_manager)
635 # Set system prompts, so that scripts can decide if they are running
635 # Set system prompts, so that scripts can decide if they are running
636 # interactively.
636 # interactively.
637 sys.ps1 = 'In : '
637 sys.ps1 = 'In : '
638 sys.ps2 = '...: '
638 sys.ps2 = '...: '
639 sys.ps3 = 'Out: '
639 sys.ps3 = 'Out: '
640
640
641 def init_display_formatter(self):
641 def init_display_formatter(self):
642 self.display_formatter = DisplayFormatter(config=self.config)
642 self.display_formatter = DisplayFormatter(config=self.config)
643 self.configurables.append(self.display_formatter)
643 self.configurables.append(self.display_formatter)
644
644
645 def init_display_pub(self):
645 def init_display_pub(self):
646 self.display_pub = self.display_pub_class(config=self.config)
646 self.display_pub = self.display_pub_class(config=self.config)
647 self.configurables.append(self.display_pub)
647 self.configurables.append(self.display_pub)
648
648
649 def init_displayhook(self):
649 def init_displayhook(self):
650 # Initialize displayhook, set in/out prompts and printing system
650 # Initialize displayhook, set in/out prompts and printing system
651 self.displayhook = self.displayhook_class(
651 self.displayhook = self.displayhook_class(
652 config=self.config,
652 config=self.config,
653 shell=self,
653 shell=self,
654 cache_size=self.cache_size,
654 cache_size=self.cache_size,
655 )
655 )
656 self.configurables.append(self.displayhook)
656 self.configurables.append(self.displayhook)
657 # This is a context manager that installs/revmoes the displayhook at
657 # This is a context manager that installs/revmoes the displayhook at
658 # the appropriate time.
658 # the appropriate time.
659 self.display_trap = DisplayTrap(hook=self.displayhook)
659 self.display_trap = DisplayTrap(hook=self.displayhook)
660
660
661 def init_reload_doctest(self):
661 def init_reload_doctest(self):
662 # Do a proper resetting of doctest, including the necessary displayhook
662 # Do a proper resetting of doctest, including the necessary displayhook
663 # monkeypatching
663 # monkeypatching
664 try:
664 try:
665 doctest_reload()
665 doctest_reload()
666 except ImportError:
666 except ImportError:
667 warn("doctest module does not exist.")
667 warn("doctest module does not exist.")
668
668
669 def init_virtualenv(self):
669 def init_virtualenv(self):
670 """Add a virtualenv to sys.path so the user can import modules from it.
670 """Add a virtualenv to sys.path so the user can import modules from it.
671 This isn't perfect: it doesn't use the Python interpreter with which the
671 This isn't perfect: it doesn't use the Python interpreter with which the
672 virtualenv was built, and it ignores the --no-site-packages option. A
672 virtualenv was built, and it ignores the --no-site-packages option. A
673 warning will appear suggesting the user installs IPython in the
673 warning will appear suggesting the user installs IPython in the
674 virtualenv, but for many cases, it probably works well enough.
674 virtualenv, but for many cases, it probably works well enough.
675
675
676 Adapted from code snippets online.
676 Adapted from code snippets online.
677
677
678 http://blog.ufsoft.org/2009/1/29/ipython-and-virtualenv
678 http://blog.ufsoft.org/2009/1/29/ipython-and-virtualenv
679 """
679 """
680 if 'VIRTUAL_ENV' not in os.environ:
680 if 'VIRTUAL_ENV' not in os.environ:
681 # Not in a virtualenv
681 # Not in a virtualenv
682 return
682 return
683
683
684 if sys.executable.startswith(os.environ['VIRTUAL_ENV']):
684 if sys.executable.startswith(os.environ['VIRTUAL_ENV']):
685 # Running properly in the virtualenv, don't need to do anything
685 # Running properly in the virtualenv, don't need to do anything
686 return
686 return
687
687
688 warn("Attempting to work in a virtualenv. If you encounter problems, please "
688 warn("Attempting to work in a virtualenv. If you encounter problems, please "
689 "install IPython inside the virtualenv.\n")
689 "install IPython inside the virtualenv.\n")
690 if sys.platform == "win32":
690 if sys.platform == "win32":
691 virtual_env = os.path.join(os.environ['VIRTUAL_ENV'], 'Lib', 'site-packages')
691 virtual_env = os.path.join(os.environ['VIRTUAL_ENV'], 'Lib', 'site-packages')
692 else:
692 else:
693 virtual_env = os.path.join(os.environ['VIRTUAL_ENV'], 'lib',
693 virtual_env = os.path.join(os.environ['VIRTUAL_ENV'], 'lib',
694 'python%d.%d' % sys.version_info[:2], 'site-packages')
694 'python%d.%d' % sys.version_info[:2], 'site-packages')
695
695
696 import site
696 import site
697 sys.path.insert(0, virtual_env)
697 sys.path.insert(0, virtual_env)
698 site.addsitedir(virtual_env)
698 site.addsitedir(virtual_env)
699
699
700 #-------------------------------------------------------------------------
700 #-------------------------------------------------------------------------
701 # Things related to injections into the sys module
701 # Things related to injections into the sys module
702 #-------------------------------------------------------------------------
702 #-------------------------------------------------------------------------
703
703
704 def save_sys_module_state(self):
704 def save_sys_module_state(self):
705 """Save the state of hooks in the sys module.
705 """Save the state of hooks in the sys module.
706
706
707 This has to be called after self.user_module is created.
707 This has to be called after self.user_module is created.
708 """
708 """
709 self._orig_sys_module_state = {}
709 self._orig_sys_module_state = {}
710 self._orig_sys_module_state['stdin'] = sys.stdin
710 self._orig_sys_module_state['stdin'] = sys.stdin
711 self._orig_sys_module_state['stdout'] = sys.stdout
711 self._orig_sys_module_state['stdout'] = sys.stdout
712 self._orig_sys_module_state['stderr'] = sys.stderr
712 self._orig_sys_module_state['stderr'] = sys.stderr
713 self._orig_sys_module_state['excepthook'] = sys.excepthook
713 self._orig_sys_module_state['excepthook'] = sys.excepthook
714 self._orig_sys_modules_main_name = self.user_module.__name__
714 self._orig_sys_modules_main_name = self.user_module.__name__
715 self._orig_sys_modules_main_mod = sys.modules.get(self.user_module.__name__)
715
716
716 def restore_sys_module_state(self):
717 def restore_sys_module_state(self):
717 """Restore the state of the sys module."""
718 """Restore the state of the sys module."""
718 try:
719 try:
719 for k, v in self._orig_sys_module_state.iteritems():
720 for k, v in self._orig_sys_module_state.iteritems():
720 setattr(sys, k, v)
721 setattr(sys, k, v)
721 except AttributeError:
722 except AttributeError:
722 pass
723 pass
723 # Reset what what done in self.init_sys_modules
724 # Reset what what done in self.init_sys_modules
724 sys.modules[self.user_module.__name__] = self._orig_sys_modules_main_name
725 if self._orig_sys_modules_main_mod is not None:
726 sys.modules[self._orig_sys_modules_main_name] = self._orig_sys_modules_main_mod
725
727
726 #-------------------------------------------------------------------------
728 #-------------------------------------------------------------------------
727 # Things related to hooks
729 # Things related to hooks
728 #-------------------------------------------------------------------------
730 #-------------------------------------------------------------------------
729
731
730 def init_hooks(self):
732 def init_hooks(self):
731 # hooks holds pointers used for user-side customizations
733 # hooks holds pointers used for user-side customizations
732 self.hooks = Struct()
734 self.hooks = Struct()
733
735
734 self.strdispatchers = {}
736 self.strdispatchers = {}
735
737
736 # Set all default hooks, defined in the IPython.hooks module.
738 # Set all default hooks, defined in the IPython.hooks module.
737 hooks = IPython.core.hooks
739 hooks = IPython.core.hooks
738 for hook_name in hooks.__all__:
740 for hook_name in hooks.__all__:
739 # default hooks have priority 100, i.e. low; user hooks should have
741 # default hooks have priority 100, i.e. low; user hooks should have
740 # 0-100 priority
742 # 0-100 priority
741 self.set_hook(hook_name,getattr(hooks,hook_name), 100)
743 self.set_hook(hook_name,getattr(hooks,hook_name), 100)
742
744
743 def set_hook(self,name,hook, priority = 50, str_key = None, re_key = None):
745 def set_hook(self,name,hook, priority = 50, str_key = None, re_key = None):
744 """set_hook(name,hook) -> sets an internal IPython hook.
746 """set_hook(name,hook) -> sets an internal IPython hook.
745
747
746 IPython exposes some of its internal API as user-modifiable hooks. By
748 IPython exposes some of its internal API as user-modifiable hooks. By
747 adding your function to one of these hooks, you can modify IPython's
749 adding your function to one of these hooks, you can modify IPython's
748 behavior to call at runtime your own routines."""
750 behavior to call at runtime your own routines."""
749
751
750 # At some point in the future, this should validate the hook before it
752 # At some point in the future, this should validate the hook before it
751 # accepts it. Probably at least check that the hook takes the number
753 # accepts it. Probably at least check that the hook takes the number
752 # of args it's supposed to.
754 # of args it's supposed to.
753
755
754 f = types.MethodType(hook,self)
756 f = types.MethodType(hook,self)
755
757
756 # check if the hook is for strdispatcher first
758 # check if the hook is for strdispatcher first
757 if str_key is not None:
759 if str_key is not None:
758 sdp = self.strdispatchers.get(name, StrDispatch())
760 sdp = self.strdispatchers.get(name, StrDispatch())
759 sdp.add_s(str_key, f, priority )
761 sdp.add_s(str_key, f, priority )
760 self.strdispatchers[name] = sdp
762 self.strdispatchers[name] = sdp
761 return
763 return
762 if re_key is not None:
764 if re_key is not None:
763 sdp = self.strdispatchers.get(name, StrDispatch())
765 sdp = self.strdispatchers.get(name, StrDispatch())
764 sdp.add_re(re.compile(re_key), f, priority )
766 sdp.add_re(re.compile(re_key), f, priority )
765 self.strdispatchers[name] = sdp
767 self.strdispatchers[name] = sdp
766 return
768 return
767
769
768 dp = getattr(self.hooks, name, None)
770 dp = getattr(self.hooks, name, None)
769 if name not in IPython.core.hooks.__all__:
771 if name not in IPython.core.hooks.__all__:
770 print "Warning! Hook '%s' is not one of %s" % \
772 print "Warning! Hook '%s' is not one of %s" % \
771 (name, IPython.core.hooks.__all__ )
773 (name, IPython.core.hooks.__all__ )
772 if not dp:
774 if not dp:
773 dp = IPython.core.hooks.CommandChainDispatcher()
775 dp = IPython.core.hooks.CommandChainDispatcher()
774
776
775 try:
777 try:
776 dp.add(f,priority)
778 dp.add(f,priority)
777 except AttributeError:
779 except AttributeError:
778 # it was not commandchain, plain old func - replace
780 # it was not commandchain, plain old func - replace
779 dp = f
781 dp = f
780
782
781 setattr(self.hooks,name, dp)
783 setattr(self.hooks,name, dp)
782
784
783 def register_post_execute(self, func):
785 def register_post_execute(self, func):
784 """Register a function for calling after code execution.
786 """Register a function for calling after code execution.
785 """
787 """
786 if not callable(func):
788 if not callable(func):
787 raise ValueError('argument %s must be callable' % func)
789 raise ValueError('argument %s must be callable' % func)
788 self._post_execute[func] = True
790 self._post_execute[func] = True
789
791
790 #-------------------------------------------------------------------------
792 #-------------------------------------------------------------------------
791 # Things related to the "main" module
793 # Things related to the "main" module
792 #-------------------------------------------------------------------------
794 #-------------------------------------------------------------------------
793
795
794 def new_main_mod(self,ns=None):
796 def new_main_mod(self,ns=None):
795 """Return a new 'main' module object for user code execution.
797 """Return a new 'main' module object for user code execution.
796 """
798 """
797 main_mod = self._user_main_module
799 main_mod = self._user_main_module
798 init_fakemod_dict(main_mod,ns)
800 init_fakemod_dict(main_mod,ns)
799 return main_mod
801 return main_mod
800
802
801 def cache_main_mod(self,ns,fname):
803 def cache_main_mod(self,ns,fname):
802 """Cache a main module's namespace.
804 """Cache a main module's namespace.
803
805
804 When scripts are executed via %run, we must keep a reference to the
806 When scripts are executed via %run, we must keep a reference to the
805 namespace of their __main__ module (a FakeModule instance) around so
807 namespace of their __main__ module (a FakeModule instance) around so
806 that Python doesn't clear it, rendering objects defined therein
808 that Python doesn't clear it, rendering objects defined therein
807 useless.
809 useless.
808
810
809 This method keeps said reference in a private dict, keyed by the
811 This method keeps said reference in a private dict, keyed by the
810 absolute path of the module object (which corresponds to the script
812 absolute path of the module object (which corresponds to the script
811 path). This way, for multiple executions of the same script we only
813 path). This way, for multiple executions of the same script we only
812 keep one copy of the namespace (the last one), thus preventing memory
814 keep one copy of the namespace (the last one), thus preventing memory
813 leaks from old references while allowing the objects from the last
815 leaks from old references while allowing the objects from the last
814 execution to be accessible.
816 execution to be accessible.
815
817
816 Note: we can not allow the actual FakeModule instances to be deleted,
818 Note: we can not allow the actual FakeModule instances to be deleted,
817 because of how Python tears down modules (it hard-sets all their
819 because of how Python tears down modules (it hard-sets all their
818 references to None without regard for reference counts). This method
820 references to None without regard for reference counts). This method
819 must therefore make a *copy* of the given namespace, to allow the
821 must therefore make a *copy* of the given namespace, to allow the
820 original module's __dict__ to be cleared and reused.
822 original module's __dict__ to be cleared and reused.
821
823
822
824
823 Parameters
825 Parameters
824 ----------
826 ----------
825 ns : a namespace (a dict, typically)
827 ns : a namespace (a dict, typically)
826
828
827 fname : str
829 fname : str
828 Filename associated with the namespace.
830 Filename associated with the namespace.
829
831
830 Examples
832 Examples
831 --------
833 --------
832
834
833 In [10]: import IPython
835 In [10]: import IPython
834
836
835 In [11]: _ip.cache_main_mod(IPython.__dict__,IPython.__file__)
837 In [11]: _ip.cache_main_mod(IPython.__dict__,IPython.__file__)
836
838
837 In [12]: IPython.__file__ in _ip._main_ns_cache
839 In [12]: IPython.__file__ in _ip._main_ns_cache
838 Out[12]: True
840 Out[12]: True
839 """
841 """
840 self._main_ns_cache[os.path.abspath(fname)] = ns.copy()
842 self._main_ns_cache[os.path.abspath(fname)] = ns.copy()
841
843
842 def clear_main_mod_cache(self):
844 def clear_main_mod_cache(self):
843 """Clear the cache of main modules.
845 """Clear the cache of main modules.
844
846
845 Mainly for use by utilities like %reset.
847 Mainly for use by utilities like %reset.
846
848
847 Examples
849 Examples
848 --------
850 --------
849
851
850 In [15]: import IPython
852 In [15]: import IPython
851
853
852 In [16]: _ip.cache_main_mod(IPython.__dict__,IPython.__file__)
854 In [16]: _ip.cache_main_mod(IPython.__dict__,IPython.__file__)
853
855
854 In [17]: len(_ip._main_ns_cache) > 0
856 In [17]: len(_ip._main_ns_cache) > 0
855 Out[17]: True
857 Out[17]: True
856
858
857 In [18]: _ip.clear_main_mod_cache()
859 In [18]: _ip.clear_main_mod_cache()
858
860
859 In [19]: len(_ip._main_ns_cache) == 0
861 In [19]: len(_ip._main_ns_cache) == 0
860 Out[19]: True
862 Out[19]: True
861 """
863 """
862 self._main_ns_cache.clear()
864 self._main_ns_cache.clear()
863
865
864 #-------------------------------------------------------------------------
866 #-------------------------------------------------------------------------
865 # Things related to debugging
867 # Things related to debugging
866 #-------------------------------------------------------------------------
868 #-------------------------------------------------------------------------
867
869
868 def init_pdb(self):
870 def init_pdb(self):
869 # Set calling of pdb on exceptions
871 # Set calling of pdb on exceptions
870 # self.call_pdb is a property
872 # self.call_pdb is a property
871 self.call_pdb = self.pdb
873 self.call_pdb = self.pdb
872
874
873 def _get_call_pdb(self):
875 def _get_call_pdb(self):
874 return self._call_pdb
876 return self._call_pdb
875
877
876 def _set_call_pdb(self,val):
878 def _set_call_pdb(self,val):
877
879
878 if val not in (0,1,False,True):
880 if val not in (0,1,False,True):
879 raise ValueError,'new call_pdb value must be boolean'
881 raise ValueError,'new call_pdb value must be boolean'
880
882
881 # store value in instance
883 # store value in instance
882 self._call_pdb = val
884 self._call_pdb = val
883
885
884 # notify the actual exception handlers
886 # notify the actual exception handlers
885 self.InteractiveTB.call_pdb = val
887 self.InteractiveTB.call_pdb = val
886
888
887 call_pdb = property(_get_call_pdb,_set_call_pdb,None,
889 call_pdb = property(_get_call_pdb,_set_call_pdb,None,
888 'Control auto-activation of pdb at exceptions')
890 'Control auto-activation of pdb at exceptions')
889
891
890 def debugger(self,force=False):
892 def debugger(self,force=False):
891 """Call the pydb/pdb debugger.
893 """Call the pydb/pdb debugger.
892
894
893 Keywords:
895 Keywords:
894
896
895 - force(False): by default, this routine checks the instance call_pdb
897 - force(False): by default, this routine checks the instance call_pdb
896 flag and does not actually invoke the debugger if the flag is false.
898 flag and does not actually invoke the debugger if the flag is false.
897 The 'force' option forces the debugger to activate even if the flag
899 The 'force' option forces the debugger to activate even if the flag
898 is false.
900 is false.
899 """
901 """
900
902
901 if not (force or self.call_pdb):
903 if not (force or self.call_pdb):
902 return
904 return
903
905
904 if not hasattr(sys,'last_traceback'):
906 if not hasattr(sys,'last_traceback'):
905 error('No traceback has been produced, nothing to debug.')
907 error('No traceback has been produced, nothing to debug.')
906 return
908 return
907
909
908 # use pydb if available
910 # use pydb if available
909 if debugger.has_pydb:
911 if debugger.has_pydb:
910 from pydb import pm
912 from pydb import pm
911 else:
913 else:
912 # fallback to our internal debugger
914 # fallback to our internal debugger
913 pm = lambda : self.InteractiveTB.debugger(force=True)
915 pm = lambda : self.InteractiveTB.debugger(force=True)
914
916
915 with self.readline_no_record:
917 with self.readline_no_record:
916 pm()
918 pm()
917
919
918 #-------------------------------------------------------------------------
920 #-------------------------------------------------------------------------
919 # Things related to IPython's various namespaces
921 # Things related to IPython's various namespaces
920 #-------------------------------------------------------------------------
922 #-------------------------------------------------------------------------
921 default_user_namespaces = True
923 default_user_namespaces = True
922
924
923 def init_create_namespaces(self, user_module=None, user_ns=None):
925 def init_create_namespaces(self, user_module=None, user_ns=None):
924 # Create the namespace where the user will operate. user_ns is
926 # Create the namespace where the user will operate. user_ns is
925 # normally the only one used, and it is passed to the exec calls as
927 # normally the only one used, and it is passed to the exec calls as
926 # the locals argument. But we do carry a user_global_ns namespace
928 # the locals argument. But we do carry a user_global_ns namespace
927 # given as the exec 'globals' argument, This is useful in embedding
929 # given as the exec 'globals' argument, This is useful in embedding
928 # situations where the ipython shell opens in a context where the
930 # situations where the ipython shell opens in a context where the
929 # distinction between locals and globals is meaningful. For
931 # distinction between locals and globals is meaningful. For
930 # non-embedded contexts, it is just the same object as the user_ns dict.
932 # non-embedded contexts, it is just the same object as the user_ns dict.
931
933
932 # FIXME. For some strange reason, __builtins__ is showing up at user
934 # FIXME. For some strange reason, __builtins__ is showing up at user
933 # level as a dict instead of a module. This is a manual fix, but I
935 # level as a dict instead of a module. This is a manual fix, but I
934 # should really track down where the problem is coming from. Alex
936 # should really track down where the problem is coming from. Alex
935 # Schmolck reported this problem first.
937 # Schmolck reported this problem first.
936
938
937 # A useful post by Alex Martelli on this topic:
939 # A useful post by Alex Martelli on this topic:
938 # Re: inconsistent value from __builtins__
940 # Re: inconsistent value from __builtins__
939 # Von: Alex Martelli <aleaxit@yahoo.com>
941 # Von: Alex Martelli <aleaxit@yahoo.com>
940 # Datum: Freitag 01 Oktober 2004 04:45:34 nachmittags/abends
942 # Datum: Freitag 01 Oktober 2004 04:45:34 nachmittags/abends
941 # Gruppen: comp.lang.python
943 # Gruppen: comp.lang.python
942
944
943 # Michael Hohn <hohn@hooknose.lbl.gov> wrote:
945 # Michael Hohn <hohn@hooknose.lbl.gov> wrote:
944 # > >>> print type(builtin_check.get_global_binding('__builtins__'))
946 # > >>> print type(builtin_check.get_global_binding('__builtins__'))
945 # > <type 'dict'>
947 # > <type 'dict'>
946 # > >>> print type(__builtins__)
948 # > >>> print type(__builtins__)
947 # > <type 'module'>
949 # > <type 'module'>
948 # > Is this difference in return value intentional?
950 # > Is this difference in return value intentional?
949
951
950 # Well, it's documented that '__builtins__' can be either a dictionary
952 # Well, it's documented that '__builtins__' can be either a dictionary
951 # or a module, and it's been that way for a long time. Whether it's
953 # or a module, and it's been that way for a long time. Whether it's
952 # intentional (or sensible), I don't know. In any case, the idea is
954 # intentional (or sensible), I don't know. In any case, the idea is
953 # that if you need to access the built-in namespace directly, you
955 # that if you need to access the built-in namespace directly, you
954 # should start with "import __builtin__" (note, no 's') which will
956 # should start with "import __builtin__" (note, no 's') which will
955 # definitely give you a module. Yeah, it's somewhat confusing:-(.
957 # definitely give you a module. Yeah, it's somewhat confusing:-(.
956
958
957 # These routines return a properly built module and dict as needed by
959 # These routines return a properly built module and dict as needed by
958 # the rest of the code, and can also be used by extension writers to
960 # the rest of the code, and can also be used by extension writers to
959 # generate properly initialized namespaces.
961 # generate properly initialized namespaces.
960 if (user_ns is not None) or (user_module is not None):
962 if (user_ns is not None) or (user_module is not None):
961 self.default_user_namespaces = False
963 self.default_user_namespaces = False
962 self.user_module, self.user_ns = self.prepare_user_module(user_module, user_ns)
964 self.user_module, self.user_ns = self.prepare_user_module(user_module, user_ns)
963
965
964 # A record of hidden variables we have added to the user namespace, so
966 # A record of hidden variables we have added to the user namespace, so
965 # we can list later only variables defined in actual interactive use.
967 # we can list later only variables defined in actual interactive use.
966 self.user_ns_hidden = set()
968 self.user_ns_hidden = set()
967
969
968 # Now that FakeModule produces a real module, we've run into a nasty
970 # Now that FakeModule produces a real module, we've run into a nasty
969 # problem: after script execution (via %run), the module where the user
971 # problem: after script execution (via %run), the module where the user
970 # code ran is deleted. Now that this object is a true module (needed
972 # code ran is deleted. Now that this object is a true module (needed
971 # so docetst and other tools work correctly), the Python module
973 # so docetst and other tools work correctly), the Python module
972 # teardown mechanism runs over it, and sets to None every variable
974 # teardown mechanism runs over it, and sets to None every variable
973 # present in that module. Top-level references to objects from the
975 # present in that module. Top-level references to objects from the
974 # script survive, because the user_ns is updated with them. However,
976 # script survive, because the user_ns is updated with them. However,
975 # calling functions defined in the script that use other things from
977 # calling functions defined in the script that use other things from
976 # the script will fail, because the function's closure had references
978 # the script will fail, because the function's closure had references
977 # to the original objects, which are now all None. So we must protect
979 # to the original objects, which are now all None. So we must protect
978 # these modules from deletion by keeping a cache.
980 # these modules from deletion by keeping a cache.
979 #
981 #
980 # To avoid keeping stale modules around (we only need the one from the
982 # To avoid keeping stale modules around (we only need the one from the
981 # last run), we use a dict keyed with the full path to the script, so
983 # last run), we use a dict keyed with the full path to the script, so
982 # only the last version of the module is held in the cache. Note,
984 # only the last version of the module is held in the cache. Note,
983 # however, that we must cache the module *namespace contents* (their
985 # however, that we must cache the module *namespace contents* (their
984 # __dict__). Because if we try to cache the actual modules, old ones
986 # __dict__). Because if we try to cache the actual modules, old ones
985 # (uncached) could be destroyed while still holding references (such as
987 # (uncached) could be destroyed while still holding references (such as
986 # those held by GUI objects that tend to be long-lived)>
988 # those held by GUI objects that tend to be long-lived)>
987 #
989 #
988 # The %reset command will flush this cache. See the cache_main_mod()
990 # The %reset command will flush this cache. See the cache_main_mod()
989 # and clear_main_mod_cache() methods for details on use.
991 # and clear_main_mod_cache() methods for details on use.
990
992
991 # This is the cache used for 'main' namespaces
993 # This is the cache used for 'main' namespaces
992 self._main_ns_cache = {}
994 self._main_ns_cache = {}
993 # And this is the single instance of FakeModule whose __dict__ we keep
995 # And this is the single instance of FakeModule whose __dict__ we keep
994 # copying and clearing for reuse on each %run
996 # copying and clearing for reuse on each %run
995 self._user_main_module = FakeModule()
997 self._user_main_module = FakeModule()
996
998
997 # A table holding all the namespaces IPython deals with, so that
999 # A table holding all the namespaces IPython deals with, so that
998 # introspection facilities can search easily.
1000 # introspection facilities can search easily.
999 self.ns_table = {'user_global':self.user_module.__dict__,
1001 self.ns_table = {'user_global':self.user_module.__dict__,
1000 'user_local':self.user_ns,
1002 'user_local':self.user_ns,
1001 'builtin':builtin_mod.__dict__
1003 'builtin':builtin_mod.__dict__
1002 }
1004 }
1003
1005
1004 @property
1006 @property
1005 def user_global_ns(self):
1007 def user_global_ns(self):
1006 return self.user_module.__dict__
1008 return self.user_module.__dict__
1007
1009
1008 def prepare_user_module(self, user_module=None, user_ns=None):
1010 def prepare_user_module(self, user_module=None, user_ns=None):
1009 """Prepare the module and namespace in which user code will be run.
1011 """Prepare the module and namespace in which user code will be run.
1010
1012
1011 When IPython is started normally, both parameters are None: a new module
1013 When IPython is started normally, both parameters are None: a new module
1012 is created automatically, and its __dict__ used as the namespace.
1014 is created automatically, and its __dict__ used as the namespace.
1013
1015
1014 If only user_module is provided, its __dict__ is used as the namespace.
1016 If only user_module is provided, its __dict__ is used as the namespace.
1015 If only user_ns is provided, a dummy module is created, and user_ns
1017 If only user_ns is provided, a dummy module is created, and user_ns
1016 becomes the global namespace. If both are provided (as they may be
1018 becomes the global namespace. If both are provided (as they may be
1017 when embedding), user_ns is the local namespace, and user_module
1019 when embedding), user_ns is the local namespace, and user_module
1018 provides the global namespace.
1020 provides the global namespace.
1019
1021
1020 Parameters
1022 Parameters
1021 ----------
1023 ----------
1022 user_module : module, optional
1024 user_module : module, optional
1023 The current user module in which IPython is being run. If None,
1025 The current user module in which IPython is being run. If None,
1024 a clean module will be created.
1026 a clean module will be created.
1025 user_ns : dict, optional
1027 user_ns : dict, optional
1026 A namespace in which to run interactive commands.
1028 A namespace in which to run interactive commands.
1027
1029
1028 Returns
1030 Returns
1029 -------
1031 -------
1030 A tuple of user_module and user_ns, each properly initialised.
1032 A tuple of user_module and user_ns, each properly initialised.
1031 """
1033 """
1032 if user_module is None and user_ns is not None:
1034 if user_module is None and user_ns is not None:
1033 user_ns.setdefault("__name__", "__main__")
1035 user_ns.setdefault("__name__", "__main__")
1034 class DummyMod(object):
1036 class DummyMod(object):
1035 "A dummy module used for IPython's interactive namespace."
1037 "A dummy module used for IPython's interactive namespace."
1036 pass
1038 pass
1037 user_module = DummyMod()
1039 user_module = DummyMod()
1038 user_module.__dict__ = user_ns
1040 user_module.__dict__ = user_ns
1039
1041
1040 if user_module is None:
1042 if user_module is None:
1041 user_module = types.ModuleType("__main__",
1043 user_module = types.ModuleType("__main__",
1042 doc="Automatically created module for IPython interactive environment")
1044 doc="Automatically created module for IPython interactive environment")
1043
1045
1044 # We must ensure that __builtin__ (without the final 's') is always
1046 # We must ensure that __builtin__ (without the final 's') is always
1045 # available and pointing to the __builtin__ *module*. For more details:
1047 # available and pointing to the __builtin__ *module*. For more details:
1046 # http://mail.python.org/pipermail/python-dev/2001-April/014068.html
1048 # http://mail.python.org/pipermail/python-dev/2001-April/014068.html
1047 user_module.__dict__.setdefault('__builtin__', builtin_mod)
1049 user_module.__dict__.setdefault('__builtin__', builtin_mod)
1048 user_module.__dict__.setdefault('__builtins__', builtin_mod)
1050 user_module.__dict__.setdefault('__builtins__', builtin_mod)
1049
1051
1050 if user_ns is None:
1052 if user_ns is None:
1051 user_ns = user_module.__dict__
1053 user_ns = user_module.__dict__
1052
1054
1053 return user_module, user_ns
1055 return user_module, user_ns
1054
1056
1055 def init_sys_modules(self):
1057 def init_sys_modules(self):
1056 # We need to insert into sys.modules something that looks like a
1058 # We need to insert into sys.modules something that looks like a
1057 # module but which accesses the IPython namespace, for shelve and
1059 # module but which accesses the IPython namespace, for shelve and
1058 # pickle to work interactively. Normally they rely on getting
1060 # pickle to work interactively. Normally they rely on getting
1059 # everything out of __main__, but for embedding purposes each IPython
1061 # everything out of __main__, but for embedding purposes each IPython
1060 # instance has its own private namespace, so we can't go shoving
1062 # instance has its own private namespace, so we can't go shoving
1061 # everything into __main__.
1063 # everything into __main__.
1062
1064
1063 # note, however, that we should only do this for non-embedded
1065 # note, however, that we should only do this for non-embedded
1064 # ipythons, which really mimic the __main__.__dict__ with their own
1066 # ipythons, which really mimic the __main__.__dict__ with their own
1065 # namespace. Embedded instances, on the other hand, should not do
1067 # namespace. Embedded instances, on the other hand, should not do
1066 # this because they need to manage the user local/global namespaces
1068 # this because they need to manage the user local/global namespaces
1067 # only, but they live within a 'normal' __main__ (meaning, they
1069 # only, but they live within a 'normal' __main__ (meaning, they
1068 # shouldn't overtake the execution environment of the script they're
1070 # shouldn't overtake the execution environment of the script they're
1069 # embedded in).
1071 # embedded in).
1070
1072
1071 # This is overridden in the InteractiveShellEmbed subclass to a no-op.
1073 # This is overridden in the InteractiveShellEmbed subclass to a no-op.
1072 main_name = self.user_module.__name__
1074 main_name = self.user_module.__name__
1073 sys.modules[main_name] = self.user_module
1075 sys.modules[main_name] = self.user_module
1074
1076
1075 def init_user_ns(self):
1077 def init_user_ns(self):
1076 """Initialize all user-visible namespaces to their minimum defaults.
1078 """Initialize all user-visible namespaces to their minimum defaults.
1077
1079
1078 Certain history lists are also initialized here, as they effectively
1080 Certain history lists are also initialized here, as they effectively
1079 act as user namespaces.
1081 act as user namespaces.
1080
1082
1081 Notes
1083 Notes
1082 -----
1084 -----
1083 All data structures here are only filled in, they are NOT reset by this
1085 All data structures here are only filled in, they are NOT reset by this
1084 method. If they were not empty before, data will simply be added to
1086 method. If they were not empty before, data will simply be added to
1085 therm.
1087 therm.
1086 """
1088 """
1087 # This function works in two parts: first we put a few things in
1089 # This function works in two parts: first we put a few things in
1088 # user_ns, and we sync that contents into user_ns_hidden so that these
1090 # user_ns, and we sync that contents into user_ns_hidden so that these
1089 # initial variables aren't shown by %who. After the sync, we add the
1091 # initial variables aren't shown by %who. After the sync, we add the
1090 # rest of what we *do* want the user to see with %who even on a new
1092 # rest of what we *do* want the user to see with %who even on a new
1091 # session (probably nothing, so theye really only see their own stuff)
1093 # session (probably nothing, so theye really only see their own stuff)
1092
1094
1093 # The user dict must *always* have a __builtin__ reference to the
1095 # The user dict must *always* have a __builtin__ reference to the
1094 # Python standard __builtin__ namespace, which must be imported.
1096 # Python standard __builtin__ namespace, which must be imported.
1095 # This is so that certain operations in prompt evaluation can be
1097 # This is so that certain operations in prompt evaluation can be
1096 # reliably executed with builtins. Note that we can NOT use
1098 # reliably executed with builtins. Note that we can NOT use
1097 # __builtins__ (note the 's'), because that can either be a dict or a
1099 # __builtins__ (note the 's'), because that can either be a dict or a
1098 # module, and can even mutate at runtime, depending on the context
1100 # module, and can even mutate at runtime, depending on the context
1099 # (Python makes no guarantees on it). In contrast, __builtin__ is
1101 # (Python makes no guarantees on it). In contrast, __builtin__ is
1100 # always a module object, though it must be explicitly imported.
1102 # always a module object, though it must be explicitly imported.
1101
1103
1102 # For more details:
1104 # For more details:
1103 # http://mail.python.org/pipermail/python-dev/2001-April/014068.html
1105 # http://mail.python.org/pipermail/python-dev/2001-April/014068.html
1104 ns = dict()
1106 ns = dict()
1105
1107
1106 # Put 'help' in the user namespace
1108 # Put 'help' in the user namespace
1107 try:
1109 try:
1108 from site import _Helper
1110 from site import _Helper
1109 ns['help'] = _Helper()
1111 ns['help'] = _Helper()
1110 except ImportError:
1112 except ImportError:
1111 warn('help() not available - check site.py')
1113 warn('help() not available - check site.py')
1112
1114
1113 # make global variables for user access to the histories
1115 # make global variables for user access to the histories
1114 ns['_ih'] = self.history_manager.input_hist_parsed
1116 ns['_ih'] = self.history_manager.input_hist_parsed
1115 ns['_oh'] = self.history_manager.output_hist
1117 ns['_oh'] = self.history_manager.output_hist
1116 ns['_dh'] = self.history_manager.dir_hist
1118 ns['_dh'] = self.history_manager.dir_hist
1117
1119
1118 ns['_sh'] = shadowns
1120 ns['_sh'] = shadowns
1119
1121
1120 # user aliases to input and output histories. These shouldn't show up
1122 # user aliases to input and output histories. These shouldn't show up
1121 # in %who, as they can have very large reprs.
1123 # in %who, as they can have very large reprs.
1122 ns['In'] = self.history_manager.input_hist_parsed
1124 ns['In'] = self.history_manager.input_hist_parsed
1123 ns['Out'] = self.history_manager.output_hist
1125 ns['Out'] = self.history_manager.output_hist
1124
1126
1125 # Store myself as the public api!!!
1127 # Store myself as the public api!!!
1126 ns['get_ipython'] = self.get_ipython
1128 ns['get_ipython'] = self.get_ipython
1127
1129
1128 ns['exit'] = self.exiter
1130 ns['exit'] = self.exiter
1129 ns['quit'] = self.exiter
1131 ns['quit'] = self.exiter
1130
1132
1131 # Sync what we've added so far to user_ns_hidden so these aren't seen
1133 # Sync what we've added so far to user_ns_hidden so these aren't seen
1132 # by %who
1134 # by %who
1133 self.user_ns_hidden.update(ns)
1135 self.user_ns_hidden.update(ns)
1134
1136
1135 # Anything put into ns now would show up in %who. Think twice before
1137 # Anything put into ns now would show up in %who. Think twice before
1136 # putting anything here, as we really want %who to show the user their
1138 # putting anything here, as we really want %who to show the user their
1137 # stuff, not our variables.
1139 # stuff, not our variables.
1138
1140
1139 # Finally, update the real user's namespace
1141 # Finally, update the real user's namespace
1140 self.user_ns.update(ns)
1142 self.user_ns.update(ns)
1141
1143
1142 @property
1144 @property
1143 def all_ns_refs(self):
1145 def all_ns_refs(self):
1144 """Get a list of references to all the namespace dictionaries in which
1146 """Get a list of references to all the namespace dictionaries in which
1145 IPython might store a user-created object.
1147 IPython might store a user-created object.
1146
1148
1147 Note that this does not include the displayhook, which also caches
1149 Note that this does not include the displayhook, which also caches
1148 objects from the output."""
1150 objects from the output."""
1149 return [self.user_ns, self.user_global_ns,
1151 return [self.user_ns, self.user_global_ns,
1150 self._user_main_module.__dict__] + self._main_ns_cache.values()
1152 self._user_main_module.__dict__] + self._main_ns_cache.values()
1151
1153
1152 def reset(self, new_session=True):
1154 def reset(self, new_session=True):
1153 """Clear all internal namespaces, and attempt to release references to
1155 """Clear all internal namespaces, and attempt to release references to
1154 user objects.
1156 user objects.
1155
1157
1156 If new_session is True, a new history session will be opened.
1158 If new_session is True, a new history session will be opened.
1157 """
1159 """
1158 # Clear histories
1160 # Clear histories
1159 self.history_manager.reset(new_session)
1161 self.history_manager.reset(new_session)
1160 # Reset counter used to index all histories
1162 # Reset counter used to index all histories
1161 if new_session:
1163 if new_session:
1162 self.execution_count = 1
1164 self.execution_count = 1
1163
1165
1164 # Flush cached output items
1166 # Flush cached output items
1165 if self.displayhook.do_full_cache:
1167 if self.displayhook.do_full_cache:
1166 self.displayhook.flush()
1168 self.displayhook.flush()
1167
1169
1168 # The main execution namespaces must be cleared very carefully,
1170 # The main execution namespaces must be cleared very carefully,
1169 # skipping the deletion of the builtin-related keys, because doing so
1171 # skipping the deletion of the builtin-related keys, because doing so
1170 # would cause errors in many object's __del__ methods.
1172 # would cause errors in many object's __del__ methods.
1171 if self.user_ns is not self.user_global_ns:
1173 if self.user_ns is not self.user_global_ns:
1172 self.user_ns.clear()
1174 self.user_ns.clear()
1173 ns = self.user_global_ns
1175 ns = self.user_global_ns
1174 drop_keys = set(ns.keys())
1176 drop_keys = set(ns.keys())
1175 drop_keys.discard('__builtin__')
1177 drop_keys.discard('__builtin__')
1176 drop_keys.discard('__builtins__')
1178 drop_keys.discard('__builtins__')
1177 drop_keys.discard('__name__')
1179 drop_keys.discard('__name__')
1178 for k in drop_keys:
1180 for k in drop_keys:
1179 del ns[k]
1181 del ns[k]
1180
1182
1181 self.user_ns_hidden.clear()
1183 self.user_ns_hidden.clear()
1182
1184
1183 # Restore the user namespaces to minimal usability
1185 # Restore the user namespaces to minimal usability
1184 self.init_user_ns()
1186 self.init_user_ns()
1185
1187
1186 # Restore the default and user aliases
1188 # Restore the default and user aliases
1187 self.alias_manager.clear_aliases()
1189 self.alias_manager.clear_aliases()
1188 self.alias_manager.init_aliases()
1190 self.alias_manager.init_aliases()
1189
1191
1190 # Flush the private list of module references kept for script
1192 # Flush the private list of module references kept for script
1191 # execution protection
1193 # execution protection
1192 self.clear_main_mod_cache()
1194 self.clear_main_mod_cache()
1193
1195
1194 # Clear out the namespace from the last %run
1196 # Clear out the namespace from the last %run
1195 self.new_main_mod()
1197 self.new_main_mod()
1196
1198
1197 def del_var(self, varname, by_name=False):
1199 def del_var(self, varname, by_name=False):
1198 """Delete a variable from the various namespaces, so that, as
1200 """Delete a variable from the various namespaces, so that, as
1199 far as possible, we're not keeping any hidden references to it.
1201 far as possible, we're not keeping any hidden references to it.
1200
1202
1201 Parameters
1203 Parameters
1202 ----------
1204 ----------
1203 varname : str
1205 varname : str
1204 The name of the variable to delete.
1206 The name of the variable to delete.
1205 by_name : bool
1207 by_name : bool
1206 If True, delete variables with the given name in each
1208 If True, delete variables with the given name in each
1207 namespace. If False (default), find the variable in the user
1209 namespace. If False (default), find the variable in the user
1208 namespace, and delete references to it.
1210 namespace, and delete references to it.
1209 """
1211 """
1210 if varname in ('__builtin__', '__builtins__'):
1212 if varname in ('__builtin__', '__builtins__'):
1211 raise ValueError("Refusing to delete %s" % varname)
1213 raise ValueError("Refusing to delete %s" % varname)
1212
1214
1213 ns_refs = self.all_ns_refs
1215 ns_refs = self.all_ns_refs
1214
1216
1215 if by_name: # Delete by name
1217 if by_name: # Delete by name
1216 for ns in ns_refs:
1218 for ns in ns_refs:
1217 try:
1219 try:
1218 del ns[varname]
1220 del ns[varname]
1219 except KeyError:
1221 except KeyError:
1220 pass
1222 pass
1221 else: # Delete by object
1223 else: # Delete by object
1222 try:
1224 try:
1223 obj = self.user_ns[varname]
1225 obj = self.user_ns[varname]
1224 except KeyError:
1226 except KeyError:
1225 raise NameError("name '%s' is not defined" % varname)
1227 raise NameError("name '%s' is not defined" % varname)
1226 # Also check in output history
1228 # Also check in output history
1227 ns_refs.append(self.history_manager.output_hist)
1229 ns_refs.append(self.history_manager.output_hist)
1228 for ns in ns_refs:
1230 for ns in ns_refs:
1229 to_delete = [n for n, o in ns.iteritems() if o is obj]
1231 to_delete = [n for n, o in ns.iteritems() if o is obj]
1230 for name in to_delete:
1232 for name in to_delete:
1231 del ns[name]
1233 del ns[name]
1232
1234
1233 # displayhook keeps extra references, but not in a dictionary
1235 # displayhook keeps extra references, but not in a dictionary
1234 for name in ('_', '__', '___'):
1236 for name in ('_', '__', '___'):
1235 if getattr(self.displayhook, name) is obj:
1237 if getattr(self.displayhook, name) is obj:
1236 setattr(self.displayhook, name, None)
1238 setattr(self.displayhook, name, None)
1237
1239
1238 def reset_selective(self, regex=None):
1240 def reset_selective(self, regex=None):
1239 """Clear selective variables from internal namespaces based on a
1241 """Clear selective variables from internal namespaces based on a
1240 specified regular expression.
1242 specified regular expression.
1241
1243
1242 Parameters
1244 Parameters
1243 ----------
1245 ----------
1244 regex : string or compiled pattern, optional
1246 regex : string or compiled pattern, optional
1245 A regular expression pattern that will be used in searching
1247 A regular expression pattern that will be used in searching
1246 variable names in the users namespaces.
1248 variable names in the users namespaces.
1247 """
1249 """
1248 if regex is not None:
1250 if regex is not None:
1249 try:
1251 try:
1250 m = re.compile(regex)
1252 m = re.compile(regex)
1251 except TypeError:
1253 except TypeError:
1252 raise TypeError('regex must be a string or compiled pattern')
1254 raise TypeError('regex must be a string or compiled pattern')
1253 # Search for keys in each namespace that match the given regex
1255 # Search for keys in each namespace that match the given regex
1254 # If a match is found, delete the key/value pair.
1256 # If a match is found, delete the key/value pair.
1255 for ns in self.all_ns_refs:
1257 for ns in self.all_ns_refs:
1256 for var in ns:
1258 for var in ns:
1257 if m.search(var):
1259 if m.search(var):
1258 del ns[var]
1260 del ns[var]
1259
1261
1260 def push(self, variables, interactive=True):
1262 def push(self, variables, interactive=True):
1261 """Inject a group of variables into the IPython user namespace.
1263 """Inject a group of variables into the IPython user namespace.
1262
1264
1263 Parameters
1265 Parameters
1264 ----------
1266 ----------
1265 variables : dict, str or list/tuple of str
1267 variables : dict, str or list/tuple of str
1266 The variables to inject into the user's namespace. If a dict, a
1268 The variables to inject into the user's namespace. If a dict, a
1267 simple update is done. If a str, the string is assumed to have
1269 simple update is done. If a str, the string is assumed to have
1268 variable names separated by spaces. A list/tuple of str can also
1270 variable names separated by spaces. A list/tuple of str can also
1269 be used to give the variable names. If just the variable names are
1271 be used to give the variable names. If just the variable names are
1270 give (list/tuple/str) then the variable values looked up in the
1272 give (list/tuple/str) then the variable values looked up in the
1271 callers frame.
1273 callers frame.
1272 interactive : bool
1274 interactive : bool
1273 If True (default), the variables will be listed with the ``who``
1275 If True (default), the variables will be listed with the ``who``
1274 magic.
1276 magic.
1275 """
1277 """
1276 vdict = None
1278 vdict = None
1277
1279
1278 # We need a dict of name/value pairs to do namespace updates.
1280 # We need a dict of name/value pairs to do namespace updates.
1279 if isinstance(variables, dict):
1281 if isinstance(variables, dict):
1280 vdict = variables
1282 vdict = variables
1281 elif isinstance(variables, (basestring, list, tuple)):
1283 elif isinstance(variables, (basestring, list, tuple)):
1282 if isinstance(variables, basestring):
1284 if isinstance(variables, basestring):
1283 vlist = variables.split()
1285 vlist = variables.split()
1284 else:
1286 else:
1285 vlist = variables
1287 vlist = variables
1286 vdict = {}
1288 vdict = {}
1287 cf = sys._getframe(1)
1289 cf = sys._getframe(1)
1288 for name in vlist:
1290 for name in vlist:
1289 try:
1291 try:
1290 vdict[name] = eval(name, cf.f_globals, cf.f_locals)
1292 vdict[name] = eval(name, cf.f_globals, cf.f_locals)
1291 except:
1293 except:
1292 print ('Could not get variable %s from %s' %
1294 print ('Could not get variable %s from %s' %
1293 (name,cf.f_code.co_name))
1295 (name,cf.f_code.co_name))
1294 else:
1296 else:
1295 raise ValueError('variables must be a dict/str/list/tuple')
1297 raise ValueError('variables must be a dict/str/list/tuple')
1296
1298
1297 # Propagate variables to user namespace
1299 # Propagate variables to user namespace
1298 self.user_ns.update(vdict)
1300 self.user_ns.update(vdict)
1299
1301
1300 # And configure interactive visibility
1302 # And configure interactive visibility
1301 user_ns_hidden = self.user_ns_hidden
1303 user_ns_hidden = self.user_ns_hidden
1302 if interactive:
1304 if interactive:
1303 user_ns_hidden.difference_update(vdict)
1305 user_ns_hidden.difference_update(vdict)
1304 else:
1306 else:
1305 user_ns_hidden.update(vdict)
1307 user_ns_hidden.update(vdict)
1306
1308
1307 def drop_by_id(self, variables):
1309 def drop_by_id(self, variables):
1308 """Remove a dict of variables from the user namespace, if they are the
1310 """Remove a dict of variables from the user namespace, if they are the
1309 same as the values in the dictionary.
1311 same as the values in the dictionary.
1310
1312
1311 This is intended for use by extensions: variables that they've added can
1313 This is intended for use by extensions: variables that they've added can
1312 be taken back out if they are unloaded, without removing any that the
1314 be taken back out if they are unloaded, without removing any that the
1313 user has overwritten.
1315 user has overwritten.
1314
1316
1315 Parameters
1317 Parameters
1316 ----------
1318 ----------
1317 variables : dict
1319 variables : dict
1318 A dictionary mapping object names (as strings) to the objects.
1320 A dictionary mapping object names (as strings) to the objects.
1319 """
1321 """
1320 for name, obj in variables.iteritems():
1322 for name, obj in variables.iteritems():
1321 if name in self.user_ns and self.user_ns[name] is obj:
1323 if name in self.user_ns and self.user_ns[name] is obj:
1322 del self.user_ns[name]
1324 del self.user_ns[name]
1323 self.user_ns_hidden.discard(name)
1325 self.user_ns_hidden.discard(name)
1324
1326
1325 #-------------------------------------------------------------------------
1327 #-------------------------------------------------------------------------
1326 # Things related to object introspection
1328 # Things related to object introspection
1327 #-------------------------------------------------------------------------
1329 #-------------------------------------------------------------------------
1328
1330
1329 def _ofind(self, oname, namespaces=None):
1331 def _ofind(self, oname, namespaces=None):
1330 """Find an object in the available namespaces.
1332 """Find an object in the available namespaces.
1331
1333
1332 self._ofind(oname) -> dict with keys: found,obj,ospace,ismagic
1334 self._ofind(oname) -> dict with keys: found,obj,ospace,ismagic
1333
1335
1334 Has special code to detect magic functions.
1336 Has special code to detect magic functions.
1335 """
1337 """
1336 oname = oname.strip()
1338 oname = oname.strip()
1337 #print '1- oname: <%r>' % oname # dbg
1339 #print '1- oname: <%r>' % oname # dbg
1338 if not py3compat.isidentifier(oname.lstrip(ESC_MAGIC), dotted=True):
1340 if not py3compat.isidentifier(oname.lstrip(ESC_MAGIC), dotted=True):
1339 return dict(found=False)
1341 return dict(found=False)
1340
1342
1341 alias_ns = None
1343 alias_ns = None
1342 if namespaces is None:
1344 if namespaces is None:
1343 # Namespaces to search in:
1345 # Namespaces to search in:
1344 # Put them in a list. The order is important so that we
1346 # Put them in a list. The order is important so that we
1345 # find things in the same order that Python finds them.
1347 # find things in the same order that Python finds them.
1346 namespaces = [ ('Interactive', self.user_ns),
1348 namespaces = [ ('Interactive', self.user_ns),
1347 ('Interactive (global)', self.user_global_ns),
1349 ('Interactive (global)', self.user_global_ns),
1348 ('Python builtin', builtin_mod.__dict__),
1350 ('Python builtin', builtin_mod.__dict__),
1349 ('Alias', self.alias_manager.alias_table),
1351 ('Alias', self.alias_manager.alias_table),
1350 ]
1352 ]
1351 alias_ns = self.alias_manager.alias_table
1353 alias_ns = self.alias_manager.alias_table
1352
1354
1353 # initialize results to 'null'
1355 # initialize results to 'null'
1354 found = False; obj = None; ospace = None; ds = None;
1356 found = False; obj = None; ospace = None; ds = None;
1355 ismagic = False; isalias = False; parent = None
1357 ismagic = False; isalias = False; parent = None
1356
1358
1357 # We need to special-case 'print', which as of python2.6 registers as a
1359 # We need to special-case 'print', which as of python2.6 registers as a
1358 # function but should only be treated as one if print_function was
1360 # function but should only be treated as one if print_function was
1359 # loaded with a future import. In this case, just bail.
1361 # loaded with a future import. In this case, just bail.
1360 if (oname == 'print' and not py3compat.PY3 and not \
1362 if (oname == 'print' and not py3compat.PY3 and not \
1361 (self.compile.compiler_flags & __future__.CO_FUTURE_PRINT_FUNCTION)):
1363 (self.compile.compiler_flags & __future__.CO_FUTURE_PRINT_FUNCTION)):
1362 return {'found':found, 'obj':obj, 'namespace':ospace,
1364 return {'found':found, 'obj':obj, 'namespace':ospace,
1363 'ismagic':ismagic, 'isalias':isalias, 'parent':parent}
1365 'ismagic':ismagic, 'isalias':isalias, 'parent':parent}
1364
1366
1365 # Look for the given name by splitting it in parts. If the head is
1367 # Look for the given name by splitting it in parts. If the head is
1366 # found, then we look for all the remaining parts as members, and only
1368 # found, then we look for all the remaining parts as members, and only
1367 # declare success if we can find them all.
1369 # declare success if we can find them all.
1368 oname_parts = oname.split('.')
1370 oname_parts = oname.split('.')
1369 oname_head, oname_rest = oname_parts[0],oname_parts[1:]
1371 oname_head, oname_rest = oname_parts[0],oname_parts[1:]
1370 for nsname,ns in namespaces:
1372 for nsname,ns in namespaces:
1371 try:
1373 try:
1372 obj = ns[oname_head]
1374 obj = ns[oname_head]
1373 except KeyError:
1375 except KeyError:
1374 continue
1376 continue
1375 else:
1377 else:
1376 #print 'oname_rest:', oname_rest # dbg
1378 #print 'oname_rest:', oname_rest # dbg
1377 for part in oname_rest:
1379 for part in oname_rest:
1378 try:
1380 try:
1379 parent = obj
1381 parent = obj
1380 obj = getattr(obj,part)
1382 obj = getattr(obj,part)
1381 except:
1383 except:
1382 # Blanket except b/c some badly implemented objects
1384 # Blanket except b/c some badly implemented objects
1383 # allow __getattr__ to raise exceptions other than
1385 # allow __getattr__ to raise exceptions other than
1384 # AttributeError, which then crashes IPython.
1386 # AttributeError, which then crashes IPython.
1385 break
1387 break
1386 else:
1388 else:
1387 # If we finish the for loop (no break), we got all members
1389 # If we finish the for loop (no break), we got all members
1388 found = True
1390 found = True
1389 ospace = nsname
1391 ospace = nsname
1390 if ns == alias_ns:
1392 if ns == alias_ns:
1391 isalias = True
1393 isalias = True
1392 break # namespace loop
1394 break # namespace loop
1393
1395
1394 # Try to see if it's magic
1396 # Try to see if it's magic
1395 if not found:
1397 if not found:
1396 if oname.startswith(ESC_MAGIC):
1398 if oname.startswith(ESC_MAGIC):
1397 oname = oname[1:]
1399 oname = oname[1:]
1398 obj = getattr(self,'magic_'+oname,None)
1400 obj = getattr(self,'magic_'+oname,None)
1399 if obj is not None:
1401 if obj is not None:
1400 found = True
1402 found = True
1401 ospace = 'IPython internal'
1403 ospace = 'IPython internal'
1402 ismagic = True
1404 ismagic = True
1403
1405
1404 # Last try: special-case some literals like '', [], {}, etc:
1406 # Last try: special-case some literals like '', [], {}, etc:
1405 if not found and oname_head in ["''",'""','[]','{}','()']:
1407 if not found and oname_head in ["''",'""','[]','{}','()']:
1406 obj = eval(oname_head)
1408 obj = eval(oname_head)
1407 found = True
1409 found = True
1408 ospace = 'Interactive'
1410 ospace = 'Interactive'
1409
1411
1410 return {'found':found, 'obj':obj, 'namespace':ospace,
1412 return {'found':found, 'obj':obj, 'namespace':ospace,
1411 'ismagic':ismagic, 'isalias':isalias, 'parent':parent}
1413 'ismagic':ismagic, 'isalias':isalias, 'parent':parent}
1412
1414
1413 def _ofind_property(self, oname, info):
1415 def _ofind_property(self, oname, info):
1414 """Second part of object finding, to look for property details."""
1416 """Second part of object finding, to look for property details."""
1415 if info.found:
1417 if info.found:
1416 # Get the docstring of the class property if it exists.
1418 # Get the docstring of the class property if it exists.
1417 path = oname.split('.')
1419 path = oname.split('.')
1418 root = '.'.join(path[:-1])
1420 root = '.'.join(path[:-1])
1419 if info.parent is not None:
1421 if info.parent is not None:
1420 try:
1422 try:
1421 target = getattr(info.parent, '__class__')
1423 target = getattr(info.parent, '__class__')
1422 # The object belongs to a class instance.
1424 # The object belongs to a class instance.
1423 try:
1425 try:
1424 target = getattr(target, path[-1])
1426 target = getattr(target, path[-1])
1425 # The class defines the object.
1427 # The class defines the object.
1426 if isinstance(target, property):
1428 if isinstance(target, property):
1427 oname = root + '.__class__.' + path[-1]
1429 oname = root + '.__class__.' + path[-1]
1428 info = Struct(self._ofind(oname))
1430 info = Struct(self._ofind(oname))
1429 except AttributeError: pass
1431 except AttributeError: pass
1430 except AttributeError: pass
1432 except AttributeError: pass
1431
1433
1432 # We return either the new info or the unmodified input if the object
1434 # We return either the new info or the unmodified input if the object
1433 # hadn't been found
1435 # hadn't been found
1434 return info
1436 return info
1435
1437
1436 def _object_find(self, oname, namespaces=None):
1438 def _object_find(self, oname, namespaces=None):
1437 """Find an object and return a struct with info about it."""
1439 """Find an object and return a struct with info about it."""
1438 inf = Struct(self._ofind(oname, namespaces))
1440 inf = Struct(self._ofind(oname, namespaces))
1439 return Struct(self._ofind_property(oname, inf))
1441 return Struct(self._ofind_property(oname, inf))
1440
1442
1441 def _inspect(self, meth, oname, namespaces=None, **kw):
1443 def _inspect(self, meth, oname, namespaces=None, **kw):
1442 """Generic interface to the inspector system.
1444 """Generic interface to the inspector system.
1443
1445
1444 This function is meant to be called by pdef, pdoc & friends."""
1446 This function is meant to be called by pdef, pdoc & friends."""
1445 info = self._object_find(oname)
1447 info = self._object_find(oname)
1446 if info.found:
1448 if info.found:
1447 pmethod = getattr(self.inspector, meth)
1449 pmethod = getattr(self.inspector, meth)
1448 formatter = format_screen if info.ismagic else None
1450 formatter = format_screen if info.ismagic else None
1449 if meth == 'pdoc':
1451 if meth == 'pdoc':
1450 pmethod(info.obj, oname, formatter)
1452 pmethod(info.obj, oname, formatter)
1451 elif meth == 'pinfo':
1453 elif meth == 'pinfo':
1452 pmethod(info.obj, oname, formatter, info, **kw)
1454 pmethod(info.obj, oname, formatter, info, **kw)
1453 else:
1455 else:
1454 pmethod(info.obj, oname)
1456 pmethod(info.obj, oname)
1455 else:
1457 else:
1456 print 'Object `%s` not found.' % oname
1458 print 'Object `%s` not found.' % oname
1457 return 'not found' # so callers can take other action
1459 return 'not found' # so callers can take other action
1458
1460
1459 def object_inspect(self, oname, detail_level=0):
1461 def object_inspect(self, oname, detail_level=0):
1460 with self.builtin_trap:
1462 with self.builtin_trap:
1461 info = self._object_find(oname)
1463 info = self._object_find(oname)
1462 if info.found:
1464 if info.found:
1463 return self.inspector.info(info.obj, oname, info=info,
1465 return self.inspector.info(info.obj, oname, info=info,
1464 detail_level=detail_level
1466 detail_level=detail_level
1465 )
1467 )
1466 else:
1468 else:
1467 return oinspect.object_info(name=oname, found=False)
1469 return oinspect.object_info(name=oname, found=False)
1468
1470
1469 #-------------------------------------------------------------------------
1471 #-------------------------------------------------------------------------
1470 # Things related to history management
1472 # Things related to history management
1471 #-------------------------------------------------------------------------
1473 #-------------------------------------------------------------------------
1472
1474
1473 def init_history(self):
1475 def init_history(self):
1474 """Sets up the command history, and starts regular autosaves."""
1476 """Sets up the command history, and starts regular autosaves."""
1475 self.history_manager = HistoryManager(shell=self, config=self.config)
1477 self.history_manager = HistoryManager(shell=self, config=self.config)
1476 self.configurables.append(self.history_manager)
1478 self.configurables.append(self.history_manager)
1477
1479
1478 #-------------------------------------------------------------------------
1480 #-------------------------------------------------------------------------
1479 # Things related to exception handling and tracebacks (not debugging)
1481 # Things related to exception handling and tracebacks (not debugging)
1480 #-------------------------------------------------------------------------
1482 #-------------------------------------------------------------------------
1481
1483
1482 def init_traceback_handlers(self, custom_exceptions):
1484 def init_traceback_handlers(self, custom_exceptions):
1483 # Syntax error handler.
1485 # Syntax error handler.
1484 self.SyntaxTB = ultratb.SyntaxTB(color_scheme='NoColor')
1486 self.SyntaxTB = ultratb.SyntaxTB(color_scheme='NoColor')
1485
1487
1486 # The interactive one is initialized with an offset, meaning we always
1488 # The interactive one is initialized with an offset, meaning we always
1487 # want to remove the topmost item in the traceback, which is our own
1489 # want to remove the topmost item in the traceback, which is our own
1488 # internal code. Valid modes: ['Plain','Context','Verbose']
1490 # internal code. Valid modes: ['Plain','Context','Verbose']
1489 self.InteractiveTB = ultratb.AutoFormattedTB(mode = 'Plain',
1491 self.InteractiveTB = ultratb.AutoFormattedTB(mode = 'Plain',
1490 color_scheme='NoColor',
1492 color_scheme='NoColor',
1491 tb_offset = 1,
1493 tb_offset = 1,
1492 check_cache=self.compile.check_cache)
1494 check_cache=self.compile.check_cache)
1493
1495
1494 # The instance will store a pointer to the system-wide exception hook,
1496 # The instance will store a pointer to the system-wide exception hook,
1495 # so that runtime code (such as magics) can access it. This is because
1497 # so that runtime code (such as magics) can access it. This is because
1496 # during the read-eval loop, it may get temporarily overwritten.
1498 # during the read-eval loop, it may get temporarily overwritten.
1497 self.sys_excepthook = sys.excepthook
1499 self.sys_excepthook = sys.excepthook
1498
1500
1499 # and add any custom exception handlers the user may have specified
1501 # and add any custom exception handlers the user may have specified
1500 self.set_custom_exc(*custom_exceptions)
1502 self.set_custom_exc(*custom_exceptions)
1501
1503
1502 # Set the exception mode
1504 # Set the exception mode
1503 self.InteractiveTB.set_mode(mode=self.xmode)
1505 self.InteractiveTB.set_mode(mode=self.xmode)
1504
1506
1505 def set_custom_exc(self, exc_tuple, handler):
1507 def set_custom_exc(self, exc_tuple, handler):
1506 """set_custom_exc(exc_tuple,handler)
1508 """set_custom_exc(exc_tuple,handler)
1507
1509
1508 Set a custom exception handler, which will be called if any of the
1510 Set a custom exception handler, which will be called if any of the
1509 exceptions in exc_tuple occur in the mainloop (specifically, in the
1511 exceptions in exc_tuple occur in the mainloop (specifically, in the
1510 run_code() method).
1512 run_code() method).
1511
1513
1512 Parameters
1514 Parameters
1513 ----------
1515 ----------
1514
1516
1515 exc_tuple : tuple of exception classes
1517 exc_tuple : tuple of exception classes
1516 A *tuple* of exception classes, for which to call the defined
1518 A *tuple* of exception classes, for which to call the defined
1517 handler. It is very important that you use a tuple, and NOT A
1519 handler. It is very important that you use a tuple, and NOT A
1518 LIST here, because of the way Python's except statement works. If
1520 LIST here, because of the way Python's except statement works. If
1519 you only want to trap a single exception, use a singleton tuple::
1521 you only want to trap a single exception, use a singleton tuple::
1520
1522
1521 exc_tuple == (MyCustomException,)
1523 exc_tuple == (MyCustomException,)
1522
1524
1523 handler : callable
1525 handler : callable
1524 handler must have the following signature::
1526 handler must have the following signature::
1525
1527
1526 def my_handler(self, etype, value, tb, tb_offset=None):
1528 def my_handler(self, etype, value, tb, tb_offset=None):
1527 ...
1529 ...
1528 return structured_traceback
1530 return structured_traceback
1529
1531
1530 Your handler must return a structured traceback (a list of strings),
1532 Your handler must return a structured traceback (a list of strings),
1531 or None.
1533 or None.
1532
1534
1533 This will be made into an instance method (via types.MethodType)
1535 This will be made into an instance method (via types.MethodType)
1534 of IPython itself, and it will be called if any of the exceptions
1536 of IPython itself, and it will be called if any of the exceptions
1535 listed in the exc_tuple are caught. If the handler is None, an
1537 listed in the exc_tuple are caught. If the handler is None, an
1536 internal basic one is used, which just prints basic info.
1538 internal basic one is used, which just prints basic info.
1537
1539
1538 To protect IPython from crashes, if your handler ever raises an
1540 To protect IPython from crashes, if your handler ever raises an
1539 exception or returns an invalid result, it will be immediately
1541 exception or returns an invalid result, it will be immediately
1540 disabled.
1542 disabled.
1541
1543
1542 WARNING: by putting in your own exception handler into IPython's main
1544 WARNING: by putting in your own exception handler into IPython's main
1543 execution loop, you run a very good chance of nasty crashes. This
1545 execution loop, you run a very good chance of nasty crashes. This
1544 facility should only be used if you really know what you are doing."""
1546 facility should only be used if you really know what you are doing."""
1545
1547
1546 assert type(exc_tuple)==type(()) , \
1548 assert type(exc_tuple)==type(()) , \
1547 "The custom exceptions must be given AS A TUPLE."
1549 "The custom exceptions must be given AS A TUPLE."
1548
1550
1549 def dummy_handler(self,etype,value,tb,tb_offset=None):
1551 def dummy_handler(self,etype,value,tb,tb_offset=None):
1550 print '*** Simple custom exception handler ***'
1552 print '*** Simple custom exception handler ***'
1551 print 'Exception type :',etype
1553 print 'Exception type :',etype
1552 print 'Exception value:',value
1554 print 'Exception value:',value
1553 print 'Traceback :',tb
1555 print 'Traceback :',tb
1554 #print 'Source code :','\n'.join(self.buffer)
1556 #print 'Source code :','\n'.join(self.buffer)
1555
1557
1556 def validate_stb(stb):
1558 def validate_stb(stb):
1557 """validate structured traceback return type
1559 """validate structured traceback return type
1558
1560
1559 return type of CustomTB *should* be a list of strings, but allow
1561 return type of CustomTB *should* be a list of strings, but allow
1560 single strings or None, which are harmless.
1562 single strings or None, which are harmless.
1561
1563
1562 This function will *always* return a list of strings,
1564 This function will *always* return a list of strings,
1563 and will raise a TypeError if stb is inappropriate.
1565 and will raise a TypeError if stb is inappropriate.
1564 """
1566 """
1565 msg = "CustomTB must return list of strings, not %r" % stb
1567 msg = "CustomTB must return list of strings, not %r" % stb
1566 if stb is None:
1568 if stb is None:
1567 return []
1569 return []
1568 elif isinstance(stb, basestring):
1570 elif isinstance(stb, basestring):
1569 return [stb]
1571 return [stb]
1570 elif not isinstance(stb, list):
1572 elif not isinstance(stb, list):
1571 raise TypeError(msg)
1573 raise TypeError(msg)
1572 # it's a list
1574 # it's a list
1573 for line in stb:
1575 for line in stb:
1574 # check every element
1576 # check every element
1575 if not isinstance(line, basestring):
1577 if not isinstance(line, basestring):
1576 raise TypeError(msg)
1578 raise TypeError(msg)
1577 return stb
1579 return stb
1578
1580
1579 if handler is None:
1581 if handler is None:
1580 wrapped = dummy_handler
1582 wrapped = dummy_handler
1581 else:
1583 else:
1582 def wrapped(self,etype,value,tb,tb_offset=None):
1584 def wrapped(self,etype,value,tb,tb_offset=None):
1583 """wrap CustomTB handler, to protect IPython from user code
1585 """wrap CustomTB handler, to protect IPython from user code
1584
1586
1585 This makes it harder (but not impossible) for custom exception
1587 This makes it harder (but not impossible) for custom exception
1586 handlers to crash IPython.
1588 handlers to crash IPython.
1587 """
1589 """
1588 try:
1590 try:
1589 stb = handler(self,etype,value,tb,tb_offset=tb_offset)
1591 stb = handler(self,etype,value,tb,tb_offset=tb_offset)
1590 return validate_stb(stb)
1592 return validate_stb(stb)
1591 except:
1593 except:
1592 # clear custom handler immediately
1594 # clear custom handler immediately
1593 self.set_custom_exc((), None)
1595 self.set_custom_exc((), None)
1594 print >> io.stderr, "Custom TB Handler failed, unregistering"
1596 print >> io.stderr, "Custom TB Handler failed, unregistering"
1595 # show the exception in handler first
1597 # show the exception in handler first
1596 stb = self.InteractiveTB.structured_traceback(*sys.exc_info())
1598 stb = self.InteractiveTB.structured_traceback(*sys.exc_info())
1597 print >> io.stdout, self.InteractiveTB.stb2text(stb)
1599 print >> io.stdout, self.InteractiveTB.stb2text(stb)
1598 print >> io.stdout, "The original exception:"
1600 print >> io.stdout, "The original exception:"
1599 stb = self.InteractiveTB.structured_traceback(
1601 stb = self.InteractiveTB.structured_traceback(
1600 (etype,value,tb), tb_offset=tb_offset
1602 (etype,value,tb), tb_offset=tb_offset
1601 )
1603 )
1602 return stb
1604 return stb
1603
1605
1604 self.CustomTB = types.MethodType(wrapped,self)
1606 self.CustomTB = types.MethodType(wrapped,self)
1605 self.custom_exceptions = exc_tuple
1607 self.custom_exceptions = exc_tuple
1606
1608
1607 def excepthook(self, etype, value, tb):
1609 def excepthook(self, etype, value, tb):
1608 """One more defense for GUI apps that call sys.excepthook.
1610 """One more defense for GUI apps that call sys.excepthook.
1609
1611
1610 GUI frameworks like wxPython trap exceptions and call
1612 GUI frameworks like wxPython trap exceptions and call
1611 sys.excepthook themselves. I guess this is a feature that
1613 sys.excepthook themselves. I guess this is a feature that
1612 enables them to keep running after exceptions that would
1614 enables them to keep running after exceptions that would
1613 otherwise kill their mainloop. This is a bother for IPython
1615 otherwise kill their mainloop. This is a bother for IPython
1614 which excepts to catch all of the program exceptions with a try:
1616 which excepts to catch all of the program exceptions with a try:
1615 except: statement.
1617 except: statement.
1616
1618
1617 Normally, IPython sets sys.excepthook to a CrashHandler instance, so if
1619 Normally, IPython sets sys.excepthook to a CrashHandler instance, so if
1618 any app directly invokes sys.excepthook, it will look to the user like
1620 any app directly invokes sys.excepthook, it will look to the user like
1619 IPython crashed. In order to work around this, we can disable the
1621 IPython crashed. In order to work around this, we can disable the
1620 CrashHandler and replace it with this excepthook instead, which prints a
1622 CrashHandler and replace it with this excepthook instead, which prints a
1621 regular traceback using our InteractiveTB. In this fashion, apps which
1623 regular traceback using our InteractiveTB. In this fashion, apps which
1622 call sys.excepthook will generate a regular-looking exception from
1624 call sys.excepthook will generate a regular-looking exception from
1623 IPython, and the CrashHandler will only be triggered by real IPython
1625 IPython, and the CrashHandler will only be triggered by real IPython
1624 crashes.
1626 crashes.
1625
1627
1626 This hook should be used sparingly, only in places which are not likely
1628 This hook should be used sparingly, only in places which are not likely
1627 to be true IPython errors.
1629 to be true IPython errors.
1628 """
1630 """
1629 self.showtraceback((etype,value,tb),tb_offset=0)
1631 self.showtraceback((etype,value,tb),tb_offset=0)
1630
1632
1631 def _get_exc_info(self, exc_tuple=None):
1633 def _get_exc_info(self, exc_tuple=None):
1632 """get exc_info from a given tuple, sys.exc_info() or sys.last_type etc.
1634 """get exc_info from a given tuple, sys.exc_info() or sys.last_type etc.
1633
1635
1634 Ensures sys.last_type,value,traceback hold the exc_info we found,
1636 Ensures sys.last_type,value,traceback hold the exc_info we found,
1635 from whichever source.
1637 from whichever source.
1636
1638
1637 raises ValueError if none of these contain any information
1639 raises ValueError if none of these contain any information
1638 """
1640 """
1639 if exc_tuple is None:
1641 if exc_tuple is None:
1640 etype, value, tb = sys.exc_info()
1642 etype, value, tb = sys.exc_info()
1641 else:
1643 else:
1642 etype, value, tb = exc_tuple
1644 etype, value, tb = exc_tuple
1643
1645
1644 if etype is None:
1646 if etype is None:
1645 if hasattr(sys, 'last_type'):
1647 if hasattr(sys, 'last_type'):
1646 etype, value, tb = sys.last_type, sys.last_value, \
1648 etype, value, tb = sys.last_type, sys.last_value, \
1647 sys.last_traceback
1649 sys.last_traceback
1648
1650
1649 if etype is None:
1651 if etype is None:
1650 raise ValueError("No exception to find")
1652 raise ValueError("No exception to find")
1651
1653
1652 # Now store the exception info in sys.last_type etc.
1654 # Now store the exception info in sys.last_type etc.
1653 # WARNING: these variables are somewhat deprecated and not
1655 # WARNING: these variables are somewhat deprecated and not
1654 # necessarily safe to use in a threaded environment, but tools
1656 # necessarily safe to use in a threaded environment, but tools
1655 # like pdb depend on their existence, so let's set them. If we
1657 # like pdb depend on their existence, so let's set them. If we
1656 # find problems in the field, we'll need to revisit their use.
1658 # find problems in the field, we'll need to revisit their use.
1657 sys.last_type = etype
1659 sys.last_type = etype
1658 sys.last_value = value
1660 sys.last_value = value
1659 sys.last_traceback = tb
1661 sys.last_traceback = tb
1660
1662
1661 return etype, value, tb
1663 return etype, value, tb
1662
1664
1663
1665
1664 def showtraceback(self,exc_tuple = None,filename=None,tb_offset=None,
1666 def showtraceback(self,exc_tuple = None,filename=None,tb_offset=None,
1665 exception_only=False):
1667 exception_only=False):
1666 """Display the exception that just occurred.
1668 """Display the exception that just occurred.
1667
1669
1668 If nothing is known about the exception, this is the method which
1670 If nothing is known about the exception, this is the method which
1669 should be used throughout the code for presenting user tracebacks,
1671 should be used throughout the code for presenting user tracebacks,
1670 rather than directly invoking the InteractiveTB object.
1672 rather than directly invoking the InteractiveTB object.
1671
1673
1672 A specific showsyntaxerror() also exists, but this method can take
1674 A specific showsyntaxerror() also exists, but this method can take
1673 care of calling it if needed, so unless you are explicitly catching a
1675 care of calling it if needed, so unless you are explicitly catching a
1674 SyntaxError exception, don't try to analyze the stack manually and
1676 SyntaxError exception, don't try to analyze the stack manually and
1675 simply call this method."""
1677 simply call this method."""
1676
1678
1677 try:
1679 try:
1678 try:
1680 try:
1679 etype, value, tb = self._get_exc_info(exc_tuple)
1681 etype, value, tb = self._get_exc_info(exc_tuple)
1680 except ValueError:
1682 except ValueError:
1681 self.write_err('No traceback available to show.\n')
1683 self.write_err('No traceback available to show.\n')
1682 return
1684 return
1683
1685
1684 if etype is SyntaxError:
1686 if etype is SyntaxError:
1685 # Though this won't be called by syntax errors in the input
1687 # Though this won't be called by syntax errors in the input
1686 # line, there may be SyntaxError cases with imported code.
1688 # line, there may be SyntaxError cases with imported code.
1687 self.showsyntaxerror(filename)
1689 self.showsyntaxerror(filename)
1688 elif etype is UsageError:
1690 elif etype is UsageError:
1689 self.write_err("UsageError: %s" % value)
1691 self.write_err("UsageError: %s" % value)
1690 else:
1692 else:
1691 if etype in self.custom_exceptions:
1693 if etype in self.custom_exceptions:
1692 stb = self.CustomTB(etype, value, tb, tb_offset)
1694 stb = self.CustomTB(etype, value, tb, tb_offset)
1693 else:
1695 else:
1694 if exception_only:
1696 if exception_only:
1695 stb = ['An exception has occurred, use %tb to see '
1697 stb = ['An exception has occurred, use %tb to see '
1696 'the full traceback.\n']
1698 'the full traceback.\n']
1697 stb.extend(self.InteractiveTB.get_exception_only(etype,
1699 stb.extend(self.InteractiveTB.get_exception_only(etype,
1698 value))
1700 value))
1699 else:
1701 else:
1700 stb = self.InteractiveTB.structured_traceback(etype,
1702 stb = self.InteractiveTB.structured_traceback(etype,
1701 value, tb, tb_offset=tb_offset)
1703 value, tb, tb_offset=tb_offset)
1702
1704
1703 self._showtraceback(etype, value, stb)
1705 self._showtraceback(etype, value, stb)
1704 if self.call_pdb:
1706 if self.call_pdb:
1705 # drop into debugger
1707 # drop into debugger
1706 self.debugger(force=True)
1708 self.debugger(force=True)
1707 return
1709 return
1708
1710
1709 # Actually show the traceback
1711 # Actually show the traceback
1710 self._showtraceback(etype, value, stb)
1712 self._showtraceback(etype, value, stb)
1711
1713
1712 except KeyboardInterrupt:
1714 except KeyboardInterrupt:
1713 self.write_err("\nKeyboardInterrupt\n")
1715 self.write_err("\nKeyboardInterrupt\n")
1714
1716
1715 def _showtraceback(self, etype, evalue, stb):
1717 def _showtraceback(self, etype, evalue, stb):
1716 """Actually show a traceback.
1718 """Actually show a traceback.
1717
1719
1718 Subclasses may override this method to put the traceback on a different
1720 Subclasses may override this method to put the traceback on a different
1719 place, like a side channel.
1721 place, like a side channel.
1720 """
1722 """
1721 print >> io.stdout, self.InteractiveTB.stb2text(stb)
1723 print >> io.stdout, self.InteractiveTB.stb2text(stb)
1722
1724
1723 def showsyntaxerror(self, filename=None):
1725 def showsyntaxerror(self, filename=None):
1724 """Display the syntax error that just occurred.
1726 """Display the syntax error that just occurred.
1725
1727
1726 This doesn't display a stack trace because there isn't one.
1728 This doesn't display a stack trace because there isn't one.
1727
1729
1728 If a filename is given, it is stuffed in the exception instead
1730 If a filename is given, it is stuffed in the exception instead
1729 of what was there before (because Python's parser always uses
1731 of what was there before (because Python's parser always uses
1730 "<string>" when reading from a string).
1732 "<string>" when reading from a string).
1731 """
1733 """
1732 etype, value, last_traceback = self._get_exc_info()
1734 etype, value, last_traceback = self._get_exc_info()
1733
1735
1734 if filename and etype is SyntaxError:
1736 if filename and etype is SyntaxError:
1735 try:
1737 try:
1736 value.filename = filename
1738 value.filename = filename
1737 except:
1739 except:
1738 # Not the format we expect; leave it alone
1740 # Not the format we expect; leave it alone
1739 pass
1741 pass
1740
1742
1741 stb = self.SyntaxTB.structured_traceback(etype, value, [])
1743 stb = self.SyntaxTB.structured_traceback(etype, value, [])
1742 self._showtraceback(etype, value, stb)
1744 self._showtraceback(etype, value, stb)
1743
1745
1744 # This is overridden in TerminalInteractiveShell to show a message about
1746 # This is overridden in TerminalInteractiveShell to show a message about
1745 # the %paste magic.
1747 # the %paste magic.
1746 def showindentationerror(self):
1748 def showindentationerror(self):
1747 """Called by run_cell when there's an IndentationError in code entered
1749 """Called by run_cell when there's an IndentationError in code entered
1748 at the prompt.
1750 at the prompt.
1749
1751
1750 This is overridden in TerminalInteractiveShell to show a message about
1752 This is overridden in TerminalInteractiveShell to show a message about
1751 the %paste magic."""
1753 the %paste magic."""
1752 self.showsyntaxerror()
1754 self.showsyntaxerror()
1753
1755
1754 #-------------------------------------------------------------------------
1756 #-------------------------------------------------------------------------
1755 # Things related to readline
1757 # Things related to readline
1756 #-------------------------------------------------------------------------
1758 #-------------------------------------------------------------------------
1757
1759
1758 def init_readline(self):
1760 def init_readline(self):
1759 """Command history completion/saving/reloading."""
1761 """Command history completion/saving/reloading."""
1760
1762
1761 if self.readline_use:
1763 if self.readline_use:
1762 import IPython.utils.rlineimpl as readline
1764 import IPython.utils.rlineimpl as readline
1763
1765
1764 self.rl_next_input = None
1766 self.rl_next_input = None
1765 self.rl_do_indent = False
1767 self.rl_do_indent = False
1766
1768
1767 if not self.readline_use or not readline.have_readline:
1769 if not self.readline_use or not readline.have_readline:
1768 self.has_readline = False
1770 self.has_readline = False
1769 self.readline = None
1771 self.readline = None
1770 # Set a number of methods that depend on readline to be no-op
1772 # Set a number of methods that depend on readline to be no-op
1771 self.readline_no_record = no_op_context
1773 self.readline_no_record = no_op_context
1772 self.set_readline_completer = no_op
1774 self.set_readline_completer = no_op
1773 self.set_custom_completer = no_op
1775 self.set_custom_completer = no_op
1774 self.set_completer_frame = no_op
1776 self.set_completer_frame = no_op
1775 if self.readline_use:
1777 if self.readline_use:
1776 warn('Readline services not available or not loaded.')
1778 warn('Readline services not available or not loaded.')
1777 else:
1779 else:
1778 self.has_readline = True
1780 self.has_readline = True
1779 self.readline = readline
1781 self.readline = readline
1780 sys.modules['readline'] = readline
1782 sys.modules['readline'] = readline
1781
1783
1782 # Platform-specific configuration
1784 # Platform-specific configuration
1783 if os.name == 'nt':
1785 if os.name == 'nt':
1784 # FIXME - check with Frederick to see if we can harmonize
1786 # FIXME - check with Frederick to see if we can harmonize
1785 # naming conventions with pyreadline to avoid this
1787 # naming conventions with pyreadline to avoid this
1786 # platform-dependent check
1788 # platform-dependent check
1787 self.readline_startup_hook = readline.set_pre_input_hook
1789 self.readline_startup_hook = readline.set_pre_input_hook
1788 else:
1790 else:
1789 self.readline_startup_hook = readline.set_startup_hook
1791 self.readline_startup_hook = readline.set_startup_hook
1790
1792
1791 # Load user's initrc file (readline config)
1793 # Load user's initrc file (readline config)
1792 # Or if libedit is used, load editrc.
1794 # Or if libedit is used, load editrc.
1793 inputrc_name = os.environ.get('INPUTRC')
1795 inputrc_name = os.environ.get('INPUTRC')
1794 if inputrc_name is None:
1796 if inputrc_name is None:
1795 inputrc_name = '.inputrc'
1797 inputrc_name = '.inputrc'
1796 if readline.uses_libedit:
1798 if readline.uses_libedit:
1797 inputrc_name = '.editrc'
1799 inputrc_name = '.editrc'
1798 inputrc_name = os.path.join(self.home_dir, inputrc_name)
1800 inputrc_name = os.path.join(self.home_dir, inputrc_name)
1799 if os.path.isfile(inputrc_name):
1801 if os.path.isfile(inputrc_name):
1800 try:
1802 try:
1801 readline.read_init_file(inputrc_name)
1803 readline.read_init_file(inputrc_name)
1802 except:
1804 except:
1803 warn('Problems reading readline initialization file <%s>'
1805 warn('Problems reading readline initialization file <%s>'
1804 % inputrc_name)
1806 % inputrc_name)
1805
1807
1806 # Configure readline according to user's prefs
1808 # Configure readline according to user's prefs
1807 # This is only done if GNU readline is being used. If libedit
1809 # This is only done if GNU readline is being used. If libedit
1808 # is being used (as on Leopard) the readline config is
1810 # is being used (as on Leopard) the readline config is
1809 # not run as the syntax for libedit is different.
1811 # not run as the syntax for libedit is different.
1810 if not readline.uses_libedit:
1812 if not readline.uses_libedit:
1811 for rlcommand in self.readline_parse_and_bind:
1813 for rlcommand in self.readline_parse_and_bind:
1812 #print "loading rl:",rlcommand # dbg
1814 #print "loading rl:",rlcommand # dbg
1813 readline.parse_and_bind(rlcommand)
1815 readline.parse_and_bind(rlcommand)
1814
1816
1815 # Remove some chars from the delimiters list. If we encounter
1817 # Remove some chars from the delimiters list. If we encounter
1816 # unicode chars, discard them.
1818 # unicode chars, discard them.
1817 delims = readline.get_completer_delims()
1819 delims = readline.get_completer_delims()
1818 if not py3compat.PY3:
1820 if not py3compat.PY3:
1819 delims = delims.encode("ascii", "ignore")
1821 delims = delims.encode("ascii", "ignore")
1820 for d in self.readline_remove_delims:
1822 for d in self.readline_remove_delims:
1821 delims = delims.replace(d, "")
1823 delims = delims.replace(d, "")
1822 delims = delims.replace(ESC_MAGIC, '')
1824 delims = delims.replace(ESC_MAGIC, '')
1823 readline.set_completer_delims(delims)
1825 readline.set_completer_delims(delims)
1824 # otherwise we end up with a monster history after a while:
1826 # otherwise we end up with a monster history after a while:
1825 readline.set_history_length(self.history_length)
1827 readline.set_history_length(self.history_length)
1826
1828
1827 self.refill_readline_hist()
1829 self.refill_readline_hist()
1828 self.readline_no_record = ReadlineNoRecord(self)
1830 self.readline_no_record = ReadlineNoRecord(self)
1829
1831
1830 # Configure auto-indent for all platforms
1832 # Configure auto-indent for all platforms
1831 self.set_autoindent(self.autoindent)
1833 self.set_autoindent(self.autoindent)
1832
1834
1833 def refill_readline_hist(self):
1835 def refill_readline_hist(self):
1834 # Load the last 1000 lines from history
1836 # Load the last 1000 lines from history
1835 self.readline.clear_history()
1837 self.readline.clear_history()
1836 stdin_encoding = sys.stdin.encoding or "utf-8"
1838 stdin_encoding = sys.stdin.encoding or "utf-8"
1837 last_cell = u""
1839 last_cell = u""
1838 for _, _, cell in self.history_manager.get_tail(1000,
1840 for _, _, cell in self.history_manager.get_tail(1000,
1839 include_latest=True):
1841 include_latest=True):
1840 # Ignore blank lines and consecutive duplicates
1842 # Ignore blank lines and consecutive duplicates
1841 cell = cell.rstrip()
1843 cell = cell.rstrip()
1842 if cell and (cell != last_cell):
1844 if cell and (cell != last_cell):
1843 if self.multiline_history:
1845 if self.multiline_history:
1844 self.readline.add_history(py3compat.unicode_to_str(cell,
1846 self.readline.add_history(py3compat.unicode_to_str(cell,
1845 stdin_encoding))
1847 stdin_encoding))
1846 else:
1848 else:
1847 for line in cell.splitlines():
1849 for line in cell.splitlines():
1848 self.readline.add_history(py3compat.unicode_to_str(line,
1850 self.readline.add_history(py3compat.unicode_to_str(line,
1849 stdin_encoding))
1851 stdin_encoding))
1850 last_cell = cell
1852 last_cell = cell
1851
1853
1852 def set_next_input(self, s):
1854 def set_next_input(self, s):
1853 """ Sets the 'default' input string for the next command line.
1855 """ Sets the 'default' input string for the next command line.
1854
1856
1855 Requires readline.
1857 Requires readline.
1856
1858
1857 Example:
1859 Example:
1858
1860
1859 [D:\ipython]|1> _ip.set_next_input("Hello Word")
1861 [D:\ipython]|1> _ip.set_next_input("Hello Word")
1860 [D:\ipython]|2> Hello Word_ # cursor is here
1862 [D:\ipython]|2> Hello Word_ # cursor is here
1861 """
1863 """
1862 self.rl_next_input = py3compat.cast_bytes_py2(s)
1864 self.rl_next_input = py3compat.cast_bytes_py2(s)
1863
1865
1864 # Maybe move this to the terminal subclass?
1866 # Maybe move this to the terminal subclass?
1865 def pre_readline(self):
1867 def pre_readline(self):
1866 """readline hook to be used at the start of each line.
1868 """readline hook to be used at the start of each line.
1867
1869
1868 Currently it handles auto-indent only."""
1870 Currently it handles auto-indent only."""
1869
1871
1870 if self.rl_do_indent:
1872 if self.rl_do_indent:
1871 self.readline.insert_text(self._indent_current_str())
1873 self.readline.insert_text(self._indent_current_str())
1872 if self.rl_next_input is not None:
1874 if self.rl_next_input is not None:
1873 self.readline.insert_text(self.rl_next_input)
1875 self.readline.insert_text(self.rl_next_input)
1874 self.rl_next_input = None
1876 self.rl_next_input = None
1875
1877
1876 def _indent_current_str(self):
1878 def _indent_current_str(self):
1877 """return the current level of indentation as a string"""
1879 """return the current level of indentation as a string"""
1878 return self.input_splitter.indent_spaces * ' '
1880 return self.input_splitter.indent_spaces * ' '
1879
1881
1880 #-------------------------------------------------------------------------
1882 #-------------------------------------------------------------------------
1881 # Things related to text completion
1883 # Things related to text completion
1882 #-------------------------------------------------------------------------
1884 #-------------------------------------------------------------------------
1883
1885
1884 def init_completer(self):
1886 def init_completer(self):
1885 """Initialize the completion machinery.
1887 """Initialize the completion machinery.
1886
1888
1887 This creates completion machinery that can be used by client code,
1889 This creates completion machinery that can be used by client code,
1888 either interactively in-process (typically triggered by the readline
1890 either interactively in-process (typically triggered by the readline
1889 library), programatically (such as in test suites) or out-of-prcess
1891 library), programatically (such as in test suites) or out-of-prcess
1890 (typically over the network by remote frontends).
1892 (typically over the network by remote frontends).
1891 """
1893 """
1892 from IPython.core.completer import IPCompleter
1894 from IPython.core.completer import IPCompleter
1893 from IPython.core.completerlib import (module_completer,
1895 from IPython.core.completerlib import (module_completer,
1894 magic_run_completer, cd_completer, reset_completer)
1896 magic_run_completer, cd_completer, reset_completer)
1895
1897
1896 self.Completer = IPCompleter(shell=self,
1898 self.Completer = IPCompleter(shell=self,
1897 namespace=self.user_ns,
1899 namespace=self.user_ns,
1898 global_namespace=self.user_global_ns,
1900 global_namespace=self.user_global_ns,
1899 alias_table=self.alias_manager.alias_table,
1901 alias_table=self.alias_manager.alias_table,
1900 use_readline=self.has_readline,
1902 use_readline=self.has_readline,
1901 config=self.config,
1903 config=self.config,
1902 )
1904 )
1903 self.configurables.append(self.Completer)
1905 self.configurables.append(self.Completer)
1904
1906
1905 # Add custom completers to the basic ones built into IPCompleter
1907 # Add custom completers to the basic ones built into IPCompleter
1906 sdisp = self.strdispatchers.get('complete_command', StrDispatch())
1908 sdisp = self.strdispatchers.get('complete_command', StrDispatch())
1907 self.strdispatchers['complete_command'] = sdisp
1909 self.strdispatchers['complete_command'] = sdisp
1908 self.Completer.custom_completers = sdisp
1910 self.Completer.custom_completers = sdisp
1909
1911
1910 self.set_hook('complete_command', module_completer, str_key = 'import')
1912 self.set_hook('complete_command', module_completer, str_key = 'import')
1911 self.set_hook('complete_command', module_completer, str_key = 'from')
1913 self.set_hook('complete_command', module_completer, str_key = 'from')
1912 self.set_hook('complete_command', magic_run_completer, str_key = '%run')
1914 self.set_hook('complete_command', magic_run_completer, str_key = '%run')
1913 self.set_hook('complete_command', cd_completer, str_key = '%cd')
1915 self.set_hook('complete_command', cd_completer, str_key = '%cd')
1914 self.set_hook('complete_command', reset_completer, str_key = '%reset')
1916 self.set_hook('complete_command', reset_completer, str_key = '%reset')
1915
1917
1916 # Only configure readline if we truly are using readline. IPython can
1918 # Only configure readline if we truly are using readline. IPython can
1917 # do tab-completion over the network, in GUIs, etc, where readline
1919 # do tab-completion over the network, in GUIs, etc, where readline
1918 # itself may be absent
1920 # itself may be absent
1919 if self.has_readline:
1921 if self.has_readline:
1920 self.set_readline_completer()
1922 self.set_readline_completer()
1921
1923
1922 def complete(self, text, line=None, cursor_pos=None):
1924 def complete(self, text, line=None, cursor_pos=None):
1923 """Return the completed text and a list of completions.
1925 """Return the completed text and a list of completions.
1924
1926
1925 Parameters
1927 Parameters
1926 ----------
1928 ----------
1927
1929
1928 text : string
1930 text : string
1929 A string of text to be completed on. It can be given as empty and
1931 A string of text to be completed on. It can be given as empty and
1930 instead a line/position pair are given. In this case, the
1932 instead a line/position pair are given. In this case, the
1931 completer itself will split the line like readline does.
1933 completer itself will split the line like readline does.
1932
1934
1933 line : string, optional
1935 line : string, optional
1934 The complete line that text is part of.
1936 The complete line that text is part of.
1935
1937
1936 cursor_pos : int, optional
1938 cursor_pos : int, optional
1937 The position of the cursor on the input line.
1939 The position of the cursor on the input line.
1938
1940
1939 Returns
1941 Returns
1940 -------
1942 -------
1941 text : string
1943 text : string
1942 The actual text that was completed.
1944 The actual text that was completed.
1943
1945
1944 matches : list
1946 matches : list
1945 A sorted list with all possible completions.
1947 A sorted list with all possible completions.
1946
1948
1947 The optional arguments allow the completion to take more context into
1949 The optional arguments allow the completion to take more context into
1948 account, and are part of the low-level completion API.
1950 account, and are part of the low-level completion API.
1949
1951
1950 This is a wrapper around the completion mechanism, similar to what
1952 This is a wrapper around the completion mechanism, similar to what
1951 readline does at the command line when the TAB key is hit. By
1953 readline does at the command line when the TAB key is hit. By
1952 exposing it as a method, it can be used by other non-readline
1954 exposing it as a method, it can be used by other non-readline
1953 environments (such as GUIs) for text completion.
1955 environments (such as GUIs) for text completion.
1954
1956
1955 Simple usage example:
1957 Simple usage example:
1956
1958
1957 In [1]: x = 'hello'
1959 In [1]: x = 'hello'
1958
1960
1959 In [2]: _ip.complete('x.l')
1961 In [2]: _ip.complete('x.l')
1960 Out[2]: ('x.l', ['x.ljust', 'x.lower', 'x.lstrip'])
1962 Out[2]: ('x.l', ['x.ljust', 'x.lower', 'x.lstrip'])
1961 """
1963 """
1962
1964
1963 # Inject names into __builtin__ so we can complete on the added names.
1965 # Inject names into __builtin__ so we can complete on the added names.
1964 with self.builtin_trap:
1966 with self.builtin_trap:
1965 return self.Completer.complete(text, line, cursor_pos)
1967 return self.Completer.complete(text, line, cursor_pos)
1966
1968
1967 def set_custom_completer(self, completer, pos=0):
1969 def set_custom_completer(self, completer, pos=0):
1968 """Adds a new custom completer function.
1970 """Adds a new custom completer function.
1969
1971
1970 The position argument (defaults to 0) is the index in the completers
1972 The position argument (defaults to 0) is the index in the completers
1971 list where you want the completer to be inserted."""
1973 list where you want the completer to be inserted."""
1972
1974
1973 newcomp = types.MethodType(completer,self.Completer)
1975 newcomp = types.MethodType(completer,self.Completer)
1974 self.Completer.matchers.insert(pos,newcomp)
1976 self.Completer.matchers.insert(pos,newcomp)
1975
1977
1976 def set_readline_completer(self):
1978 def set_readline_completer(self):
1977 """Reset readline's completer to be our own."""
1979 """Reset readline's completer to be our own."""
1978 self.readline.set_completer(self.Completer.rlcomplete)
1980 self.readline.set_completer(self.Completer.rlcomplete)
1979
1981
1980 def set_completer_frame(self, frame=None):
1982 def set_completer_frame(self, frame=None):
1981 """Set the frame of the completer."""
1983 """Set the frame of the completer."""
1982 if frame:
1984 if frame:
1983 self.Completer.namespace = frame.f_locals
1985 self.Completer.namespace = frame.f_locals
1984 self.Completer.global_namespace = frame.f_globals
1986 self.Completer.global_namespace = frame.f_globals
1985 else:
1987 else:
1986 self.Completer.namespace = self.user_ns
1988 self.Completer.namespace = self.user_ns
1987 self.Completer.global_namespace = self.user_global_ns
1989 self.Completer.global_namespace = self.user_global_ns
1988
1990
1989 #-------------------------------------------------------------------------
1991 #-------------------------------------------------------------------------
1990 # Things related to magics
1992 # Things related to magics
1991 #-------------------------------------------------------------------------
1993 #-------------------------------------------------------------------------
1992
1994
1993 def init_magics(self):
1995 def init_magics(self):
1994 # FIXME: Move the color initialization to the DisplayHook, which
1996 # FIXME: Move the color initialization to the DisplayHook, which
1995 # should be split into a prompt manager and displayhook. We probably
1997 # should be split into a prompt manager and displayhook. We probably
1996 # even need a centralize colors management object.
1998 # even need a centralize colors management object.
1997 self.magic_colors(self.colors)
1999 self.magic_colors(self.colors)
1998 # History was moved to a separate module
2000 # History was moved to a separate module
1999 from IPython.core import history
2001 from IPython.core import history
2000 history.init_ipython(self)
2002 history.init_ipython(self)
2001
2003
2002 def magic(self, arg_s, next_input=None):
2004 def magic(self, arg_s, next_input=None):
2003 """Call a magic function by name.
2005 """Call a magic function by name.
2004
2006
2005 Input: a string containing the name of the magic function to call and
2007 Input: a string containing the name of the magic function to call and
2006 any additional arguments to be passed to the magic.
2008 any additional arguments to be passed to the magic.
2007
2009
2008 magic('name -opt foo bar') is equivalent to typing at the ipython
2010 magic('name -opt foo bar') is equivalent to typing at the ipython
2009 prompt:
2011 prompt:
2010
2012
2011 In[1]: %name -opt foo bar
2013 In[1]: %name -opt foo bar
2012
2014
2013 To call a magic without arguments, simply use magic('name').
2015 To call a magic without arguments, simply use magic('name').
2014
2016
2015 This provides a proper Python function to call IPython's magics in any
2017 This provides a proper Python function to call IPython's magics in any
2016 valid Python code you can type at the interpreter, including loops and
2018 valid Python code you can type at the interpreter, including loops and
2017 compound statements.
2019 compound statements.
2018 """
2020 """
2019 # Allow setting the next input - this is used if the user does `a=abs?`.
2021 # Allow setting the next input - this is used if the user does `a=abs?`.
2020 # We do this first so that magic functions can override it.
2022 # We do this first so that magic functions can override it.
2021 if next_input:
2023 if next_input:
2022 self.set_next_input(next_input)
2024 self.set_next_input(next_input)
2023
2025
2024 magic_name, _, magic_args = arg_s.partition(' ')
2026 magic_name, _, magic_args = arg_s.partition(' ')
2025 magic_name = magic_name.lstrip(prefilter.ESC_MAGIC)
2027 magic_name = magic_name.lstrip(prefilter.ESC_MAGIC)
2026
2028
2027 fn = getattr(self,'magic_'+magic_name,None)
2029 fn = getattr(self,'magic_'+magic_name,None)
2028 if fn is None:
2030 if fn is None:
2029 error("Magic function `%s` not found." % magic_name)
2031 error("Magic function `%s` not found." % magic_name)
2030 else:
2032 else:
2031 magic_args = self.var_expand(magic_args,1)
2033 magic_args = self.var_expand(magic_args,1)
2032 # Grab local namespace if we need it:
2034 # Grab local namespace if we need it:
2033 if getattr(fn, "needs_local_scope", False):
2035 if getattr(fn, "needs_local_scope", False):
2034 self._magic_locals = sys._getframe(1).f_locals
2036 self._magic_locals = sys._getframe(1).f_locals
2035 with self.builtin_trap:
2037 with self.builtin_trap:
2036 result = fn(magic_args)
2038 result = fn(magic_args)
2037 # Ensure we're not keeping object references around:
2039 # Ensure we're not keeping object references around:
2038 self._magic_locals = {}
2040 self._magic_locals = {}
2039 return result
2041 return result
2040
2042
2041 def define_magic(self, magicname, func):
2043 def define_magic(self, magicname, func):
2042 """Expose own function as magic function for ipython
2044 """Expose own function as magic function for ipython
2043
2045
2044 Example::
2046 Example::
2045
2047
2046 def foo_impl(self,parameter_s=''):
2048 def foo_impl(self,parameter_s=''):
2047 'My very own magic!. (Use docstrings, IPython reads them).'
2049 'My very own magic!. (Use docstrings, IPython reads them).'
2048 print 'Magic function. Passed parameter is between < >:'
2050 print 'Magic function. Passed parameter is between < >:'
2049 print '<%s>' % parameter_s
2051 print '<%s>' % parameter_s
2050 print 'The self object is:', self
2052 print 'The self object is:', self
2051
2053
2052 ip.define_magic('foo',foo_impl)
2054 ip.define_magic('foo',foo_impl)
2053 """
2055 """
2054 im = types.MethodType(func,self)
2056 im = types.MethodType(func,self)
2055 old = getattr(self, "magic_" + magicname, None)
2057 old = getattr(self, "magic_" + magicname, None)
2056 setattr(self, "magic_" + magicname, im)
2058 setattr(self, "magic_" + magicname, im)
2057 return old
2059 return old
2058
2060
2059 #-------------------------------------------------------------------------
2061 #-------------------------------------------------------------------------
2060 # Things related to macros
2062 # Things related to macros
2061 #-------------------------------------------------------------------------
2063 #-------------------------------------------------------------------------
2062
2064
2063 def define_macro(self, name, themacro):
2065 def define_macro(self, name, themacro):
2064 """Define a new macro
2066 """Define a new macro
2065
2067
2066 Parameters
2068 Parameters
2067 ----------
2069 ----------
2068 name : str
2070 name : str
2069 The name of the macro.
2071 The name of the macro.
2070 themacro : str or Macro
2072 themacro : str or Macro
2071 The action to do upon invoking the macro. If a string, a new
2073 The action to do upon invoking the macro. If a string, a new
2072 Macro object is created by passing the string to it.
2074 Macro object is created by passing the string to it.
2073 """
2075 """
2074
2076
2075 from IPython.core import macro
2077 from IPython.core import macro
2076
2078
2077 if isinstance(themacro, basestring):
2079 if isinstance(themacro, basestring):
2078 themacro = macro.Macro(themacro)
2080 themacro = macro.Macro(themacro)
2079 if not isinstance(themacro, macro.Macro):
2081 if not isinstance(themacro, macro.Macro):
2080 raise ValueError('A macro must be a string or a Macro instance.')
2082 raise ValueError('A macro must be a string or a Macro instance.')
2081 self.user_ns[name] = themacro
2083 self.user_ns[name] = themacro
2082
2084
2083 #-------------------------------------------------------------------------
2085 #-------------------------------------------------------------------------
2084 # Things related to the running of system commands
2086 # Things related to the running of system commands
2085 #-------------------------------------------------------------------------
2087 #-------------------------------------------------------------------------
2086
2088
2087 def system_piped(self, cmd):
2089 def system_piped(self, cmd):
2088 """Call the given cmd in a subprocess, piping stdout/err
2090 """Call the given cmd in a subprocess, piping stdout/err
2089
2091
2090 Parameters
2092 Parameters
2091 ----------
2093 ----------
2092 cmd : str
2094 cmd : str
2093 Command to execute (can not end in '&', as background processes are
2095 Command to execute (can not end in '&', as background processes are
2094 not supported. Should not be a command that expects input
2096 not supported. Should not be a command that expects input
2095 other than simple text.
2097 other than simple text.
2096 """
2098 """
2097 if cmd.rstrip().endswith('&'):
2099 if cmd.rstrip().endswith('&'):
2098 # this is *far* from a rigorous test
2100 # this is *far* from a rigorous test
2099 # We do not support backgrounding processes because we either use
2101 # We do not support backgrounding processes because we either use
2100 # pexpect or pipes to read from. Users can always just call
2102 # pexpect or pipes to read from. Users can always just call
2101 # os.system() or use ip.system=ip.system_raw
2103 # os.system() or use ip.system=ip.system_raw
2102 # if they really want a background process.
2104 # if they really want a background process.
2103 raise OSError("Background processes not supported.")
2105 raise OSError("Background processes not supported.")
2104
2106
2105 # we explicitly do NOT return the subprocess status code, because
2107 # we explicitly do NOT return the subprocess status code, because
2106 # a non-None value would trigger :func:`sys.displayhook` calls.
2108 # a non-None value would trigger :func:`sys.displayhook` calls.
2107 # Instead, we store the exit_code in user_ns.
2109 # Instead, we store the exit_code in user_ns.
2108 self.user_ns['_exit_code'] = system(self.var_expand(cmd, depth=2))
2110 self.user_ns['_exit_code'] = system(self.var_expand(cmd, depth=2))
2109
2111
2110 def system_raw(self, cmd):
2112 def system_raw(self, cmd):
2111 """Call the given cmd in a subprocess using os.system
2113 """Call the given cmd in a subprocess using os.system
2112
2114
2113 Parameters
2115 Parameters
2114 ----------
2116 ----------
2115 cmd : str
2117 cmd : str
2116 Command to execute.
2118 Command to execute.
2117 """
2119 """
2118 cmd = self.var_expand(cmd, depth=2)
2120 cmd = self.var_expand(cmd, depth=2)
2119 # protect os.system from UNC paths on Windows, which it can't handle:
2121 # protect os.system from UNC paths on Windows, which it can't handle:
2120 if sys.platform == 'win32':
2122 if sys.platform == 'win32':
2121 from IPython.utils._process_win32 import AvoidUNCPath
2123 from IPython.utils._process_win32 import AvoidUNCPath
2122 with AvoidUNCPath() as path:
2124 with AvoidUNCPath() as path:
2123 if path is not None:
2125 if path is not None:
2124 cmd = '"pushd %s &&"%s' % (path, cmd)
2126 cmd = '"pushd %s &&"%s' % (path, cmd)
2125 cmd = py3compat.unicode_to_str(cmd)
2127 cmd = py3compat.unicode_to_str(cmd)
2126 ec = os.system(cmd)
2128 ec = os.system(cmd)
2127 else:
2129 else:
2128 cmd = py3compat.unicode_to_str(cmd)
2130 cmd = py3compat.unicode_to_str(cmd)
2129 ec = os.system(cmd)
2131 ec = os.system(cmd)
2130
2132
2131 # We explicitly do NOT return the subprocess status code, because
2133 # We explicitly do NOT return the subprocess status code, because
2132 # a non-None value would trigger :func:`sys.displayhook` calls.
2134 # a non-None value would trigger :func:`sys.displayhook` calls.
2133 # Instead, we store the exit_code in user_ns.
2135 # Instead, we store the exit_code in user_ns.
2134 self.user_ns['_exit_code'] = ec
2136 self.user_ns['_exit_code'] = ec
2135
2137
2136 # use piped system by default, because it is better behaved
2138 # use piped system by default, because it is better behaved
2137 system = system_piped
2139 system = system_piped
2138
2140
2139 def getoutput(self, cmd, split=True):
2141 def getoutput(self, cmd, split=True):
2140 """Get output (possibly including stderr) from a subprocess.
2142 """Get output (possibly including stderr) from a subprocess.
2141
2143
2142 Parameters
2144 Parameters
2143 ----------
2145 ----------
2144 cmd : str
2146 cmd : str
2145 Command to execute (can not end in '&', as background processes are
2147 Command to execute (can not end in '&', as background processes are
2146 not supported.
2148 not supported.
2147 split : bool, optional
2149 split : bool, optional
2148
2150
2149 If True, split the output into an IPython SList. Otherwise, an
2151 If True, split the output into an IPython SList. Otherwise, an
2150 IPython LSString is returned. These are objects similar to normal
2152 IPython LSString is returned. These are objects similar to normal
2151 lists and strings, with a few convenience attributes for easier
2153 lists and strings, with a few convenience attributes for easier
2152 manipulation of line-based output. You can use '?' on them for
2154 manipulation of line-based output. You can use '?' on them for
2153 details.
2155 details.
2154 """
2156 """
2155 if cmd.rstrip().endswith('&'):
2157 if cmd.rstrip().endswith('&'):
2156 # this is *far* from a rigorous test
2158 # this is *far* from a rigorous test
2157 raise OSError("Background processes not supported.")
2159 raise OSError("Background processes not supported.")
2158 out = getoutput(self.var_expand(cmd, depth=2))
2160 out = getoutput(self.var_expand(cmd, depth=2))
2159 if split:
2161 if split:
2160 out = SList(out.splitlines())
2162 out = SList(out.splitlines())
2161 else:
2163 else:
2162 out = LSString(out)
2164 out = LSString(out)
2163 return out
2165 return out
2164
2166
2165 #-------------------------------------------------------------------------
2167 #-------------------------------------------------------------------------
2166 # Things related to aliases
2168 # Things related to aliases
2167 #-------------------------------------------------------------------------
2169 #-------------------------------------------------------------------------
2168
2170
2169 def init_alias(self):
2171 def init_alias(self):
2170 self.alias_manager = AliasManager(shell=self, config=self.config)
2172 self.alias_manager = AliasManager(shell=self, config=self.config)
2171 self.configurables.append(self.alias_manager)
2173 self.configurables.append(self.alias_manager)
2172 self.ns_table['alias'] = self.alias_manager.alias_table,
2174 self.ns_table['alias'] = self.alias_manager.alias_table,
2173
2175
2174 #-------------------------------------------------------------------------
2176 #-------------------------------------------------------------------------
2175 # Things related to extensions and plugins
2177 # Things related to extensions and plugins
2176 #-------------------------------------------------------------------------
2178 #-------------------------------------------------------------------------
2177
2179
2178 def init_extension_manager(self):
2180 def init_extension_manager(self):
2179 self.extension_manager = ExtensionManager(shell=self, config=self.config)
2181 self.extension_manager = ExtensionManager(shell=self, config=self.config)
2180 self.configurables.append(self.extension_manager)
2182 self.configurables.append(self.extension_manager)
2181
2183
2182 def init_plugin_manager(self):
2184 def init_plugin_manager(self):
2183 self.plugin_manager = PluginManager(config=self.config)
2185 self.plugin_manager = PluginManager(config=self.config)
2184 self.configurables.append(self.plugin_manager)
2186 self.configurables.append(self.plugin_manager)
2185
2187
2186
2188
2187 #-------------------------------------------------------------------------
2189 #-------------------------------------------------------------------------
2188 # Things related to payloads
2190 # Things related to payloads
2189 #-------------------------------------------------------------------------
2191 #-------------------------------------------------------------------------
2190
2192
2191 def init_payload(self):
2193 def init_payload(self):
2192 self.payload_manager = PayloadManager(config=self.config)
2194 self.payload_manager = PayloadManager(config=self.config)
2193 self.configurables.append(self.payload_manager)
2195 self.configurables.append(self.payload_manager)
2194
2196
2195 #-------------------------------------------------------------------------
2197 #-------------------------------------------------------------------------
2196 # Things related to the prefilter
2198 # Things related to the prefilter
2197 #-------------------------------------------------------------------------
2199 #-------------------------------------------------------------------------
2198
2200
2199 def init_prefilter(self):
2201 def init_prefilter(self):
2200 self.prefilter_manager = PrefilterManager(shell=self, config=self.config)
2202 self.prefilter_manager = PrefilterManager(shell=self, config=self.config)
2201 self.configurables.append(self.prefilter_manager)
2203 self.configurables.append(self.prefilter_manager)
2202 # Ultimately this will be refactored in the new interpreter code, but
2204 # Ultimately this will be refactored in the new interpreter code, but
2203 # for now, we should expose the main prefilter method (there's legacy
2205 # for now, we should expose the main prefilter method (there's legacy
2204 # code out there that may rely on this).
2206 # code out there that may rely on this).
2205 self.prefilter = self.prefilter_manager.prefilter_lines
2207 self.prefilter = self.prefilter_manager.prefilter_lines
2206
2208
2207 def auto_rewrite_input(self, cmd):
2209 def auto_rewrite_input(self, cmd):
2208 """Print to the screen the rewritten form of the user's command.
2210 """Print to the screen the rewritten form of the user's command.
2209
2211
2210 This shows visual feedback by rewriting input lines that cause
2212 This shows visual feedback by rewriting input lines that cause
2211 automatic calling to kick in, like::
2213 automatic calling to kick in, like::
2212
2214
2213 /f x
2215 /f x
2214
2216
2215 into::
2217 into::
2216
2218
2217 ------> f(x)
2219 ------> f(x)
2218
2220
2219 after the user's input prompt. This helps the user understand that the
2221 after the user's input prompt. This helps the user understand that the
2220 input line was transformed automatically by IPython.
2222 input line was transformed automatically by IPython.
2221 """
2223 """
2222 if not self.show_rewritten_input:
2224 if not self.show_rewritten_input:
2223 return
2225 return
2224
2226
2225 rw = self.prompt_manager.render('rewrite') + cmd
2227 rw = self.prompt_manager.render('rewrite') + cmd
2226
2228
2227 try:
2229 try:
2228 # plain ascii works better w/ pyreadline, on some machines, so
2230 # plain ascii works better w/ pyreadline, on some machines, so
2229 # we use it and only print uncolored rewrite if we have unicode
2231 # we use it and only print uncolored rewrite if we have unicode
2230 rw = str(rw)
2232 rw = str(rw)
2231 print >> io.stdout, rw
2233 print >> io.stdout, rw
2232 except UnicodeEncodeError:
2234 except UnicodeEncodeError:
2233 print "------> " + cmd
2235 print "------> " + cmd
2234
2236
2235 #-------------------------------------------------------------------------
2237 #-------------------------------------------------------------------------
2236 # Things related to extracting values/expressions from kernel and user_ns
2238 # Things related to extracting values/expressions from kernel and user_ns
2237 #-------------------------------------------------------------------------
2239 #-------------------------------------------------------------------------
2238
2240
2239 def _simple_error(self):
2241 def _simple_error(self):
2240 etype, value = sys.exc_info()[:2]
2242 etype, value = sys.exc_info()[:2]
2241 return u'[ERROR] {e.__name__}: {v}'.format(e=etype, v=value)
2243 return u'[ERROR] {e.__name__}: {v}'.format(e=etype, v=value)
2242
2244
2243 def user_variables(self, names):
2245 def user_variables(self, names):
2244 """Get a list of variable names from the user's namespace.
2246 """Get a list of variable names from the user's namespace.
2245
2247
2246 Parameters
2248 Parameters
2247 ----------
2249 ----------
2248 names : list of strings
2250 names : list of strings
2249 A list of names of variables to be read from the user namespace.
2251 A list of names of variables to be read from the user namespace.
2250
2252
2251 Returns
2253 Returns
2252 -------
2254 -------
2253 A dict, keyed by the input names and with the repr() of each value.
2255 A dict, keyed by the input names and with the repr() of each value.
2254 """
2256 """
2255 out = {}
2257 out = {}
2256 user_ns = self.user_ns
2258 user_ns = self.user_ns
2257 for varname in names:
2259 for varname in names:
2258 try:
2260 try:
2259 value = repr(user_ns[varname])
2261 value = repr(user_ns[varname])
2260 except:
2262 except:
2261 value = self._simple_error()
2263 value = self._simple_error()
2262 out[varname] = value
2264 out[varname] = value
2263 return out
2265 return out
2264
2266
2265 def user_expressions(self, expressions):
2267 def user_expressions(self, expressions):
2266 """Evaluate a dict of expressions in the user's namespace.
2268 """Evaluate a dict of expressions in the user's namespace.
2267
2269
2268 Parameters
2270 Parameters
2269 ----------
2271 ----------
2270 expressions : dict
2272 expressions : dict
2271 A dict with string keys and string values. The expression values
2273 A dict with string keys and string values. The expression values
2272 should be valid Python expressions, each of which will be evaluated
2274 should be valid Python expressions, each of which will be evaluated
2273 in the user namespace.
2275 in the user namespace.
2274
2276
2275 Returns
2277 Returns
2276 -------
2278 -------
2277 A dict, keyed like the input expressions dict, with the repr() of each
2279 A dict, keyed like the input expressions dict, with the repr() of each
2278 value.
2280 value.
2279 """
2281 """
2280 out = {}
2282 out = {}
2281 user_ns = self.user_ns
2283 user_ns = self.user_ns
2282 global_ns = self.user_global_ns
2284 global_ns = self.user_global_ns
2283 for key, expr in expressions.iteritems():
2285 for key, expr in expressions.iteritems():
2284 try:
2286 try:
2285 value = repr(eval(expr, global_ns, user_ns))
2287 value = repr(eval(expr, global_ns, user_ns))
2286 except:
2288 except:
2287 value = self._simple_error()
2289 value = self._simple_error()
2288 out[key] = value
2290 out[key] = value
2289 return out
2291 return out
2290
2292
2291 #-------------------------------------------------------------------------
2293 #-------------------------------------------------------------------------
2292 # Things related to the running of code
2294 # Things related to the running of code
2293 #-------------------------------------------------------------------------
2295 #-------------------------------------------------------------------------
2294
2296
2295 def ex(self, cmd):
2297 def ex(self, cmd):
2296 """Execute a normal python statement in user namespace."""
2298 """Execute a normal python statement in user namespace."""
2297 with self.builtin_trap:
2299 with self.builtin_trap:
2298 exec cmd in self.user_global_ns, self.user_ns
2300 exec cmd in self.user_global_ns, self.user_ns
2299
2301
2300 def ev(self, expr):
2302 def ev(self, expr):
2301 """Evaluate python expression expr in user namespace.
2303 """Evaluate python expression expr in user namespace.
2302
2304
2303 Returns the result of evaluation
2305 Returns the result of evaluation
2304 """
2306 """
2305 with self.builtin_trap:
2307 with self.builtin_trap:
2306 return eval(expr, self.user_global_ns, self.user_ns)
2308 return eval(expr, self.user_global_ns, self.user_ns)
2307
2309
2308 def safe_execfile(self, fname, *where, **kw):
2310 def safe_execfile(self, fname, *where, **kw):
2309 """A safe version of the builtin execfile().
2311 """A safe version of the builtin execfile().
2310
2312
2311 This version will never throw an exception, but instead print
2313 This version will never throw an exception, but instead print
2312 helpful error messages to the screen. This only works on pure
2314 helpful error messages to the screen. This only works on pure
2313 Python files with the .py extension.
2315 Python files with the .py extension.
2314
2316
2315 Parameters
2317 Parameters
2316 ----------
2318 ----------
2317 fname : string
2319 fname : string
2318 The name of the file to be executed.
2320 The name of the file to be executed.
2319 where : tuple
2321 where : tuple
2320 One or two namespaces, passed to execfile() as (globals,locals).
2322 One or two namespaces, passed to execfile() as (globals,locals).
2321 If only one is given, it is passed as both.
2323 If only one is given, it is passed as both.
2322 exit_ignore : bool (False)
2324 exit_ignore : bool (False)
2323 If True, then silence SystemExit for non-zero status (it is always
2325 If True, then silence SystemExit for non-zero status (it is always
2324 silenced for zero status, as it is so common).
2326 silenced for zero status, as it is so common).
2325 raise_exceptions : bool (False)
2327 raise_exceptions : bool (False)
2326 If True raise exceptions everywhere. Meant for testing.
2328 If True raise exceptions everywhere. Meant for testing.
2327
2329
2328 """
2330 """
2329 kw.setdefault('exit_ignore', False)
2331 kw.setdefault('exit_ignore', False)
2330 kw.setdefault('raise_exceptions', False)
2332 kw.setdefault('raise_exceptions', False)
2331
2333
2332 fname = os.path.abspath(os.path.expanduser(fname))
2334 fname = os.path.abspath(os.path.expanduser(fname))
2333
2335
2334 # Make sure we can open the file
2336 # Make sure we can open the file
2335 try:
2337 try:
2336 with open(fname) as thefile:
2338 with open(fname) as thefile:
2337 pass
2339 pass
2338 except:
2340 except:
2339 warn('Could not open file <%s> for safe execution.' % fname)
2341 warn('Could not open file <%s> for safe execution.' % fname)
2340 return
2342 return
2341
2343
2342 # Find things also in current directory. This is needed to mimic the
2344 # Find things also in current directory. This is needed to mimic the
2343 # behavior of running a script from the system command line, where
2345 # behavior of running a script from the system command line, where
2344 # Python inserts the script's directory into sys.path
2346 # Python inserts the script's directory into sys.path
2345 dname = os.path.dirname(fname)
2347 dname = os.path.dirname(fname)
2346
2348
2347 with prepended_to_syspath(dname):
2349 with prepended_to_syspath(dname):
2348 try:
2350 try:
2349 py3compat.execfile(fname,*where)
2351 py3compat.execfile(fname,*where)
2350 except SystemExit, status:
2352 except SystemExit, status:
2351 # If the call was made with 0 or None exit status (sys.exit(0)
2353 # If the call was made with 0 or None exit status (sys.exit(0)
2352 # or sys.exit() ), don't bother showing a traceback, as both of
2354 # or sys.exit() ), don't bother showing a traceback, as both of
2353 # these are considered normal by the OS:
2355 # these are considered normal by the OS:
2354 # > python -c'import sys;sys.exit(0)'; echo $?
2356 # > python -c'import sys;sys.exit(0)'; echo $?
2355 # 0
2357 # 0
2356 # > python -c'import sys;sys.exit()'; echo $?
2358 # > python -c'import sys;sys.exit()'; echo $?
2357 # 0
2359 # 0
2358 # For other exit status, we show the exception unless
2360 # For other exit status, we show the exception unless
2359 # explicitly silenced, but only in short form.
2361 # explicitly silenced, but only in short form.
2360 if kw['raise_exceptions']:
2362 if kw['raise_exceptions']:
2361 raise
2363 raise
2362 if status.code not in (0, None) and not kw['exit_ignore']:
2364 if status.code not in (0, None) and not kw['exit_ignore']:
2363 self.showtraceback(exception_only=True)
2365 self.showtraceback(exception_only=True)
2364 except:
2366 except:
2365 if kw['raise_exceptions']:
2367 if kw['raise_exceptions']:
2366 raise
2368 raise
2367 self.showtraceback()
2369 self.showtraceback()
2368
2370
2369 def safe_execfile_ipy(self, fname):
2371 def safe_execfile_ipy(self, fname):
2370 """Like safe_execfile, but for .ipy files with IPython syntax.
2372 """Like safe_execfile, but for .ipy files with IPython syntax.
2371
2373
2372 Parameters
2374 Parameters
2373 ----------
2375 ----------
2374 fname : str
2376 fname : str
2375 The name of the file to execute. The filename must have a
2377 The name of the file to execute. The filename must have a
2376 .ipy extension.
2378 .ipy extension.
2377 """
2379 """
2378 fname = os.path.abspath(os.path.expanduser(fname))
2380 fname = os.path.abspath(os.path.expanduser(fname))
2379
2381
2380 # Make sure we can open the file
2382 # Make sure we can open the file
2381 try:
2383 try:
2382 with open(fname) as thefile:
2384 with open(fname) as thefile:
2383 pass
2385 pass
2384 except:
2386 except:
2385 warn('Could not open file <%s> for safe execution.' % fname)
2387 warn('Could not open file <%s> for safe execution.' % fname)
2386 return
2388 return
2387
2389
2388 # Find things also in current directory. This is needed to mimic the
2390 # Find things also in current directory. This is needed to mimic the
2389 # behavior of running a script from the system command line, where
2391 # behavior of running a script from the system command line, where
2390 # Python inserts the script's directory into sys.path
2392 # Python inserts the script's directory into sys.path
2391 dname = os.path.dirname(fname)
2393 dname = os.path.dirname(fname)
2392
2394
2393 with prepended_to_syspath(dname):
2395 with prepended_to_syspath(dname):
2394 try:
2396 try:
2395 with open(fname) as thefile:
2397 with open(fname) as thefile:
2396 # self.run_cell currently captures all exceptions
2398 # self.run_cell currently captures all exceptions
2397 # raised in user code. It would be nice if there were
2399 # raised in user code. It would be nice if there were
2398 # versions of runlines, execfile that did raise, so
2400 # versions of runlines, execfile that did raise, so
2399 # we could catch the errors.
2401 # we could catch the errors.
2400 self.run_cell(thefile.read(), store_history=False)
2402 self.run_cell(thefile.read(), store_history=False)
2401 except:
2403 except:
2402 self.showtraceback()
2404 self.showtraceback()
2403 warn('Unknown failure executing file: <%s>' % fname)
2405 warn('Unknown failure executing file: <%s>' % fname)
2404
2406
2405 def safe_run_module(self, mod_name, where):
2407 def safe_run_module(self, mod_name, where):
2406 """A safe version of runpy.run_module().
2408 """A safe version of runpy.run_module().
2407
2409
2408 This version will never throw an exception, but instead print
2410 This version will never throw an exception, but instead print
2409 helpful error messages to the screen.
2411 helpful error messages to the screen.
2410
2412
2411 Parameters
2413 Parameters
2412 ----------
2414 ----------
2413 mod_name : string
2415 mod_name : string
2414 The name of the module to be executed.
2416 The name of the module to be executed.
2415 where : dict
2417 where : dict
2416 The globals namespace.
2418 The globals namespace.
2417 """
2419 """
2418 try:
2420 try:
2419 where.update(
2421 where.update(
2420 runpy.run_module(str(mod_name), run_name="__main__",
2422 runpy.run_module(str(mod_name), run_name="__main__",
2421 alter_sys=True)
2423 alter_sys=True)
2422 )
2424 )
2423 except:
2425 except:
2424 self.showtraceback()
2426 self.showtraceback()
2425 warn('Unknown failure executing module: <%s>' % mod_name)
2427 warn('Unknown failure executing module: <%s>' % mod_name)
2426
2428
2427 def run_cell(self, raw_cell, store_history=False):
2429 def run_cell(self, raw_cell, store_history=False, silent=False):
2428 """Run a complete IPython cell.
2430 """Run a complete IPython cell.
2429
2431
2430 Parameters
2432 Parameters
2431 ----------
2433 ----------
2432 raw_cell : str
2434 raw_cell : str
2433 The code (including IPython code such as %magic functions) to run.
2435 The code (including IPython code such as %magic functions) to run.
2434 store_history : bool
2436 store_history : bool
2435 If True, the raw and translated cell will be stored in IPython's
2437 If True, the raw and translated cell will be stored in IPython's
2436 history. For user code calling back into IPython's machinery, this
2438 history. For user code calling back into IPython's machinery, this
2437 should be set to False.
2439 should be set to False.
2440 silent : bool
2441 If True, avoid side-effets, such as implicit displayhooks, history,
2442 and logging. silent=True forces store_history=False.
2438 """
2443 """
2439 if (not raw_cell) or raw_cell.isspace():
2444 if (not raw_cell) or raw_cell.isspace():
2440 return
2445 return
2446
2447 if silent:
2448 store_history = False
2441
2449
2442 for line in raw_cell.splitlines():
2450 for line in raw_cell.splitlines():
2443 self.input_splitter.push(line)
2451 self.input_splitter.push(line)
2444 cell = self.input_splitter.source_reset()
2452 cell = self.input_splitter.source_reset()
2445
2453
2446 with self.builtin_trap:
2454 with self.builtin_trap:
2447 prefilter_failed = False
2455 prefilter_failed = False
2448 if len(cell.splitlines()) == 1:
2456 if len(cell.splitlines()) == 1:
2449 try:
2457 try:
2450 # use prefilter_lines to handle trailing newlines
2458 # use prefilter_lines to handle trailing newlines
2451 # restore trailing newline for ast.parse
2459 # restore trailing newline for ast.parse
2452 cell = self.prefilter_manager.prefilter_lines(cell) + '\n'
2460 cell = self.prefilter_manager.prefilter_lines(cell) + '\n'
2453 except AliasError as e:
2461 except AliasError as e:
2454 error(e)
2462 error(e)
2455 prefilter_failed = True
2463 prefilter_failed = True
2456 except Exception:
2464 except Exception:
2457 # don't allow prefilter errors to crash IPython
2465 # don't allow prefilter errors to crash IPython
2458 self.showtraceback()
2466 self.showtraceback()
2459 prefilter_failed = True
2467 prefilter_failed = True
2460
2468
2461 # Store raw and processed history
2469 # Store raw and processed history
2462 if store_history:
2470 if store_history:
2463 self.history_manager.store_inputs(self.execution_count,
2471 self.history_manager.store_inputs(self.execution_count,
2464 cell, raw_cell)
2472 cell, raw_cell)
2465
2473 if not silent:
2466 self.logger.log(cell, raw_cell)
2474 self.logger.log(cell, raw_cell)
2467
2475
2468 if not prefilter_failed:
2476 if not prefilter_failed:
2469 # don't run if prefilter failed
2477 # don't run if prefilter failed
2470 cell_name = self.compile.cache(cell, self.execution_count)
2478 cell_name = self.compile.cache(cell, self.execution_count)
2471
2479
2472 with self.display_trap:
2480 with self.display_trap:
2473 try:
2481 try:
2474 code_ast = self.compile.ast_parse(cell, filename=cell_name)
2482 code_ast = self.compile.ast_parse(cell, filename=cell_name)
2475 except IndentationError:
2483 except IndentationError:
2476 self.showindentationerror()
2484 self.showindentationerror()
2477 if store_history:
2485 if store_history:
2478 self.execution_count += 1
2486 self.execution_count += 1
2479 return None
2487 return None
2480 except (OverflowError, SyntaxError, ValueError, TypeError,
2488 except (OverflowError, SyntaxError, ValueError, TypeError,
2481 MemoryError):
2489 MemoryError):
2482 self.showsyntaxerror()
2490 self.showsyntaxerror()
2483 if store_history:
2491 if store_history:
2484 self.execution_count += 1
2492 self.execution_count += 1
2485 return None
2493 return None
2486
2494
2495 interactivity = "none" if silent else "last_expr"
2487 self.run_ast_nodes(code_ast.body, cell_name,
2496 self.run_ast_nodes(code_ast.body, cell_name,
2488 interactivity="last_expr")
2497 interactivity=interactivity)
2489
2498
2490 # Execute any registered post-execution functions.
2499 # Execute any registered post-execution functions.
2491 for func, status in self._post_execute.iteritems():
2500 # unless we are silent
2501 post_exec = [] if silent else self._post_execute.iteritems()
2502
2503 for func, status in post_exec:
2492 if self.disable_failing_post_execute and not status:
2504 if self.disable_failing_post_execute and not status:
2493 continue
2505 continue
2494 try:
2506 try:
2495 func()
2507 func()
2496 except KeyboardInterrupt:
2508 except KeyboardInterrupt:
2497 print >> io.stderr, "\nKeyboardInterrupt"
2509 print >> io.stderr, "\nKeyboardInterrupt"
2498 except Exception:
2510 except Exception:
2499 # register as failing:
2511 # register as failing:
2500 self._post_execute[func] = False
2512 self._post_execute[func] = False
2501 self.showtraceback()
2513 self.showtraceback()
2502 print >> io.stderr, '\n'.join([
2514 print >> io.stderr, '\n'.join([
2503 "post-execution function %r produced an error." % func,
2515 "post-execution function %r produced an error." % func,
2504 "If this problem persists, you can disable failing post-exec functions with:",
2516 "If this problem persists, you can disable failing post-exec functions with:",
2505 "",
2517 "",
2506 " get_ipython().disable_failing_post_execute = True"
2518 " get_ipython().disable_failing_post_execute = True"
2507 ])
2519 ])
2508
2520
2509 if store_history:
2521 if store_history:
2510 # Write output to the database. Does nothing unless
2522 # Write output to the database. Does nothing unless
2511 # history output logging is enabled.
2523 # history output logging is enabled.
2512 self.history_manager.store_output(self.execution_count)
2524 self.history_manager.store_output(self.execution_count)
2513 # Each cell is a *single* input, regardless of how many lines it has
2525 # Each cell is a *single* input, regardless of how many lines it has
2514 self.execution_count += 1
2526 self.execution_count += 1
2515
2527
2516 def run_ast_nodes(self, nodelist, cell_name, interactivity='last_expr'):
2528 def run_ast_nodes(self, nodelist, cell_name, interactivity='last_expr'):
2517 """Run a sequence of AST nodes. The execution mode depends on the
2529 """Run a sequence of AST nodes. The execution mode depends on the
2518 interactivity parameter.
2530 interactivity parameter.
2519
2531
2520 Parameters
2532 Parameters
2521 ----------
2533 ----------
2522 nodelist : list
2534 nodelist : list
2523 A sequence of AST nodes to run.
2535 A sequence of AST nodes to run.
2524 cell_name : str
2536 cell_name : str
2525 Will be passed to the compiler as the filename of the cell. Typically
2537 Will be passed to the compiler as the filename of the cell. Typically
2526 the value returned by ip.compile.cache(cell).
2538 the value returned by ip.compile.cache(cell).
2527 interactivity : str
2539 interactivity : str
2528 'all', 'last', 'last_expr' or 'none', specifying which nodes should be
2540 'all', 'last', 'last_expr' or 'none', specifying which nodes should be
2529 run interactively (displaying output from expressions). 'last_expr'
2541 run interactively (displaying output from expressions). 'last_expr'
2530 will run the last node interactively only if it is an expression (i.e.
2542 will run the last node interactively only if it is an expression (i.e.
2531 expressions in loops or other blocks are not displayed. Other values
2543 expressions in loops or other blocks are not displayed. Other values
2532 for this parameter will raise a ValueError.
2544 for this parameter will raise a ValueError.
2533 """
2545 """
2534 if not nodelist:
2546 if not nodelist:
2535 return
2547 return
2536
2548
2537 if interactivity == 'last_expr':
2549 if interactivity == 'last_expr':
2538 if isinstance(nodelist[-1], ast.Expr):
2550 if isinstance(nodelist[-1], ast.Expr):
2539 interactivity = "last"
2551 interactivity = "last"
2540 else:
2552 else:
2541 interactivity = "none"
2553 interactivity = "none"
2542
2554
2543 if interactivity == 'none':
2555 if interactivity == 'none':
2544 to_run_exec, to_run_interactive = nodelist, []
2556 to_run_exec, to_run_interactive = nodelist, []
2545 elif interactivity == 'last':
2557 elif interactivity == 'last':
2546 to_run_exec, to_run_interactive = nodelist[:-1], nodelist[-1:]
2558 to_run_exec, to_run_interactive = nodelist[:-1], nodelist[-1:]
2547 elif interactivity == 'all':
2559 elif interactivity == 'all':
2548 to_run_exec, to_run_interactive = [], nodelist
2560 to_run_exec, to_run_interactive = [], nodelist
2549 else:
2561 else:
2550 raise ValueError("Interactivity was %r" % interactivity)
2562 raise ValueError("Interactivity was %r" % interactivity)
2551
2563
2552 exec_count = self.execution_count
2564 exec_count = self.execution_count
2553
2565
2554 try:
2566 try:
2555 for i, node in enumerate(to_run_exec):
2567 for i, node in enumerate(to_run_exec):
2556 mod = ast.Module([node])
2568 mod = ast.Module([node])
2557 code = self.compile(mod, cell_name, "exec")
2569 code = self.compile(mod, cell_name, "exec")
2558 if self.run_code(code):
2570 if self.run_code(code):
2559 return True
2571 return True
2560
2572
2561 for i, node in enumerate(to_run_interactive):
2573 for i, node in enumerate(to_run_interactive):
2562 mod = ast.Interactive([node])
2574 mod = ast.Interactive([node])
2563 code = self.compile(mod, cell_name, "single")
2575 code = self.compile(mod, cell_name, "single")
2564 if self.run_code(code):
2576 if self.run_code(code):
2565 return True
2577 return True
2566
2578
2567 # Flush softspace
2579 # Flush softspace
2568 if softspace(sys.stdout, 0):
2580 if softspace(sys.stdout, 0):
2569 print
2581 print
2570
2582
2571 except:
2583 except:
2572 # It's possible to have exceptions raised here, typically by
2584 # It's possible to have exceptions raised here, typically by
2573 # compilation of odd code (such as a naked 'return' outside a
2585 # compilation of odd code (such as a naked 'return' outside a
2574 # function) that did parse but isn't valid. Typically the exception
2586 # function) that did parse but isn't valid. Typically the exception
2575 # is a SyntaxError, but it's safest just to catch anything and show
2587 # is a SyntaxError, but it's safest just to catch anything and show
2576 # the user a traceback.
2588 # the user a traceback.
2577
2589
2578 # We do only one try/except outside the loop to minimize the impact
2590 # We do only one try/except outside the loop to minimize the impact
2579 # on runtime, and also because if any node in the node list is
2591 # on runtime, and also because if any node in the node list is
2580 # broken, we should stop execution completely.
2592 # broken, we should stop execution completely.
2581 self.showtraceback()
2593 self.showtraceback()
2582
2594
2583 return False
2595 return False
2584
2596
2585 def run_code(self, code_obj):
2597 def run_code(self, code_obj):
2586 """Execute a code object.
2598 """Execute a code object.
2587
2599
2588 When an exception occurs, self.showtraceback() is called to display a
2600 When an exception occurs, self.showtraceback() is called to display a
2589 traceback.
2601 traceback.
2590
2602
2591 Parameters
2603 Parameters
2592 ----------
2604 ----------
2593 code_obj : code object
2605 code_obj : code object
2594 A compiled code object, to be executed
2606 A compiled code object, to be executed
2595
2607
2596 Returns
2608 Returns
2597 -------
2609 -------
2598 False : successful execution.
2610 False : successful execution.
2599 True : an error occurred.
2611 True : an error occurred.
2600 """
2612 """
2601
2613
2602 # Set our own excepthook in case the user code tries to call it
2614 # Set our own excepthook in case the user code tries to call it
2603 # directly, so that the IPython crash handler doesn't get triggered
2615 # directly, so that the IPython crash handler doesn't get triggered
2604 old_excepthook,sys.excepthook = sys.excepthook, self.excepthook
2616 old_excepthook,sys.excepthook = sys.excepthook, self.excepthook
2605
2617
2606 # we save the original sys.excepthook in the instance, in case config
2618 # we save the original sys.excepthook in the instance, in case config
2607 # code (such as magics) needs access to it.
2619 # code (such as magics) needs access to it.
2608 self.sys_excepthook = old_excepthook
2620 self.sys_excepthook = old_excepthook
2609 outflag = 1 # happens in more places, so it's easier as default
2621 outflag = 1 # happens in more places, so it's easier as default
2610 try:
2622 try:
2611 try:
2623 try:
2612 self.hooks.pre_run_code_hook()
2624 self.hooks.pre_run_code_hook()
2613 #rprint('Running code', repr(code_obj)) # dbg
2625 #rprint('Running code', repr(code_obj)) # dbg
2614 exec code_obj in self.user_global_ns, self.user_ns
2626 exec code_obj in self.user_global_ns, self.user_ns
2615 finally:
2627 finally:
2616 # Reset our crash handler in place
2628 # Reset our crash handler in place
2617 sys.excepthook = old_excepthook
2629 sys.excepthook = old_excepthook
2618 except SystemExit:
2630 except SystemExit:
2619 self.showtraceback(exception_only=True)
2631 self.showtraceback(exception_only=True)
2620 warn("To exit: use 'exit', 'quit', or Ctrl-D.", level=1)
2632 warn("To exit: use 'exit', 'quit', or Ctrl-D.", level=1)
2621 except self.custom_exceptions:
2633 except self.custom_exceptions:
2622 etype,value,tb = sys.exc_info()
2634 etype,value,tb = sys.exc_info()
2623 self.CustomTB(etype,value,tb)
2635 self.CustomTB(etype,value,tb)
2624 except:
2636 except:
2625 self.showtraceback()
2637 self.showtraceback()
2626 else:
2638 else:
2627 outflag = 0
2639 outflag = 0
2628 return outflag
2640 return outflag
2629
2641
2630 # For backwards compatibility
2642 # For backwards compatibility
2631 runcode = run_code
2643 runcode = run_code
2632
2644
2633 #-------------------------------------------------------------------------
2645 #-------------------------------------------------------------------------
2634 # Things related to GUI support and pylab
2646 # Things related to GUI support and pylab
2635 #-------------------------------------------------------------------------
2647 #-------------------------------------------------------------------------
2636
2648
2637 def enable_gui(self, gui=None):
2649 def enable_gui(self, gui=None):
2638 raise NotImplementedError('Implement enable_gui in a subclass')
2650 raise NotImplementedError('Implement enable_gui in a subclass')
2639
2651
2640 def enable_pylab(self, gui=None, import_all=True):
2652 def enable_pylab(self, gui=None, import_all=True):
2641 """Activate pylab support at runtime.
2653 """Activate pylab support at runtime.
2642
2654
2643 This turns on support for matplotlib, preloads into the interactive
2655 This turns on support for matplotlib, preloads into the interactive
2644 namespace all of numpy and pylab, and configures IPython to correctly
2656 namespace all of numpy and pylab, and configures IPython to correctly
2645 interact with the GUI event loop. The GUI backend to be used can be
2657 interact with the GUI event loop. The GUI backend to be used can be
2646 optionally selected with the optional :param:`gui` argument.
2658 optionally selected with the optional :param:`gui` argument.
2647
2659
2648 Parameters
2660 Parameters
2649 ----------
2661 ----------
2650 gui : optional, string
2662 gui : optional, string
2651
2663
2652 If given, dictates the choice of matplotlib GUI backend to use
2664 If given, dictates the choice of matplotlib GUI backend to use
2653 (should be one of IPython's supported backends, 'qt', 'osx', 'tk',
2665 (should be one of IPython's supported backends, 'qt', 'osx', 'tk',
2654 'gtk', 'wx' or 'inline'), otherwise we use the default chosen by
2666 'gtk', 'wx' or 'inline'), otherwise we use the default chosen by
2655 matplotlib (as dictated by the matplotlib build-time options plus the
2667 matplotlib (as dictated by the matplotlib build-time options plus the
2656 user's matplotlibrc configuration file). Note that not all backends
2668 user's matplotlibrc configuration file). Note that not all backends
2657 make sense in all contexts, for example a terminal ipython can't
2669 make sense in all contexts, for example a terminal ipython can't
2658 display figures inline.
2670 display figures inline.
2659 """
2671 """
2660
2672
2661 # We want to prevent the loading of pylab to pollute the user's
2673 # We want to prevent the loading of pylab to pollute the user's
2662 # namespace as shown by the %who* magics, so we execute the activation
2674 # namespace as shown by the %who* magics, so we execute the activation
2663 # code in an empty namespace, and we update *both* user_ns and
2675 # code in an empty namespace, and we update *both* user_ns and
2664 # user_ns_hidden with this information.
2676 # user_ns_hidden with this information.
2665 ns = {}
2677 ns = {}
2666 try:
2678 try:
2667 gui = pylab_activate(ns, gui, import_all, self)
2679 gui = pylab_activate(ns, gui, import_all, self)
2668 except KeyError:
2680 except KeyError:
2669 error("Backend %r not supported" % gui)
2681 error("Backend %r not supported" % gui)
2670 return
2682 return
2671 self.user_ns.update(ns)
2683 self.user_ns.update(ns)
2672 self.user_ns_hidden.update(ns)
2684 self.user_ns_hidden.update(ns)
2673 # Now we must activate the gui pylab wants to use, and fix %run to take
2685 # Now we must activate the gui pylab wants to use, and fix %run to take
2674 # plot updates into account
2686 # plot updates into account
2675 self.enable_gui(gui)
2687 self.enable_gui(gui)
2676 self.magic_run = self._pylab_magic_run
2688 self.magic_run = self._pylab_magic_run
2677
2689
2678 #-------------------------------------------------------------------------
2690 #-------------------------------------------------------------------------
2679 # Utilities
2691 # Utilities
2680 #-------------------------------------------------------------------------
2692 #-------------------------------------------------------------------------
2681
2693
2682 def var_expand(self, cmd, depth=0, formatter=DollarFormatter()):
2694 def var_expand(self, cmd, depth=0, formatter=DollarFormatter()):
2683 """Expand python variables in a string.
2695 """Expand python variables in a string.
2684
2696
2685 The depth argument indicates how many frames above the caller should
2697 The depth argument indicates how many frames above the caller should
2686 be walked to look for the local namespace where to expand variables.
2698 be walked to look for the local namespace where to expand variables.
2687
2699
2688 The global namespace for expansion is always the user's interactive
2700 The global namespace for expansion is always the user's interactive
2689 namespace.
2701 namespace.
2690 """
2702 """
2691 ns = self.user_ns.copy()
2703 ns = self.user_ns.copy()
2692 ns.update(sys._getframe(depth+1).f_locals)
2704 ns.update(sys._getframe(depth+1).f_locals)
2693 ns.pop('self', None)
2705 ns.pop('self', None)
2694 try:
2706 try:
2695 cmd = formatter.format(cmd, **ns)
2707 cmd = formatter.format(cmd, **ns)
2696 except Exception:
2708 except Exception:
2697 # if formatter couldn't format, just let it go untransformed
2709 # if formatter couldn't format, just let it go untransformed
2698 pass
2710 pass
2699 return cmd
2711 return cmd
2700
2712
2701 def mktempfile(self, data=None, prefix='ipython_edit_'):
2713 def mktempfile(self, data=None, prefix='ipython_edit_'):
2702 """Make a new tempfile and return its filename.
2714 """Make a new tempfile and return its filename.
2703
2715
2704 This makes a call to tempfile.mktemp, but it registers the created
2716 This makes a call to tempfile.mktemp, but it registers the created
2705 filename internally so ipython cleans it up at exit time.
2717 filename internally so ipython cleans it up at exit time.
2706
2718
2707 Optional inputs:
2719 Optional inputs:
2708
2720
2709 - data(None): if data is given, it gets written out to the temp file
2721 - data(None): if data is given, it gets written out to the temp file
2710 immediately, and the file is closed again."""
2722 immediately, and the file is closed again."""
2711
2723
2712 filename = tempfile.mktemp('.py', prefix)
2724 filename = tempfile.mktemp('.py', prefix)
2713 self.tempfiles.append(filename)
2725 self.tempfiles.append(filename)
2714
2726
2715 if data:
2727 if data:
2716 tmp_file = open(filename,'w')
2728 tmp_file = open(filename,'w')
2717 tmp_file.write(data)
2729 tmp_file.write(data)
2718 tmp_file.close()
2730 tmp_file.close()
2719 return filename
2731 return filename
2720
2732
2721 # TODO: This should be removed when Term is refactored.
2733 # TODO: This should be removed when Term is refactored.
2722 def write(self,data):
2734 def write(self,data):
2723 """Write a string to the default output"""
2735 """Write a string to the default output"""
2724 io.stdout.write(data)
2736 io.stdout.write(data)
2725
2737
2726 # TODO: This should be removed when Term is refactored.
2738 # TODO: This should be removed when Term is refactored.
2727 def write_err(self,data):
2739 def write_err(self,data):
2728 """Write a string to the default error output"""
2740 """Write a string to the default error output"""
2729 io.stderr.write(data)
2741 io.stderr.write(data)
2730
2742
2731 def ask_yes_no(self, prompt, default=None):
2743 def ask_yes_no(self, prompt, default=None):
2732 if self.quiet:
2744 if self.quiet:
2733 return True
2745 return True
2734 return ask_yes_no(prompt,default)
2746 return ask_yes_no(prompt,default)
2735
2747
2736 def show_usage(self):
2748 def show_usage(self):
2737 """Show a usage message"""
2749 """Show a usage message"""
2738 page.page(IPython.core.usage.interactive_usage)
2750 page.page(IPython.core.usage.interactive_usage)
2739
2751
2740 def find_user_code(self, target, raw=True, py_only=False):
2752 def find_user_code(self, target, raw=True, py_only=False):
2741 """Get a code string from history, file, url, or a string or macro.
2753 """Get a code string from history, file, url, or a string or macro.
2742
2754
2743 This is mainly used by magic functions.
2755 This is mainly used by magic functions.
2744
2756
2745 Parameters
2757 Parameters
2746 ----------
2758 ----------
2747
2759
2748 target : str
2760 target : str
2749
2761
2750 A string specifying code to retrieve. This will be tried respectively
2762 A string specifying code to retrieve. This will be tried respectively
2751 as: ranges of input history (see %history for syntax), url,
2763 as: ranges of input history (see %history for syntax), url,
2752 correspnding .py file, filename, or an expression evaluating to a
2764 correspnding .py file, filename, or an expression evaluating to a
2753 string or Macro in the user namespace.
2765 string or Macro in the user namespace.
2754
2766
2755 raw : bool
2767 raw : bool
2756 If true (default), retrieve raw history. Has no effect on the other
2768 If true (default), retrieve raw history. Has no effect on the other
2757 retrieval mechanisms.
2769 retrieval mechanisms.
2758
2770
2759 py_only : bool (default False)
2771 py_only : bool (default False)
2760 Only try to fetch python code, do not try alternative methods to decode file
2772 Only try to fetch python code, do not try alternative methods to decode file
2761 if unicode fails.
2773 if unicode fails.
2762
2774
2763 Returns
2775 Returns
2764 -------
2776 -------
2765 A string of code.
2777 A string of code.
2766
2778
2767 ValueError is raised if nothing is found, and TypeError if it evaluates
2779 ValueError is raised if nothing is found, and TypeError if it evaluates
2768 to an object of another type. In each case, .args[0] is a printable
2780 to an object of another type. In each case, .args[0] is a printable
2769 message.
2781 message.
2770 """
2782 """
2771 code = self.extract_input_lines(target, raw=raw) # Grab history
2783 code = self.extract_input_lines(target, raw=raw) # Grab history
2772 if code:
2784 if code:
2773 return code
2785 return code
2774 utarget = unquote_filename(target)
2786 utarget = unquote_filename(target)
2775 try:
2787 try:
2776 if utarget.startswith(('http://', 'https://')):
2788 if utarget.startswith(('http://', 'https://')):
2777 return openpy.read_py_url(utarget, skip_encoding_cookie=True)
2789 return openpy.read_py_url(utarget, skip_encoding_cookie=True)
2778 except UnicodeDecodeError:
2790 except UnicodeDecodeError:
2779 if not py_only :
2791 if not py_only :
2780 response = urllib.urlopen(target)
2792 response = urllib.urlopen(target)
2781 return response.read().decode('latin1')
2793 return response.read().decode('latin1')
2782 raise ValueError(("'%s' seem to be unreadable.") % utarget)
2794 raise ValueError(("'%s' seem to be unreadable.") % utarget)
2783
2795
2784 potential_target = [target]
2796 potential_target = [target]
2785 try :
2797 try :
2786 potential_target.insert(0,get_py_filename(target))
2798 potential_target.insert(0,get_py_filename(target))
2787 except IOError:
2799 except IOError:
2788 pass
2800 pass
2789
2801
2790 for tgt in potential_target :
2802 for tgt in potential_target :
2791 if os.path.isfile(tgt): # Read file
2803 if os.path.isfile(tgt): # Read file
2792 try :
2804 try :
2793 return openpy.read_py_file(tgt, skip_encoding_cookie=True)
2805 return openpy.read_py_file(tgt, skip_encoding_cookie=True)
2794 except UnicodeDecodeError :
2806 except UnicodeDecodeError :
2795 if not py_only :
2807 if not py_only :
2796 with io_open(tgt,'r', encoding='latin1') as f :
2808 with io_open(tgt,'r', encoding='latin1') as f :
2797 return f.read()
2809 return f.read()
2798 raise ValueError(("'%s' seem to be unreadable.") % target)
2810 raise ValueError(("'%s' seem to be unreadable.") % target)
2799
2811
2800 try: # User namespace
2812 try: # User namespace
2801 codeobj = eval(target, self.user_ns)
2813 codeobj = eval(target, self.user_ns)
2802 except Exception:
2814 except Exception:
2803 raise ValueError(("'%s' was not found in history, as a file, url, "
2815 raise ValueError(("'%s' was not found in history, as a file, url, "
2804 "nor in the user namespace.") % target)
2816 "nor in the user namespace.") % target)
2805 if isinstance(codeobj, basestring):
2817 if isinstance(codeobj, basestring):
2806 return codeobj
2818 return codeobj
2807 elif isinstance(codeobj, Macro):
2819 elif isinstance(codeobj, Macro):
2808 return codeobj.value
2820 return codeobj.value
2809
2821
2810 raise TypeError("%s is neither a string nor a macro." % target,
2822 raise TypeError("%s is neither a string nor a macro." % target,
2811 codeobj)
2823 codeobj)
2812
2824
2813 #-------------------------------------------------------------------------
2825 #-------------------------------------------------------------------------
2814 # Things related to IPython exiting
2826 # Things related to IPython exiting
2815 #-------------------------------------------------------------------------
2827 #-------------------------------------------------------------------------
2816 def atexit_operations(self):
2828 def atexit_operations(self):
2817 """This will be executed at the time of exit.
2829 """This will be executed at the time of exit.
2818
2830
2819 Cleanup operations and saving of persistent data that is done
2831 Cleanup operations and saving of persistent data that is done
2820 unconditionally by IPython should be performed here.
2832 unconditionally by IPython should be performed here.
2821
2833
2822 For things that may depend on startup flags or platform specifics (such
2834 For things that may depend on startup flags or platform specifics (such
2823 as having readline or not), register a separate atexit function in the
2835 as having readline or not), register a separate atexit function in the
2824 code that has the appropriate information, rather than trying to
2836 code that has the appropriate information, rather than trying to
2825 clutter
2837 clutter
2826 """
2838 """
2827 # Close the history session (this stores the end time and line count)
2839 # Close the history session (this stores the end time and line count)
2828 # this must be *before* the tempfile cleanup, in case of temporary
2840 # this must be *before* the tempfile cleanup, in case of temporary
2829 # history db
2841 # history db
2830 self.history_manager.end_session()
2842 self.history_manager.end_session()
2831
2843
2832 # Cleanup all tempfiles left around
2844 # Cleanup all tempfiles left around
2833 for tfile in self.tempfiles:
2845 for tfile in self.tempfiles:
2834 try:
2846 try:
2835 os.unlink(tfile)
2847 os.unlink(tfile)
2836 except OSError:
2848 except OSError:
2837 pass
2849 pass
2838
2850
2839 # Clear all user namespaces to release all references cleanly.
2851 # Clear all user namespaces to release all references cleanly.
2840 self.reset(new_session=False)
2852 self.reset(new_session=False)
2841
2853
2842 # Run user hooks
2854 # Run user hooks
2843 self.hooks.shutdown_hook()
2855 self.hooks.shutdown_hook()
2844
2856
2845 def cleanup(self):
2857 def cleanup(self):
2846 self.restore_sys_module_state()
2858 self.restore_sys_module_state()
2847
2859
2848
2860
2849 class InteractiveShellABC(object):
2861 class InteractiveShellABC(object):
2850 """An abstract base class for InteractiveShell."""
2862 """An abstract base class for InteractiveShell."""
2851 __metaclass__ = abc.ABCMeta
2863 __metaclass__ = abc.ABCMeta
2852
2864
2853 InteractiveShellABC.register(InteractiveShell)
2865 InteractiveShellABC.register(InteractiveShell)
@@ -1,304 +1,359 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """Tests for the key interactiveshell module.
2 """Tests for the key interactiveshell module.
3
3
4 Historically the main classes in interactiveshell have been under-tested. This
4 Historically the main classes in interactiveshell have been under-tested. This
5 module should grow as many single-method tests as possible to trap many of the
5 module should grow as many single-method tests as possible to trap many of the
6 recurring bugs we seem to encounter with high-level interaction.
6 recurring bugs we seem to encounter with high-level interaction.
7
7
8 Authors
8 Authors
9 -------
9 -------
10 * Fernando Perez
10 * Fernando Perez
11 """
11 """
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # Copyright (C) 2011 The IPython Development Team
13 # Copyright (C) 2011 The IPython Development Team
14 #
14 #
15 # Distributed under the terms of the BSD License. The full license is in
15 # Distributed under the terms of the BSD License. The full license is in
16 # the file COPYING, distributed as part of this software.
16 # the file COPYING, distributed as part of this software.
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18
18
19 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
20 # Imports
20 # Imports
21 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
22 # stdlib
22 # stdlib
23 import os
23 import os
24 import shutil
24 import shutil
25 import tempfile
25 import tempfile
26 import unittest
26 import unittest
27 from os.path import join
27 from os.path import join
28 import sys
28 import sys
29 from StringIO import StringIO
29 from StringIO import StringIO
30
30
31 from IPython.testing.decorators import skipif
31 from IPython.testing.decorators import skipif
32 from IPython.utils import io
32 from IPython.utils import io
33
33
34 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
35 # Tests
35 # Tests
36 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
37
37
38 class InteractiveShellTestCase(unittest.TestCase):
38 class InteractiveShellTestCase(unittest.TestCase):
39 def test_naked_string_cells(self):
39 def test_naked_string_cells(self):
40 """Test that cells with only naked strings are fully executed"""
40 """Test that cells with only naked strings are fully executed"""
41 ip = get_ipython()
41 ip = get_ipython()
42 # First, single-line inputs
42 # First, single-line inputs
43 ip.run_cell('"a"\n')
43 ip.run_cell('"a"\n')
44 self.assertEquals(ip.user_ns['_'], 'a')
44 self.assertEquals(ip.user_ns['_'], 'a')
45 # And also multi-line cells
45 # And also multi-line cells
46 ip.run_cell('"""a\nb"""\n')
46 ip.run_cell('"""a\nb"""\n')
47 self.assertEquals(ip.user_ns['_'], 'a\nb')
47 self.assertEquals(ip.user_ns['_'], 'a\nb')
48
48
49 def test_run_empty_cell(self):
49 def test_run_empty_cell(self):
50 """Just make sure we don't get a horrible error with a blank
50 """Just make sure we don't get a horrible error with a blank
51 cell of input. Yes, I did overlook that."""
51 cell of input. Yes, I did overlook that."""
52 ip = get_ipython()
52 ip = get_ipython()
53 old_xc = ip.execution_count
53 old_xc = ip.execution_count
54 ip.run_cell('')
54 ip.run_cell('')
55 self.assertEquals(ip.execution_count, old_xc)
55 self.assertEquals(ip.execution_count, old_xc)
56
56
57 def test_run_cell_multiline(self):
57 def test_run_cell_multiline(self):
58 """Multi-block, multi-line cells must execute correctly.
58 """Multi-block, multi-line cells must execute correctly.
59 """
59 """
60 ip = get_ipython()
60 ip = get_ipython()
61 src = '\n'.join(["x=1",
61 src = '\n'.join(["x=1",
62 "y=2",
62 "y=2",
63 "if 1:",
63 "if 1:",
64 " x += 1",
64 " x += 1",
65 " y += 1",])
65 " y += 1",])
66 ip.run_cell(src)
66 ip.run_cell(src)
67 self.assertEquals(ip.user_ns['x'], 2)
67 self.assertEquals(ip.user_ns['x'], 2)
68 self.assertEquals(ip.user_ns['y'], 3)
68 self.assertEquals(ip.user_ns['y'], 3)
69
69
70 def test_multiline_string_cells(self):
70 def test_multiline_string_cells(self):
71 "Code sprinkled with multiline strings should execute (GH-306)"
71 "Code sprinkled with multiline strings should execute (GH-306)"
72 ip = get_ipython()
72 ip = get_ipython()
73 ip.run_cell('tmp=0')
73 ip.run_cell('tmp=0')
74 self.assertEquals(ip.user_ns['tmp'], 0)
74 self.assertEquals(ip.user_ns['tmp'], 0)
75 ip.run_cell('tmp=1;"""a\nb"""\n')
75 ip.run_cell('tmp=1;"""a\nb"""\n')
76 self.assertEquals(ip.user_ns['tmp'], 1)
76 self.assertEquals(ip.user_ns['tmp'], 1)
77
77
78 def test_dont_cache_with_semicolon(self):
78 def test_dont_cache_with_semicolon(self):
79 "Ending a line with semicolon should not cache the returned object (GH-307)"
79 "Ending a line with semicolon should not cache the returned object (GH-307)"
80 ip = get_ipython()
80 ip = get_ipython()
81 oldlen = len(ip.user_ns['Out'])
81 oldlen = len(ip.user_ns['Out'])
82 a = ip.run_cell('1;', store_history=True)
82 a = ip.run_cell('1;', store_history=True)
83 newlen = len(ip.user_ns['Out'])
83 newlen = len(ip.user_ns['Out'])
84 self.assertEquals(oldlen, newlen)
84 self.assertEquals(oldlen, newlen)
85 #also test the default caching behavior
85 #also test the default caching behavior
86 ip.run_cell('1', store_history=True)
86 ip.run_cell('1', store_history=True)
87 newlen = len(ip.user_ns['Out'])
87 newlen = len(ip.user_ns['Out'])
88 self.assertEquals(oldlen+1, newlen)
88 self.assertEquals(oldlen+1, newlen)
89
89
90 def test_In_variable(self):
90 def test_In_variable(self):
91 "Verify that In variable grows with user input (GH-284)"
91 "Verify that In variable grows with user input (GH-284)"
92 ip = get_ipython()
92 ip = get_ipython()
93 oldlen = len(ip.user_ns['In'])
93 oldlen = len(ip.user_ns['In'])
94 ip.run_cell('1;', store_history=True)
94 ip.run_cell('1;', store_history=True)
95 newlen = len(ip.user_ns['In'])
95 newlen = len(ip.user_ns['In'])
96 self.assertEquals(oldlen+1, newlen)
96 self.assertEquals(oldlen+1, newlen)
97 self.assertEquals(ip.user_ns['In'][-1],'1;')
97 self.assertEquals(ip.user_ns['In'][-1],'1;')
98
98
99 def test_magic_names_in_string(self):
99 def test_magic_names_in_string(self):
100 ip = get_ipython()
100 ip = get_ipython()
101 ip.run_cell('a = """\n%exit\n"""')
101 ip.run_cell('a = """\n%exit\n"""')
102 self.assertEquals(ip.user_ns['a'], '\n%exit\n')
102 self.assertEquals(ip.user_ns['a'], '\n%exit\n')
103
103
104 def test_alias_crash(self):
104 def test_alias_crash(self):
105 """Errors in prefilter can't crash IPython"""
105 """Errors in prefilter can't crash IPython"""
106 ip = get_ipython()
106 ip = get_ipython()
107 ip.run_cell('%alias parts echo first %s second %s')
107 ip.run_cell('%alias parts echo first %s second %s')
108 # capture stderr:
108 # capture stderr:
109 save_err = io.stderr
109 save_err = io.stderr
110 io.stderr = StringIO()
110 io.stderr = StringIO()
111 ip.run_cell('parts 1')
111 ip.run_cell('parts 1')
112 err = io.stderr.getvalue()
112 err = io.stderr.getvalue()
113 io.stderr = save_err
113 io.stderr = save_err
114 self.assertEquals(err.split(':')[0], 'ERROR')
114 self.assertEquals(err.split(':')[0], 'ERROR')
115
115
116 def test_trailing_newline(self):
116 def test_trailing_newline(self):
117 """test that running !(command) does not raise a SyntaxError"""
117 """test that running !(command) does not raise a SyntaxError"""
118 ip = get_ipython()
118 ip = get_ipython()
119 ip.run_cell('!(true)\n', False)
119 ip.run_cell('!(true)\n', False)
120 ip.run_cell('!(true)\n\n\n', False)
120 ip.run_cell('!(true)\n\n\n', False)
121
121
122 def test_gh_597(self):
122 def test_gh_597(self):
123 """Pretty-printing lists of objects with non-ascii reprs may cause
123 """Pretty-printing lists of objects with non-ascii reprs may cause
124 problems."""
124 problems."""
125 class Spam(object):
125 class Spam(object):
126 def __repr__(self):
126 def __repr__(self):
127 return "\xe9"*50
127 return "\xe9"*50
128 import IPython.core.formatters
128 import IPython.core.formatters
129 f = IPython.core.formatters.PlainTextFormatter()
129 f = IPython.core.formatters.PlainTextFormatter()
130 f([Spam(),Spam()])
130 f([Spam(),Spam()])
131
131
132
132
133 def test_future_flags(self):
133 def test_future_flags(self):
134 """Check that future flags are used for parsing code (gh-777)"""
134 """Check that future flags are used for parsing code (gh-777)"""
135 ip = get_ipython()
135 ip = get_ipython()
136 ip.run_cell('from __future__ import print_function')
136 ip.run_cell('from __future__ import print_function')
137 try:
137 try:
138 ip.run_cell('prfunc_return_val = print(1,2, sep=" ")')
138 ip.run_cell('prfunc_return_val = print(1,2, sep=" ")')
139 assert 'prfunc_return_val' in ip.user_ns
139 assert 'prfunc_return_val' in ip.user_ns
140 finally:
140 finally:
141 # Reset compiler flags so we don't mess up other tests.
141 # Reset compiler flags so we don't mess up other tests.
142 ip.compile.reset_compiler_flags()
142 ip.compile.reset_compiler_flags()
143
143
144 def test_future_unicode(self):
144 def test_future_unicode(self):
145 """Check that unicode_literals is imported from __future__ (gh #786)"""
145 """Check that unicode_literals is imported from __future__ (gh #786)"""
146 ip = get_ipython()
146 ip = get_ipython()
147 try:
147 try:
148 ip.run_cell(u'byte_str = "a"')
148 ip.run_cell(u'byte_str = "a"')
149 assert isinstance(ip.user_ns['byte_str'], str) # string literals are byte strings by default
149 assert isinstance(ip.user_ns['byte_str'], str) # string literals are byte strings by default
150 ip.run_cell('from __future__ import unicode_literals')
150 ip.run_cell('from __future__ import unicode_literals')
151 ip.run_cell(u'unicode_str = "a"')
151 ip.run_cell(u'unicode_str = "a"')
152 assert isinstance(ip.user_ns['unicode_str'], unicode) # strings literals are now unicode
152 assert isinstance(ip.user_ns['unicode_str'], unicode) # strings literals are now unicode
153 finally:
153 finally:
154 # Reset compiler flags so we don't mess up other tests.
154 # Reset compiler flags so we don't mess up other tests.
155 ip.compile.reset_compiler_flags()
155 ip.compile.reset_compiler_flags()
156
156
157 def test_can_pickle(self):
157 def test_can_pickle(self):
158 "Can we pickle objects defined interactively (GH-29)"
158 "Can we pickle objects defined interactively (GH-29)"
159 ip = get_ipython()
159 ip = get_ipython()
160 ip.reset()
160 ip.reset()
161 ip.run_cell(("class Mylist(list):\n"
161 ip.run_cell(("class Mylist(list):\n"
162 " def __init__(self,x=[]):\n"
162 " def __init__(self,x=[]):\n"
163 " list.__init__(self,x)"))
163 " list.__init__(self,x)"))
164 ip.run_cell("w=Mylist([1,2,3])")
164 ip.run_cell("w=Mylist([1,2,3])")
165
165
166 from cPickle import dumps
166 from cPickle import dumps
167
167
168 # We need to swap in our main module - this is only necessary
168 # We need to swap in our main module - this is only necessary
169 # inside the test framework, because IPython puts the interactive module
169 # inside the test framework, because IPython puts the interactive module
170 # in place (but the test framework undoes this).
170 # in place (but the test framework undoes this).
171 _main = sys.modules['__main__']
171 _main = sys.modules['__main__']
172 sys.modules['__main__'] = ip.user_module
172 sys.modules['__main__'] = ip.user_module
173 try:
173 try:
174 res = dumps(ip.user_ns["w"])
174 res = dumps(ip.user_ns["w"])
175 finally:
175 finally:
176 sys.modules['__main__'] = _main
176 sys.modules['__main__'] = _main
177 self.assertTrue(isinstance(res, bytes))
177 self.assertTrue(isinstance(res, bytes))
178
178
179 def test_global_ns(self):
179 def test_global_ns(self):
180 "Code in functions must be able to access variables outside them."
180 "Code in functions must be able to access variables outside them."
181 ip = get_ipython()
181 ip = get_ipython()
182 ip.run_cell("a = 10")
182 ip.run_cell("a = 10")
183 ip.run_cell(("def f(x):\n"
183 ip.run_cell(("def f(x):\n"
184 " return x + a"))
184 " return x + a"))
185 ip.run_cell("b = f(12)")
185 ip.run_cell("b = f(12)")
186 self.assertEqual(ip.user_ns["b"], 22)
186 self.assertEqual(ip.user_ns["b"], 22)
187
187
188 def test_bad_custom_tb(self):
188 def test_bad_custom_tb(self):
189 """Check that InteractiveShell is protected from bad custom exception handlers"""
189 """Check that InteractiveShell is protected from bad custom exception handlers"""
190 ip = get_ipython()
190 ip = get_ipython()
191 from IPython.utils import io
191 from IPython.utils import io
192 save_stderr = io.stderr
192 save_stderr = io.stderr
193 try:
193 try:
194 # capture stderr
194 # capture stderr
195 io.stderr = StringIO()
195 io.stderr = StringIO()
196 ip.set_custom_exc((IOError,), lambda etype,value,tb: 1/0)
196 ip.set_custom_exc((IOError,), lambda etype,value,tb: 1/0)
197 self.assertEquals(ip.custom_exceptions, (IOError,))
197 self.assertEquals(ip.custom_exceptions, (IOError,))
198 ip.run_cell(u'raise IOError("foo")')
198 ip.run_cell(u'raise IOError("foo")')
199 self.assertEquals(ip.custom_exceptions, ())
199 self.assertEquals(ip.custom_exceptions, ())
200 self.assertTrue("Custom TB Handler failed" in io.stderr.getvalue())
200 self.assertTrue("Custom TB Handler failed" in io.stderr.getvalue())
201 finally:
201 finally:
202 io.stderr = save_stderr
202 io.stderr = save_stderr
203
203
204 def test_bad_custom_tb_return(self):
204 def test_bad_custom_tb_return(self):
205 """Check that InteractiveShell is protected from bad return types in custom exception handlers"""
205 """Check that InteractiveShell is protected from bad return types in custom exception handlers"""
206 ip = get_ipython()
206 ip = get_ipython()
207 from IPython.utils import io
207 from IPython.utils import io
208 save_stderr = io.stderr
208 save_stderr = io.stderr
209 try:
209 try:
210 # capture stderr
210 # capture stderr
211 io.stderr = StringIO()
211 io.stderr = StringIO()
212 ip.set_custom_exc((NameError,),lambda etype,value,tb, tb_offset=None: 1)
212 ip.set_custom_exc((NameError,),lambda etype,value,tb, tb_offset=None: 1)
213 self.assertEquals(ip.custom_exceptions, (NameError,))
213 self.assertEquals(ip.custom_exceptions, (NameError,))
214 ip.run_cell(u'a=abracadabra')
214 ip.run_cell(u'a=abracadabra')
215 self.assertEquals(ip.custom_exceptions, ())
215 self.assertEquals(ip.custom_exceptions, ())
216 self.assertTrue("Custom TB Handler failed" in io.stderr.getvalue())
216 self.assertTrue("Custom TB Handler failed" in io.stderr.getvalue())
217 finally:
217 finally:
218 io.stderr = save_stderr
218 io.stderr = save_stderr
219
219
220 def test_drop_by_id(self):
220 def test_drop_by_id(self):
221 ip = get_ipython()
221 ip = get_ipython()
222 myvars = {"a":object(), "b":object(), "c": object()}
222 myvars = {"a":object(), "b":object(), "c": object()}
223 ip.push(myvars, interactive=False)
223 ip.push(myvars, interactive=False)
224 for name in myvars:
224 for name in myvars:
225 assert name in ip.user_ns, name
225 assert name in ip.user_ns, name
226 assert name in ip.user_ns_hidden, name
226 assert name in ip.user_ns_hidden, name
227 ip.user_ns['b'] = 12
227 ip.user_ns['b'] = 12
228 ip.drop_by_id(myvars)
228 ip.drop_by_id(myvars)
229 for name in ["a", "c"]:
229 for name in ["a", "c"]:
230 assert name not in ip.user_ns, name
230 assert name not in ip.user_ns, name
231 assert name not in ip.user_ns_hidden, name
231 assert name not in ip.user_ns_hidden, name
232 assert ip.user_ns['b'] == 12
232 assert ip.user_ns['b'] == 12
233 ip.reset()
233 ip.reset()
234
234
235 def test_var_expand(self):
235 def test_var_expand(self):
236 ip = get_ipython()
236 ip = get_ipython()
237 ip.user_ns['f'] = u'Ca\xf1o'
237 ip.user_ns['f'] = u'Ca\xf1o'
238 self.assertEqual(ip.var_expand(u'echo $f'), u'echo Ca\xf1o')
238 self.assertEqual(ip.var_expand(u'echo $f'), u'echo Ca\xf1o')
239 self.assertEqual(ip.var_expand(u'echo {f}'), u'echo Ca\xf1o')
239 self.assertEqual(ip.var_expand(u'echo {f}'), u'echo Ca\xf1o')
240 self.assertEqual(ip.var_expand(u'echo {f[:-1]}'), u'echo Ca\xf1')
240 self.assertEqual(ip.var_expand(u'echo {f[:-1]}'), u'echo Ca\xf1')
241 self.assertEqual(ip.var_expand(u'echo {1*2}'), u'echo 2')
241 self.assertEqual(ip.var_expand(u'echo {1*2}'), u'echo 2')
242
242
243 ip.user_ns['f'] = b'Ca\xc3\xb1o'
243 ip.user_ns['f'] = b'Ca\xc3\xb1o'
244 # This should not raise any exception:
244 # This should not raise any exception:
245 ip.var_expand(u'echo $f')
245 ip.var_expand(u'echo $f')
246
246
247 def test_bad_var_expand(self):
247 def test_bad_var_expand(self):
248 """var_expand on invalid formats shouldn't raise"""
248 """var_expand on invalid formats shouldn't raise"""
249 ip = get_ipython()
249 ip = get_ipython()
250
250
251 # SyntaxError
251 # SyntaxError
252 self.assertEqual(ip.var_expand(u"{'a':5}"), u"{'a':5}")
252 self.assertEqual(ip.var_expand(u"{'a':5}"), u"{'a':5}")
253 # NameError
253 # NameError
254 self.assertEqual(ip.var_expand(u"{asdf}"), u"{asdf}")
254 self.assertEqual(ip.var_expand(u"{asdf}"), u"{asdf}")
255 # ZeroDivisionError
255 # ZeroDivisionError
256 self.assertEqual(ip.var_expand(u"{1/0}"), u"{1/0}")
256 self.assertEqual(ip.var_expand(u"{1/0}"), u"{1/0}")
257
258 def test_silent_nopostexec(self):
259 """run_cell(silent=True) doesn't invoke post-exec funcs"""
260 ip = get_ipython()
261
262 d = dict(called=False)
263 def set_called():
264 d['called'] = True
265
266 ip.register_post_execute(set_called)
267 ip.run_cell("1", silent=True)
268 self.assertFalse(d['called'])
269 # double-check that non-silent exec did what we expected
270 # silent to avoid
271 ip.run_cell("1")
272 self.assertTrue(d['called'])
273 # remove post-exec
274 ip._post_execute.pop(set_called)
275
276 def test_silent_noadvance(self):
277 """run_cell(silent=True) doesn't advance execution_count"""
278 ip = get_ipython()
279
280 ec = ip.execution_count
281 # silent should force store_history=False
282 ip.run_cell("1", store_history=True, silent=True)
283
284 self.assertEquals(ec, ip.execution_count)
285 # double-check that non-silent exec did what we expected
286 # silent to avoid
287 ip.run_cell("1", store_history=True)
288 self.assertEquals(ec+1, ip.execution_count)
289
290 def test_silent_nodisplayhook(self):
291 """run_cell(silent=True) doesn't trigger displayhook"""
292 ip = get_ipython()
293
294 d = dict(called=False)
295
296 trap = ip.display_trap
297 save_hook = trap.hook
298
299 def failing_hook(*args, **kwargs):
300 d['called'] = True
301
302 try:
303 trap.hook = failing_hook
304 ip.run_cell("1", silent=True)
305 self.assertFalse(d['called'])
306 # double-check that non-silent exec did what we expected
307 # silent to avoid
308 ip.run_cell("1")
309 self.assertTrue(d['called'])
310 finally:
311 trap.hook = save_hook
257
312
258 @skipif(sys.version_info[0] >= 3, "softspace removed in py3")
313 @skipif(sys.version_info[0] >= 3, "softspace removed in py3")
259 def test_print_softspace(self):
314 def test_print_softspace(self):
260 """Verify that softspace is handled correctly when executing multiple
315 """Verify that softspace is handled correctly when executing multiple
261 statements.
316 statements.
262
317
263 In [1]: print 1; print 2
318 In [1]: print 1; print 2
264 1
319 1
265 2
320 2
266
321
267 In [2]: print 1,; print 2
322 In [2]: print 1,; print 2
268 1 2
323 1 2
269 """
324 """
270
325
271 class TestSafeExecfileNonAsciiPath(unittest.TestCase):
326 class TestSafeExecfileNonAsciiPath(unittest.TestCase):
272
327
273 def setUp(self):
328 def setUp(self):
274 self.BASETESTDIR = tempfile.mkdtemp()
329 self.BASETESTDIR = tempfile.mkdtemp()
275 self.TESTDIR = join(self.BASETESTDIR, u"åäö")
330 self.TESTDIR = join(self.BASETESTDIR, u"åäö")
276 os.mkdir(self.TESTDIR)
331 os.mkdir(self.TESTDIR)
277 with open(join(self.TESTDIR, u"åäötestscript.py"), "w") as sfile:
332 with open(join(self.TESTDIR, u"åäötestscript.py"), "w") as sfile:
278 sfile.write("pass\n")
333 sfile.write("pass\n")
279 self.oldpath = os.getcwdu()
334 self.oldpath = os.getcwdu()
280 os.chdir(self.TESTDIR)
335 os.chdir(self.TESTDIR)
281 self.fname = u"åäötestscript.py"
336 self.fname = u"åäötestscript.py"
282
337
283
338
284 def tearDown(self):
339 def tearDown(self):
285 os.chdir(self.oldpath)
340 os.chdir(self.oldpath)
286 shutil.rmtree(self.BASETESTDIR)
341 shutil.rmtree(self.BASETESTDIR)
287
342
288 def test_1(self):
343 def test_1(self):
289 """Test safe_execfile with non-ascii path
344 """Test safe_execfile with non-ascii path
290 """
345 """
291 _ip.shell.safe_execfile(self.fname, {}, raise_exceptions=True)
346 _ip.shell.safe_execfile(self.fname, {}, raise_exceptions=True)
292
347
293
348
294 class TestSystemRaw(unittest.TestCase):
349 class TestSystemRaw(unittest.TestCase):
295 def test_1(self):
350 def test_1(self):
296 """Test system_raw with non-ascii cmd
351 """Test system_raw with non-ascii cmd
297 """
352 """
298 cmd = ur'''python -c "'åäö'" '''
353 cmd = ur'''python -c "'åäö'" '''
299 _ip.shell.system_raw(cmd)
354 _ip.shell.system_raw(cmd)
300
355
301
356
302 def test__IPYTHON__():
357 def test__IPYTHON__():
303 # This shouldn't raise a NameError, that's all
358 # This shouldn't raise a NameError, that's all
304 __IPYTHON__
359 __IPYTHON__
@@ -1,325 +1,325 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 =============
3 =============
4 parallelmagic
4 parallelmagic
5 =============
5 =============
6
6
7 Magic command interface for interactive parallel work.
7 Magic command interface for interactive parallel work.
8
8
9 Usage
9 Usage
10 =====
10 =====
11
11
12 ``%autopx``
12 ``%autopx``
13
13
14 @AUTOPX_DOC@
14 @AUTOPX_DOC@
15
15
16 ``%px``
16 ``%px``
17
17
18 @PX_DOC@
18 @PX_DOC@
19
19
20 ``%result``
20 ``%result``
21
21
22 @RESULT_DOC@
22 @RESULT_DOC@
23
23
24 """
24 """
25
25
26 #-----------------------------------------------------------------------------
26 #-----------------------------------------------------------------------------
27 # Copyright (C) 2008-2011 The IPython Development Team
27 # Copyright (C) 2008-2011 The IPython Development Team
28 #
28 #
29 # Distributed under the terms of the BSD License. The full license is in
29 # Distributed under the terms of the BSD License. The full license is in
30 # the file COPYING, distributed as part of this software.
30 # the file COPYING, distributed as part of this software.
31 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
32
32
33 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
34 # Imports
34 # Imports
35 #-----------------------------------------------------------------------------
35 #-----------------------------------------------------------------------------
36
36
37 import ast
37 import ast
38 import re
38 import re
39
39
40 from IPython.core.plugin import Plugin
40 from IPython.core.plugin import Plugin
41 from IPython.utils.traitlets import Bool, Any, Instance
41 from IPython.utils.traitlets import Bool, Any, Instance
42 from IPython.testing.skipdoctest import skip_doctest
42 from IPython.testing.skipdoctest import skip_doctest
43
43
44 #-----------------------------------------------------------------------------
44 #-----------------------------------------------------------------------------
45 # Definitions of magic functions for use with IPython
45 # Definitions of magic functions for use with IPython
46 #-----------------------------------------------------------------------------
46 #-----------------------------------------------------------------------------
47
47
48
48
49 NO_ACTIVE_VIEW = """
49 NO_ACTIVE_VIEW = """
50 Use activate() on a DirectView object to activate it for magics.
50 Use activate() on a DirectView object to activate it for magics.
51 """
51 """
52
52
53
53
54 class ParalleMagic(Plugin):
54 class ParalleMagic(Plugin):
55 """A component to manage the %result, %px and %autopx magics."""
55 """A component to manage the %result, %px and %autopx magics."""
56
56
57 active_view = Instance('IPython.parallel.client.view.DirectView')
57 active_view = Instance('IPython.parallel.client.view.DirectView')
58 verbose = Bool(False, config=True)
58 verbose = Bool(False, config=True)
59 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
59 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
60
60
61 def __init__(self, shell=None, config=None):
61 def __init__(self, shell=None, config=None):
62 super(ParalleMagic, self).__init__(shell=shell, config=config)
62 super(ParalleMagic, self).__init__(shell=shell, config=config)
63 self._define_magics()
63 self._define_magics()
64 # A flag showing if autopx is activated or not
64 # A flag showing if autopx is activated or not
65 self.autopx = False
65 self.autopx = False
66
66
67 def _define_magics(self):
67 def _define_magics(self):
68 """Define the magic functions."""
68 """Define the magic functions."""
69 self.shell.define_magic('result', self.magic_result)
69 self.shell.define_magic('result', self.magic_result)
70 self.shell.define_magic('px', self.magic_px)
70 self.shell.define_magic('px', self.magic_px)
71 self.shell.define_magic('autopx', self.magic_autopx)
71 self.shell.define_magic('autopx', self.magic_autopx)
72
72
73 @skip_doctest
73 @skip_doctest
74 def magic_result(self, ipself, parameter_s=''):
74 def magic_result(self, ipself, parameter_s=''):
75 """Print the result of command i on all engines..
75 """Print the result of command i on all engines..
76
76
77 To use this a :class:`DirectView` instance must be created
77 To use this a :class:`DirectView` instance must be created
78 and then activated by calling its :meth:`activate` method.
78 and then activated by calling its :meth:`activate` method.
79
79
80 Then you can do the following::
80 Then you can do the following::
81
81
82 In [23]: %result
82 In [23]: %result
83 Out[23]:
83 Out[23]:
84 <Results List>
84 <Results List>
85 [0] In [6]: a = 10
85 [0] In [6]: a = 10
86 [1] In [6]: a = 10
86 [1] In [6]: a = 10
87
87
88 In [22]: %result 6
88 In [22]: %result 6
89 Out[22]:
89 Out[22]:
90 <Results List>
90 <Results List>
91 [0] In [6]: a = 10
91 [0] In [6]: a = 10
92 [1] In [6]: a = 10
92 [1] In [6]: a = 10
93 """
93 """
94 if self.active_view is None:
94 if self.active_view is None:
95 print NO_ACTIVE_VIEW
95 print NO_ACTIVE_VIEW
96 return
96 return
97
97
98 try:
98 try:
99 index = int(parameter_s)
99 index = int(parameter_s)
100 except:
100 except:
101 index = None
101 index = None
102 result = self.active_view.get_result(index)
102 result = self.active_view.get_result(index)
103 return result
103 return result
104
104
105 @skip_doctest
105 @skip_doctest
106 def magic_px(self, ipself, parameter_s=''):
106 def magic_px(self, ipself, parameter_s=''):
107 """Executes the given python command in parallel.
107 """Executes the given python command in parallel.
108
108
109 To use this a :class:`DirectView` instance must be created
109 To use this a :class:`DirectView` instance must be created
110 and then activated by calling its :meth:`activate` method.
110 and then activated by calling its :meth:`activate` method.
111
111
112 Then you can do the following::
112 Then you can do the following::
113
113
114 In [24]: %px a = 5
114 In [24]: %px a = 5
115 Parallel execution on engine(s): all
115 Parallel execution on engine(s): all
116 Out[24]:
116 Out[24]:
117 <Results List>
117 <Results List>
118 [0] In [7]: a = 5
118 [0] In [7]: a = 5
119 [1] In [7]: a = 5
119 [1] In [7]: a = 5
120 """
120 """
121
121
122 if self.active_view is None:
122 if self.active_view is None:
123 print NO_ACTIVE_VIEW
123 print NO_ACTIVE_VIEW
124 return
124 return
125 print "Parallel execution on engine(s): %s" % self.active_view.targets
125 print "Parallel execution on engine(s): %s" % self.active_view.targets
126 result = self.active_view.execute(parameter_s, block=False)
126 result = self.active_view.execute(parameter_s, block=False)
127 if self.active_view.block:
127 if self.active_view.block:
128 result.get()
128 result.get()
129 self._maybe_display_output(result)
129 self._maybe_display_output(result)
130
130
131 @skip_doctest
131 @skip_doctest
132 def magic_autopx(self, ipself, parameter_s=''):
132 def magic_autopx(self, ipself, parameter_s=''):
133 """Toggles auto parallel mode.
133 """Toggles auto parallel mode.
134
134
135 To use this a :class:`DirectView` instance must be created
135 To use this a :class:`DirectView` instance must be created
136 and then activated by calling its :meth:`activate` method. Once this
136 and then activated by calling its :meth:`activate` method. Once this
137 is called, all commands typed at the command line are send to
137 is called, all commands typed at the command line are send to
138 the engines to be executed in parallel. To control which engine
138 the engines to be executed in parallel. To control which engine
139 are used, set the ``targets`` attributed of the multiengine client
139 are used, set the ``targets`` attributed of the multiengine client
140 before entering ``%autopx`` mode.
140 before entering ``%autopx`` mode.
141
141
142 Then you can do the following::
142 Then you can do the following::
143
143
144 In [25]: %autopx
144 In [25]: %autopx
145 %autopx to enabled
145 %autopx to enabled
146
146
147 In [26]: a = 10
147 In [26]: a = 10
148 Parallel execution on engine(s): [0,1,2,3]
148 Parallel execution on engine(s): [0,1,2,3]
149 In [27]: print a
149 In [27]: print a
150 Parallel execution on engine(s): [0,1,2,3]
150 Parallel execution on engine(s): [0,1,2,3]
151 [stdout:0] 10
151 [stdout:0] 10
152 [stdout:1] 10
152 [stdout:1] 10
153 [stdout:2] 10
153 [stdout:2] 10
154 [stdout:3] 10
154 [stdout:3] 10
155
155
156
156
157 In [27]: %autopx
157 In [27]: %autopx
158 %autopx disabled
158 %autopx disabled
159 """
159 """
160 if self.autopx:
160 if self.autopx:
161 self._disable_autopx()
161 self._disable_autopx()
162 else:
162 else:
163 self._enable_autopx()
163 self._enable_autopx()
164
164
165 def _enable_autopx(self):
165 def _enable_autopx(self):
166 """Enable %autopx mode by saving the original run_cell and installing
166 """Enable %autopx mode by saving the original run_cell and installing
167 pxrun_cell.
167 pxrun_cell.
168 """
168 """
169 if self.active_view is None:
169 if self.active_view is None:
170 print NO_ACTIVE_VIEW
170 print NO_ACTIVE_VIEW
171 return
171 return
172
172
173 # override run_cell and run_code
173 # override run_cell and run_code
174 self._original_run_cell = self.shell.run_cell
174 self._original_run_cell = self.shell.run_cell
175 self.shell.run_cell = self.pxrun_cell
175 self.shell.run_cell = self.pxrun_cell
176 self._original_run_code = self.shell.run_code
176 self._original_run_code = self.shell.run_code
177 self.shell.run_code = self.pxrun_code
177 self.shell.run_code = self.pxrun_code
178
178
179 self.autopx = True
179 self.autopx = True
180 print "%autopx enabled"
180 print "%autopx enabled"
181
181
182 def _disable_autopx(self):
182 def _disable_autopx(self):
183 """Disable %autopx by restoring the original InteractiveShell.run_cell.
183 """Disable %autopx by restoring the original InteractiveShell.run_cell.
184 """
184 """
185 if self.autopx:
185 if self.autopx:
186 self.shell.run_cell = self._original_run_cell
186 self.shell.run_cell = self._original_run_cell
187 self.shell.run_code = self._original_run_code
187 self.shell.run_code = self._original_run_code
188 self.autopx = False
188 self.autopx = False
189 print "%autopx disabled"
189 print "%autopx disabled"
190
190
191 def _maybe_display_output(self, result):
191 def _maybe_display_output(self, result):
192 """Maybe display the output of a parallel result.
192 """Maybe display the output of a parallel result.
193
193
194 If self.active_view.block is True, wait for the result
194 If self.active_view.block is True, wait for the result
195 and display the result. Otherwise, this is a noop.
195 and display the result. Otherwise, this is a noop.
196 """
196 """
197 if isinstance(result.stdout, basestring):
197 if isinstance(result.stdout, basestring):
198 # single result
198 # single result
199 stdouts = [result.stdout.rstrip()]
199 stdouts = [result.stdout.rstrip()]
200 else:
200 else:
201 stdouts = [s.rstrip() for s in result.stdout]
201 stdouts = [s.rstrip() for s in result.stdout]
202
202
203 targets = self.active_view.targets
203 targets = self.active_view.targets
204 if isinstance(targets, int):
204 if isinstance(targets, int):
205 targets = [targets]
205 targets = [targets]
206 elif targets == 'all':
206 elif targets == 'all':
207 targets = self.active_view.client.ids
207 targets = self.active_view.client.ids
208
208
209 if any(stdouts):
209 if any(stdouts):
210 for eid,stdout in zip(targets, stdouts):
210 for eid,stdout in zip(targets, stdouts):
211 print '[stdout:%i]'%eid, stdout
211 print '[stdout:%i]'%eid, stdout
212
212
213
213
214 def pxrun_cell(self, raw_cell, store_history=True):
214 def pxrun_cell(self, raw_cell, store_history=False, silent=False):
215 """drop-in replacement for InteractiveShell.run_cell.
215 """drop-in replacement for InteractiveShell.run_cell.
216
216
217 This executes code remotely, instead of in the local namespace.
217 This executes code remotely, instead of in the local namespace.
218
218
219 See InteractiveShell.run_cell for details.
219 See InteractiveShell.run_cell for details.
220 """
220 """
221
221
222 if (not raw_cell) or raw_cell.isspace():
222 if (not raw_cell) or raw_cell.isspace():
223 return
223 return
224
224
225 ipself = self.shell
225 ipself = self.shell
226
226
227 with ipself.builtin_trap:
227 with ipself.builtin_trap:
228 cell = ipself.prefilter_manager.prefilter_lines(raw_cell)
228 cell = ipself.prefilter_manager.prefilter_lines(raw_cell)
229
229
230 # Store raw and processed history
230 # Store raw and processed history
231 if store_history:
231 if store_history:
232 ipself.history_manager.store_inputs(ipself.execution_count,
232 ipself.history_manager.store_inputs(ipself.execution_count,
233 cell, raw_cell)
233 cell, raw_cell)
234
234
235 # ipself.logger.log(cell, raw_cell)
235 # ipself.logger.log(cell, raw_cell)
236
236
237 cell_name = ipself.compile.cache(cell, ipself.execution_count)
237 cell_name = ipself.compile.cache(cell, ipself.execution_count)
238
238
239 try:
239 try:
240 code_ast = ast.parse(cell, filename=cell_name)
240 code_ast = ast.parse(cell, filename=cell_name)
241 except (OverflowError, SyntaxError, ValueError, TypeError, MemoryError):
241 except (OverflowError, SyntaxError, ValueError, TypeError, MemoryError):
242 # Case 1
242 # Case 1
243 ipself.showsyntaxerror()
243 ipself.showsyntaxerror()
244 ipself.execution_count += 1
244 ipself.execution_count += 1
245 return None
245 return None
246 except NameError:
246 except NameError:
247 # ignore name errors, because we don't know the remote keys
247 # ignore name errors, because we don't know the remote keys
248 pass
248 pass
249
249
250 if store_history:
250 if store_history:
251 # Write output to the database. Does nothing unless
251 # Write output to the database. Does nothing unless
252 # history output logging is enabled.
252 # history output logging is enabled.
253 ipself.history_manager.store_output(ipself.execution_count)
253 ipself.history_manager.store_output(ipself.execution_count)
254 # Each cell is a *single* input, regardless of how many lines it has
254 # Each cell is a *single* input, regardless of how many lines it has
255 ipself.execution_count += 1
255 ipself.execution_count += 1
256 if re.search(r'get_ipython\(\)\.magic\(u?["\']%?autopx', cell):
256 if re.search(r'get_ipython\(\)\.magic\(u?["\']%?autopx', cell):
257 self._disable_autopx()
257 self._disable_autopx()
258 return False
258 return False
259 else:
259 else:
260 try:
260 try:
261 result = self.active_view.execute(cell, block=False)
261 result = self.active_view.execute(cell, silent=False, block=False)
262 except:
262 except:
263 ipself.showtraceback()
263 ipself.showtraceback()
264 return True
264 return True
265 else:
265 else:
266 if self.active_view.block:
266 if self.active_view.block:
267 try:
267 try:
268 result.get()
268 result.get()
269 except:
269 except:
270 self.shell.showtraceback()
270 self.shell.showtraceback()
271 return True
271 return True
272 else:
272 else:
273 self._maybe_display_output(result)
273 self._maybe_display_output(result)
274 return False
274 return False
275
275
276 def pxrun_code(self, code_obj):
276 def pxrun_code(self, code_obj):
277 """drop-in replacement for InteractiveShell.run_code.
277 """drop-in replacement for InteractiveShell.run_code.
278
278
279 This executes code remotely, instead of in the local namespace.
279 This executes code remotely, instead of in the local namespace.
280
280
281 See InteractiveShell.run_code for details.
281 See InteractiveShell.run_code for details.
282 """
282 """
283 ipself = self.shell
283 ipself = self.shell
284 # check code object for the autopx magic
284 # check code object for the autopx magic
285 if 'get_ipython' in code_obj.co_names and 'magic' in code_obj.co_names and \
285 if 'get_ipython' in code_obj.co_names and 'magic' in code_obj.co_names and \
286 any( [ isinstance(c, basestring) and 'autopx' in c for c in code_obj.co_consts ]):
286 any( [ isinstance(c, basestring) and 'autopx' in c for c in code_obj.co_consts ]):
287 self._disable_autopx()
287 self._disable_autopx()
288 return False
288 return False
289 else:
289 else:
290 try:
290 try:
291 result = self.active_view.execute(code_obj, block=False)
291 result = self.active_view.execute(code_obj, block=False)
292 except:
292 except:
293 ipself.showtraceback()
293 ipself.showtraceback()
294 return True
294 return True
295 else:
295 else:
296 if self.active_view.block:
296 if self.active_view.block:
297 try:
297 try:
298 result.get()
298 result.get()
299 except:
299 except:
300 self.shell.showtraceback()
300 self.shell.showtraceback()
301 return True
301 return True
302 else:
302 else:
303 self._maybe_display_output(result)
303 self._maybe_display_output(result)
304 return False
304 return False
305
305
306
306
307 __doc__ = __doc__.replace('@AUTOPX_DOC@',
307 __doc__ = __doc__.replace('@AUTOPX_DOC@',
308 " " + ParalleMagic.magic_autopx.__doc__)
308 " " + ParalleMagic.magic_autopx.__doc__)
309 __doc__ = __doc__.replace('@PX_DOC@',
309 __doc__ = __doc__.replace('@PX_DOC@',
310 " " + ParalleMagic.magic_px.__doc__)
310 " " + ParalleMagic.magic_px.__doc__)
311 __doc__ = __doc__.replace('@RESULT_DOC@',
311 __doc__ = __doc__.replace('@RESULT_DOC@',
312 " " + ParalleMagic.magic_result.__doc__)
312 " " + ParalleMagic.magic_result.__doc__)
313
313
314
314
315 _loaded = False
315 _loaded = False
316
316
317
317
318 def load_ipython_extension(ip):
318 def load_ipython_extension(ip):
319 """Load the extension in IPython."""
319 """Load the extension in IPython."""
320 global _loaded
320 global _loaded
321 if not _loaded:
321 if not _loaded:
322 plugin = ParalleMagic(shell=ip, config=ip.config)
322 plugin = ParalleMagic(shell=ip, config=ip.config)
323 ip.plugin_manager.register_plugin('parallelmagic', plugin)
323 ip.plugin_manager.register_plugin('parallelmagic', plugin)
324 _loaded = True
324 _loaded = True
325
325
@@ -1,352 +1,348 b''
1 """ A minimal application base mixin for all ZMQ based IPython frontends.
1 """ A minimal application base mixin for all ZMQ based IPython frontends.
2
2
3 This is not a complete console app, as subprocess will not be able to receive
3 This is not a complete console app, as subprocess will not be able to receive
4 input, there is no real readline support, among other limitations. This is a
4 input, there is no real readline support, among other limitations. This is a
5 refactoring of what used to be the IPython/frontend/qt/console/qtconsoleapp.py
5 refactoring of what used to be the IPython/frontend/qt/console/qtconsoleapp.py
6
6
7 Authors:
7 Authors:
8
8
9 * Evan Patterson
9 * Evan Patterson
10 * Min RK
10 * Min RK
11 * Erik Tollerud
11 * Erik Tollerud
12 * Fernando Perez
12 * Fernando Perez
13 * Bussonnier Matthias
13 * Bussonnier Matthias
14 * Thomas Kluyver
14 * Thomas Kluyver
15 * Paul Ivanov
15 * Paul Ivanov
16
16
17 """
17 """
18
18
19 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
20 # Imports
20 # Imports
21 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
22
22
23 # stdlib imports
23 # stdlib imports
24 import atexit
24 import atexit
25 import json
25 import json
26 import os
26 import os
27 import signal
27 import signal
28 import sys
28 import sys
29 import uuid
29 import uuid
30
30
31
31
32 # Local imports
32 # Local imports
33 from IPython.config.application import boolean_flag
33 from IPython.config.application import boolean_flag
34 from IPython.config.configurable import Configurable
34 from IPython.config.configurable import Configurable
35 from IPython.core.profiledir import ProfileDir
35 from IPython.core.profiledir import ProfileDir
36 from IPython.lib.kernel import tunnel_to_kernel, find_connection_file, swallow_argv
36 from IPython.lib.kernel import tunnel_to_kernel, find_connection_file, swallow_argv
37 from IPython.zmq.blockingkernelmanager import BlockingKernelManager
37 from IPython.zmq.blockingkernelmanager import BlockingKernelManager
38 from IPython.utils.path import filefind
38 from IPython.utils.path import filefind
39 from IPython.utils.py3compat import str_to_bytes
39 from IPython.utils.py3compat import str_to_bytes
40 from IPython.utils.traitlets import (
40 from IPython.utils.traitlets import (
41 Dict, List, Unicode, CUnicode, Int, CBool, Any
41 Dict, List, Unicode, CUnicode, Int, CBool, Any
42 )
42 )
43 from IPython.zmq.ipkernel import (
43 from IPython.zmq.ipkernel import (
44 flags as ipkernel_flags,
44 flags as ipkernel_flags,
45 aliases as ipkernel_aliases,
45 aliases as ipkernel_aliases,
46 IPKernelApp
46 IPKernelApp
47 )
47 )
48 from IPython.zmq.session import Session, default_secure
48 from IPython.zmq.session import Session, default_secure
49 from IPython.zmq.zmqshell import ZMQInteractiveShell
49 from IPython.zmq.zmqshell import ZMQInteractiveShell
50
50
51 #-----------------------------------------------------------------------------
51 #-----------------------------------------------------------------------------
52 # Network Constants
52 # Network Constants
53 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
54
54
55 from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS
55 from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS
56
56
57 #-----------------------------------------------------------------------------
57 #-----------------------------------------------------------------------------
58 # Globals
58 # Globals
59 #-----------------------------------------------------------------------------
59 #-----------------------------------------------------------------------------
60
60
61
61
62 #-----------------------------------------------------------------------------
62 #-----------------------------------------------------------------------------
63 # Aliases and Flags
63 # Aliases and Flags
64 #-----------------------------------------------------------------------------
64 #-----------------------------------------------------------------------------
65
65
66 flags = dict(ipkernel_flags)
66 flags = dict(ipkernel_flags)
67
67
68 # the flags that are specific to the frontend
68 # the flags that are specific to the frontend
69 # these must be scrubbed before being passed to the kernel,
69 # these must be scrubbed before being passed to the kernel,
70 # or it will raise an error on unrecognized flags
70 # or it will raise an error on unrecognized flags
71 app_flags = {
71 app_flags = {
72 'existing' : ({'IPythonConsoleApp' : {'existing' : 'kernel*.json'}},
72 'existing' : ({'IPythonConsoleApp' : {'existing' : 'kernel*.json'}},
73 "Connect to an existing kernel. If no argument specified, guess most recent"),
73 "Connect to an existing kernel. If no argument specified, guess most recent"),
74 }
74 }
75 app_flags.update(boolean_flag(
75 app_flags.update(boolean_flag(
76 'confirm-exit', 'IPythonConsoleApp.confirm_exit',
76 'confirm-exit', 'IPythonConsoleApp.confirm_exit',
77 """Set to display confirmation dialog on exit. You can always use 'exit' or 'quit',
77 """Set to display confirmation dialog on exit. You can always use 'exit' or 'quit',
78 to force a direct exit without any confirmation.
78 to force a direct exit without any confirmation.
79 """,
79 """,
80 """Don't prompt the user when exiting. This will terminate the kernel
80 """Don't prompt the user when exiting. This will terminate the kernel
81 if it is owned by the frontend, and leave it alive if it is external.
81 if it is owned by the frontend, and leave it alive if it is external.
82 """
82 """
83 ))
83 ))
84 flags.update(app_flags)
84 flags.update(app_flags)
85
85
86 aliases = dict(ipkernel_aliases)
86 aliases = dict(ipkernel_aliases)
87
87
88 # also scrub aliases from the frontend
88 # also scrub aliases from the frontend
89 app_aliases = dict(
89 app_aliases = dict(
90 hb = 'IPythonConsoleApp.hb_port',
90 hb = 'IPythonConsoleApp.hb_port',
91 shell = 'IPythonConsoleApp.shell_port',
91 shell = 'IPythonConsoleApp.shell_port',
92 iopub = 'IPythonConsoleApp.iopub_port',
92 iopub = 'IPythonConsoleApp.iopub_port',
93 stdin = 'IPythonConsoleApp.stdin_port',
93 stdin = 'IPythonConsoleApp.stdin_port',
94 ip = 'IPythonConsoleApp.ip',
94 ip = 'IPythonConsoleApp.ip',
95 existing = 'IPythonConsoleApp.existing',
95 existing = 'IPythonConsoleApp.existing',
96 f = 'IPythonConsoleApp.connection_file',
96 f = 'IPythonConsoleApp.connection_file',
97
97
98
98
99 ssh = 'IPythonConsoleApp.sshserver',
99 ssh = 'IPythonConsoleApp.sshserver',
100 )
100 )
101 aliases.update(app_aliases)
101 aliases.update(app_aliases)
102
102
103 #-----------------------------------------------------------------------------
103 #-----------------------------------------------------------------------------
104 # Classes
104 # Classes
105 #-----------------------------------------------------------------------------
105 #-----------------------------------------------------------------------------
106
106
107 #-----------------------------------------------------------------------------
107 #-----------------------------------------------------------------------------
108 # IPythonConsole
108 # IPythonConsole
109 #-----------------------------------------------------------------------------
109 #-----------------------------------------------------------------------------
110
110
111
111
112 class IPythonConsoleApp(Configurable):
112 class IPythonConsoleApp(Configurable):
113 name = 'ipython-console-mixin'
113 name = 'ipython-console-mixin'
114 default_config_file_name='ipython_config.py'
114 default_config_file_name='ipython_config.py'
115
115
116 description = """
116 description = """
117 The IPython Mixin Console.
117 The IPython Mixin Console.
118
118
119 This class contains the common portions of console client (QtConsole,
119 This class contains the common portions of console client (QtConsole,
120 ZMQ-based terminal console, etc). It is not a full console, in that
120 ZMQ-based terminal console, etc). It is not a full console, in that
121 launched terminal subprocesses will not be able to accept input.
121 launched terminal subprocesses will not be able to accept input.
122
122
123 The Console using this mixing supports various extra features beyond
123 The Console using this mixing supports various extra features beyond
124 the single-process Terminal IPython shell, such as connecting to
124 the single-process Terminal IPython shell, such as connecting to
125 existing kernel, via:
125 existing kernel, via:
126
126
127 ipython <appname> --existing
127 ipython <appname> --existing
128
128
129 as well as tunnel via SSH
129 as well as tunnel via SSH
130
130
131 """
131 """
132
132
133 classes = [IPKernelApp, ZMQInteractiveShell, ProfileDir, Session]
133 classes = [IPKernelApp, ZMQInteractiveShell, ProfileDir, Session]
134 flags = Dict(flags)
134 flags = Dict(flags)
135 aliases = Dict(aliases)
135 aliases = Dict(aliases)
136 kernel_manager_class = BlockingKernelManager
136 kernel_manager_class = BlockingKernelManager
137
137
138 kernel_argv = List(Unicode)
138 kernel_argv = List(Unicode)
139 # frontend flags&aliases to be stripped when building kernel_argv
139 # frontend flags&aliases to be stripped when building kernel_argv
140 frontend_flags = Any(app_flags)
140 frontend_flags = Any(app_flags)
141 frontend_aliases = Any(app_aliases)
141 frontend_aliases = Any(app_aliases)
142
142
143 pure = CBool(False, config=True,
144 help="Use a pure Python kernel instead of an IPython kernel.")
145 # create requested profiles by default, if they don't exist:
143 # create requested profiles by default, if they don't exist:
146 auto_create = CBool(True)
144 auto_create = CBool(True)
147 # connection info:
145 # connection info:
148 ip = Unicode(LOCALHOST, config=True,
146 ip = Unicode(LOCALHOST, config=True,
149 help="""Set the kernel\'s IP address [default localhost].
147 help="""Set the kernel\'s IP address [default localhost].
150 If the IP address is something other than localhost, then
148 If the IP address is something other than localhost, then
151 Consoles on other machines will be able to connect
149 Consoles on other machines will be able to connect
152 to the Kernel, so be careful!"""
150 to the Kernel, so be careful!"""
153 )
151 )
154
152
155 sshserver = Unicode('', config=True,
153 sshserver = Unicode('', config=True,
156 help="""The SSH server to use to connect to the kernel.""")
154 help="""The SSH server to use to connect to the kernel.""")
157 sshkey = Unicode('', config=True,
155 sshkey = Unicode('', config=True,
158 help="""Path to the ssh key to use for logging in to the ssh server.""")
156 help="""Path to the ssh key to use for logging in to the ssh server.""")
159
157
160 hb_port = Int(0, config=True,
158 hb_port = Int(0, config=True,
161 help="set the heartbeat port [default: random]")
159 help="set the heartbeat port [default: random]")
162 shell_port = Int(0, config=True,
160 shell_port = Int(0, config=True,
163 help="set the shell (XREP) port [default: random]")
161 help="set the shell (XREP) port [default: random]")
164 iopub_port = Int(0, config=True,
162 iopub_port = Int(0, config=True,
165 help="set the iopub (PUB) port [default: random]")
163 help="set the iopub (PUB) port [default: random]")
166 stdin_port = Int(0, config=True,
164 stdin_port = Int(0, config=True,
167 help="set the stdin (XREQ) port [default: random]")
165 help="set the stdin (XREQ) port [default: random]")
168 connection_file = Unicode('', config=True,
166 connection_file = Unicode('', config=True,
169 help="""JSON file in which to store connection info [default: kernel-<pid>.json]
167 help="""JSON file in which to store connection info [default: kernel-<pid>.json]
170
168
171 This file will contain the IP, ports, and authentication key needed to connect
169 This file will contain the IP, ports, and authentication key needed to connect
172 clients to this kernel. By default, this file will be created in the security-dir
170 clients to this kernel. By default, this file will be created in the security-dir
173 of the current profile, but can be specified by absolute path.
171 of the current profile, but can be specified by absolute path.
174 """)
172 """)
175 def _connection_file_default(self):
173 def _connection_file_default(self):
176 return 'kernel-%i.json' % os.getpid()
174 return 'kernel-%i.json' % os.getpid()
177
175
178 existing = CUnicode('', config=True,
176 existing = CUnicode('', config=True,
179 help="""Connect to an already running kernel""")
177 help="""Connect to an already running kernel""")
180
178
181 confirm_exit = CBool(True, config=True,
179 confirm_exit = CBool(True, config=True,
182 help="""
180 help="""
183 Set to display confirmation dialog on exit. You can always use 'exit' or 'quit',
181 Set to display confirmation dialog on exit. You can always use 'exit' or 'quit',
184 to force a direct exit without any confirmation.""",
182 to force a direct exit without any confirmation.""",
185 )
183 )
186
184
187
185
188 def build_kernel_argv(self, argv=None):
186 def build_kernel_argv(self, argv=None):
189 """build argv to be passed to kernel subprocess"""
187 """build argv to be passed to kernel subprocess"""
190 if argv is None:
188 if argv is None:
191 argv = sys.argv[1:]
189 argv = sys.argv[1:]
192 self.kernel_argv = swallow_argv(argv, self.frontend_aliases, self.frontend_flags)
190 self.kernel_argv = swallow_argv(argv, self.frontend_aliases, self.frontend_flags)
193 # kernel should inherit default config file from frontend
191 # kernel should inherit default config file from frontend
194 self.kernel_argv.append("--KernelApp.parent_appname='%s'"%self.name)
192 self.kernel_argv.append("--KernelApp.parent_appname='%s'"%self.name)
195
193
196 def init_connection_file(self):
194 def init_connection_file(self):
197 """find the connection file, and load the info if found.
195 """find the connection file, and load the info if found.
198
196
199 The current working directory and the current profile's security
197 The current working directory and the current profile's security
200 directory will be searched for the file if it is not given by
198 directory will be searched for the file if it is not given by
201 absolute path.
199 absolute path.
202
200
203 When attempting to connect to an existing kernel and the `--existing`
201 When attempting to connect to an existing kernel and the `--existing`
204 argument does not match an existing file, it will be interpreted as a
202 argument does not match an existing file, it will be interpreted as a
205 fileglob, and the matching file in the current profile's security dir
203 fileglob, and the matching file in the current profile's security dir
206 with the latest access time will be used.
204 with the latest access time will be used.
207
205
208 After this method is called, self.connection_file contains the *full path*
206 After this method is called, self.connection_file contains the *full path*
209 to the connection file, never just its name.
207 to the connection file, never just its name.
210 """
208 """
211 if self.existing:
209 if self.existing:
212 try:
210 try:
213 cf = find_connection_file(self.existing)
211 cf = find_connection_file(self.existing)
214 except Exception:
212 except Exception:
215 self.log.critical("Could not find existing kernel connection file %s", self.existing)
213 self.log.critical("Could not find existing kernel connection file %s", self.existing)
216 self.exit(1)
214 self.exit(1)
217 self.log.info("Connecting to existing kernel: %s" % cf)
215 self.log.info("Connecting to existing kernel: %s" % cf)
218 self.connection_file = cf
216 self.connection_file = cf
219 else:
217 else:
220 # not existing, check if we are going to write the file
218 # not existing, check if we are going to write the file
221 # and ensure that self.connection_file is a full path, not just the shortname
219 # and ensure that self.connection_file is a full path, not just the shortname
222 try:
220 try:
223 cf = find_connection_file(self.connection_file)
221 cf = find_connection_file(self.connection_file)
224 except Exception:
222 except Exception:
225 # file might not exist
223 # file might not exist
226 if self.connection_file == os.path.basename(self.connection_file):
224 if self.connection_file == os.path.basename(self.connection_file):
227 # just shortname, put it in security dir
225 # just shortname, put it in security dir
228 cf = os.path.join(self.profile_dir.security_dir, self.connection_file)
226 cf = os.path.join(self.profile_dir.security_dir, self.connection_file)
229 else:
227 else:
230 cf = self.connection_file
228 cf = self.connection_file
231 self.connection_file = cf
229 self.connection_file = cf
232
230
233 # should load_connection_file only be used for existing?
231 # should load_connection_file only be used for existing?
234 # as it is now, this allows reusing ports if an existing
232 # as it is now, this allows reusing ports if an existing
235 # file is requested
233 # file is requested
236 try:
234 try:
237 self.load_connection_file()
235 self.load_connection_file()
238 except Exception:
236 except Exception:
239 self.log.error("Failed to load connection file: %r", self.connection_file, exc_info=True)
237 self.log.error("Failed to load connection file: %r", self.connection_file, exc_info=True)
240 self.exit(1)
238 self.exit(1)
241
239
242 def load_connection_file(self):
240 def load_connection_file(self):
243 """load ip/port/hmac config from JSON connection file"""
241 """load ip/port/hmac config from JSON connection file"""
244 # this is identical to KernelApp.load_connection_file
242 # this is identical to KernelApp.load_connection_file
245 # perhaps it can be centralized somewhere?
243 # perhaps it can be centralized somewhere?
246 try:
244 try:
247 fname = filefind(self.connection_file, ['.', self.profile_dir.security_dir])
245 fname = filefind(self.connection_file, ['.', self.profile_dir.security_dir])
248 except IOError:
246 except IOError:
249 self.log.debug("Connection File not found: %s", self.connection_file)
247 self.log.debug("Connection File not found: %s", self.connection_file)
250 return
248 return
251 self.log.debug(u"Loading connection file %s", fname)
249 self.log.debug(u"Loading connection file %s", fname)
252 with open(fname) as f:
250 with open(fname) as f:
253 s = f.read()
251 s = f.read()
254 cfg = json.loads(s)
252 cfg = json.loads(s)
255 if self.ip == LOCALHOST and 'ip' in cfg:
253 if self.ip == LOCALHOST and 'ip' in cfg:
256 # not overridden by config or cl_args
254 # not overridden by config or cl_args
257 self.ip = cfg['ip']
255 self.ip = cfg['ip']
258 for channel in ('hb', 'shell', 'iopub', 'stdin'):
256 for channel in ('hb', 'shell', 'iopub', 'stdin'):
259 name = channel + '_port'
257 name = channel + '_port'
260 if getattr(self, name) == 0 and name in cfg:
258 if getattr(self, name) == 0 and name in cfg:
261 # not overridden by config or cl_args
259 # not overridden by config or cl_args
262 setattr(self, name, cfg[name])
260 setattr(self, name, cfg[name])
263 if 'key' in cfg:
261 if 'key' in cfg:
264 self.config.Session.key = str_to_bytes(cfg['key'])
262 self.config.Session.key = str_to_bytes(cfg['key'])
265
263
266 def init_ssh(self):
264 def init_ssh(self):
267 """set up ssh tunnels, if needed."""
265 """set up ssh tunnels, if needed."""
268 if not self.sshserver and not self.sshkey:
266 if not self.sshserver and not self.sshkey:
269 return
267 return
270
268
271 if self.sshkey and not self.sshserver:
269 if self.sshkey and not self.sshserver:
272 # specifying just the key implies that we are connecting directly
270 # specifying just the key implies that we are connecting directly
273 self.sshserver = self.ip
271 self.sshserver = self.ip
274 self.ip = LOCALHOST
272 self.ip = LOCALHOST
275
273
276 # build connection dict for tunnels:
274 # build connection dict for tunnels:
277 info = dict(ip=self.ip,
275 info = dict(ip=self.ip,
278 shell_port=self.shell_port,
276 shell_port=self.shell_port,
279 iopub_port=self.iopub_port,
277 iopub_port=self.iopub_port,
280 stdin_port=self.stdin_port,
278 stdin_port=self.stdin_port,
281 hb_port=self.hb_port
279 hb_port=self.hb_port
282 )
280 )
283
281
284 self.log.info("Forwarding connections to %s via %s"%(self.ip, self.sshserver))
282 self.log.info("Forwarding connections to %s via %s"%(self.ip, self.sshserver))
285
283
286 # tunnels return a new set of ports, which will be on localhost:
284 # tunnels return a new set of ports, which will be on localhost:
287 self.ip = LOCALHOST
285 self.ip = LOCALHOST
288 try:
286 try:
289 newports = tunnel_to_kernel(info, self.sshserver, self.sshkey)
287 newports = tunnel_to_kernel(info, self.sshserver, self.sshkey)
290 except:
288 except:
291 # even catch KeyboardInterrupt
289 # even catch KeyboardInterrupt
292 self.log.error("Could not setup tunnels", exc_info=True)
290 self.log.error("Could not setup tunnels", exc_info=True)
293 self.exit(1)
291 self.exit(1)
294
292
295 self.shell_port, self.iopub_port, self.stdin_port, self.hb_port = newports
293 self.shell_port, self.iopub_port, self.stdin_port, self.hb_port = newports
296
294
297 cf = self.connection_file
295 cf = self.connection_file
298 base,ext = os.path.splitext(cf)
296 base,ext = os.path.splitext(cf)
299 base = os.path.basename(base)
297 base = os.path.basename(base)
300 self.connection_file = os.path.basename(base)+'-ssh'+ext
298 self.connection_file = os.path.basename(base)+'-ssh'+ext
301 self.log.critical("To connect another client via this tunnel, use:")
299 self.log.critical("To connect another client via this tunnel, use:")
302 self.log.critical("--existing %s" % self.connection_file)
300 self.log.critical("--existing %s" % self.connection_file)
303
301
304 def _new_connection_file(self):
302 def _new_connection_file(self):
305 cf = ''
303 cf = ''
306 while not cf:
304 while not cf:
307 # we don't need a 128b id to distinguish kernels, use more readable
305 # we don't need a 128b id to distinguish kernels, use more readable
308 # 48b node segment (12 hex chars). Users running more than 32k simultaneous
306 # 48b node segment (12 hex chars). Users running more than 32k simultaneous
309 # kernels can subclass.
307 # kernels can subclass.
310 ident = str(uuid.uuid4()).split('-')[-1]
308 ident = str(uuid.uuid4()).split('-')[-1]
311 cf = os.path.join(self.profile_dir.security_dir, 'kernel-%s.json' % ident)
309 cf = os.path.join(self.profile_dir.security_dir, 'kernel-%s.json' % ident)
312 # only keep if it's actually new. Protect against unlikely collision
310 # only keep if it's actually new. Protect against unlikely collision
313 # in 48b random search space
311 # in 48b random search space
314 cf = cf if not os.path.exists(cf) else ''
312 cf = cf if not os.path.exists(cf) else ''
315 return cf
313 return cf
316
314
317 def init_kernel_manager(self):
315 def init_kernel_manager(self):
318 # Don't let Qt or ZMQ swallow KeyboardInterupts.
316 # Don't let Qt or ZMQ swallow KeyboardInterupts.
319 signal.signal(signal.SIGINT, signal.SIG_DFL)
317 signal.signal(signal.SIGINT, signal.SIG_DFL)
320
318
321 # Create a KernelManager and start a kernel.
319 # Create a KernelManager and start a kernel.
322 self.kernel_manager = self.kernel_manager_class(
320 self.kernel_manager = self.kernel_manager_class(
323 ip=self.ip,
321 ip=self.ip,
324 shell_port=self.shell_port,
322 shell_port=self.shell_port,
325 iopub_port=self.iopub_port,
323 iopub_port=self.iopub_port,
326 stdin_port=self.stdin_port,
324 stdin_port=self.stdin_port,
327 hb_port=self.hb_port,
325 hb_port=self.hb_port,
328 connection_file=self.connection_file,
326 connection_file=self.connection_file,
329 config=self.config,
327 config=self.config,
330 )
328 )
331 # start the kernel
329 # start the kernel
332 if not self.existing:
330 if not self.existing:
333 kwargs = dict(ipython=not self.pure)
331 self.kernel_manager.start_kernel(extra_arguments=self.kernel_argv)
334 kwargs['extra_arguments'] = self.kernel_argv
335 self.kernel_manager.start_kernel(**kwargs)
336 elif self.sshserver:
332 elif self.sshserver:
337 # ssh, write new connection file
333 # ssh, write new connection file
338 self.kernel_manager.write_connection_file()
334 self.kernel_manager.write_connection_file()
339 atexit.register(self.kernel_manager.cleanup_connection_file)
335 atexit.register(self.kernel_manager.cleanup_connection_file)
340 self.kernel_manager.start_channels()
336 self.kernel_manager.start_channels()
341
337
342
338
343 def initialize(self, argv=None):
339 def initialize(self, argv=None):
344 """
340 """
345 Classes which mix this class in should call:
341 Classes which mix this class in should call:
346 IPythonConsoleApp.initialize(self,argv)
342 IPythonConsoleApp.initialize(self,argv)
347 """
343 """
348 self.init_connection_file()
344 self.init_connection_file()
349 default_secure(self.config)
345 default_secure(self.config)
350 self.init_ssh()
346 self.init_ssh()
351 self.init_kernel_manager()
347 self.init_kernel_manager()
352
348
@@ -1,174 +1,172 b''
1 """Manage IPython.parallel clusters in the notebook.
1 """Manage IPython.parallel clusters in the notebook.
2
2
3 Authors:
3 Authors:
4
4
5 * Brian Granger
5 * Brian Granger
6 """
6 """
7
7
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2008-2011 The IPython Development Team
9 # Copyright (C) 2008-2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18
18
19 import os
19 import os
20
20
21 from tornado import web
21 from tornado import web
22 from zmq.eventloop import ioloop
22 from zmq.eventloop import ioloop
23
23
24 from IPython.config.configurable import LoggingConfigurable
24 from IPython.config.configurable import LoggingConfigurable
25 from IPython.config.loader import load_pyconfig_files
25 from IPython.config.loader import load_pyconfig_files
26 from IPython.utils.traitlets import Dict, Instance, CFloat
26 from IPython.utils.traitlets import Dict, Instance, CFloat
27 from IPython.parallel.apps.ipclusterapp import IPClusterStart
27 from IPython.parallel.apps.ipclusterapp import IPClusterStart
28 from IPython.core.profileapp import list_profiles_in
28 from IPython.core.profileapp import list_profiles_in
29 from IPython.core.profiledir import ProfileDir
29 from IPython.core.profiledir import ProfileDir
30 from IPython.utils.path import get_ipython_dir
30 from IPython.utils.path import get_ipython_dir
31 from IPython.utils.sysinfo import num_cpus
31 from IPython.utils.sysinfo import num_cpus
32
32
33
33
34 #-----------------------------------------------------------------------------
34 #-----------------------------------------------------------------------------
35 # Classes
35 # Classes
36 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
37
37
38
38
39 class DummyIPClusterStart(IPClusterStart):
39 class DummyIPClusterStart(IPClusterStart):
40 """Dummy subclass to skip init steps that conflict with global app.
40 """Dummy subclass to skip init steps that conflict with global app.
41
41
42 Instantiating and initializing this class should result in fully configured
42 Instantiating and initializing this class should result in fully configured
43 launchers, but no other side effects or state.
43 launchers, but no other side effects or state.
44 """
44 """
45
45
46 def init_signal(self):
46 def init_signal(self):
47 pass
47 pass
48 def init_logging(self):
49 pass
50 def reinit_logging(self):
48 def reinit_logging(self):
51 pass
49 pass
52
50
53
51
54 class ClusterManager(LoggingConfigurable):
52 class ClusterManager(LoggingConfigurable):
55
53
56 profiles = Dict()
54 profiles = Dict()
57
55
58 delay = CFloat(1., config=True,
56 delay = CFloat(1., config=True,
59 help="delay (in s) between starting the controller and the engines")
57 help="delay (in s) between starting the controller and the engines")
60
58
61 loop = Instance('zmq.eventloop.ioloop.IOLoop')
59 loop = Instance('zmq.eventloop.ioloop.IOLoop')
62 def _loop_default(self):
60 def _loop_default(self):
63 from zmq.eventloop.ioloop import IOLoop
61 from zmq.eventloop.ioloop import IOLoop
64 return IOLoop.instance()
62 return IOLoop.instance()
65
63
66 def build_launchers(self, profile_dir):
64 def build_launchers(self, profile_dir):
67 starter = DummyIPClusterStart(log=self.log)
65 starter = DummyIPClusterStart(log=self.log)
68 starter.initialize(['--profile-dir', profile_dir])
66 starter.initialize(['--profile-dir', profile_dir])
69 cl = starter.controller_launcher
67 cl = starter.controller_launcher
70 esl = starter.engine_launcher
68 esl = starter.engine_launcher
71 n = starter.n
69 n = starter.n
72 return cl, esl, n
70 return cl, esl, n
73
71
74 def get_profile_dir(self, name, path):
72 def get_profile_dir(self, name, path):
75 p = ProfileDir.find_profile_dir_by_name(path,name=name)
73 p = ProfileDir.find_profile_dir_by_name(path,name=name)
76 return p.location
74 return p.location
77
75
78 def update_profiles(self):
76 def update_profiles(self):
79 """List all profiles in the ipython_dir and cwd.
77 """List all profiles in the ipython_dir and cwd.
80 """
78 """
81 for path in [get_ipython_dir(), os.getcwdu()]:
79 for path in [get_ipython_dir(), os.getcwdu()]:
82 for profile in list_profiles_in(path):
80 for profile in list_profiles_in(path):
83 pd = self.get_profile_dir(profile, path)
81 pd = self.get_profile_dir(profile, path)
84 if profile not in self.profiles:
82 if profile not in self.profiles:
85 self.log.debug("Overwriting profile %s" % profile)
83 self.log.debug("Overwriting profile %s" % profile)
86 self.profiles[profile] = {
84 self.profiles[profile] = {
87 'profile': profile,
85 'profile': profile,
88 'profile_dir': pd,
86 'profile_dir': pd,
89 'status': 'stopped'
87 'status': 'stopped'
90 }
88 }
91
89
92 def list_profiles(self):
90 def list_profiles(self):
93 self.update_profiles()
91 self.update_profiles()
94 result = [self.profile_info(p) for p in sorted(self.profiles.keys())]
92 result = [self.profile_info(p) for p in sorted(self.profiles.keys())]
95 return result
93 return result
96
94
97 def check_profile(self, profile):
95 def check_profile(self, profile):
98 if profile not in self.profiles:
96 if profile not in self.profiles:
99 raise web.HTTPError(404, u'profile not found')
97 raise web.HTTPError(404, u'profile not found')
100
98
101 def profile_info(self, profile):
99 def profile_info(self, profile):
102 self.check_profile(profile)
100 self.check_profile(profile)
103 result = {}
101 result = {}
104 data = self.profiles.get(profile)
102 data = self.profiles.get(profile)
105 result['profile'] = profile
103 result['profile'] = profile
106 result['profile_dir'] = data['profile_dir']
104 result['profile_dir'] = data['profile_dir']
107 result['status'] = data['status']
105 result['status'] = data['status']
108 if 'n' in data:
106 if 'n' in data:
109 result['n'] = data['n']
107 result['n'] = data['n']
110 return result
108 return result
111
109
112 def start_cluster(self, profile, n=None):
110 def start_cluster(self, profile, n=None):
113 """Start a cluster for a given profile."""
111 """Start a cluster for a given profile."""
114 self.check_profile(profile)
112 self.check_profile(profile)
115 data = self.profiles[profile]
113 data = self.profiles[profile]
116 if data['status'] == 'running':
114 if data['status'] == 'running':
117 raise web.HTTPError(409, u'cluster already running')
115 raise web.HTTPError(409, u'cluster already running')
118 cl, esl, default_n = self.build_launchers(data['profile_dir'])
116 cl, esl, default_n = self.build_launchers(data['profile_dir'])
119 n = n if n is not None else default_n
117 n = n if n is not None else default_n
120 def clean_data():
118 def clean_data():
121 data.pop('controller_launcher',None)
119 data.pop('controller_launcher',None)
122 data.pop('engine_set_launcher',None)
120 data.pop('engine_set_launcher',None)
123 data.pop('n',None)
121 data.pop('n',None)
124 data['status'] = 'stopped'
122 data['status'] = 'stopped'
125 def engines_stopped(r):
123 def engines_stopped(r):
126 self.log.debug('Engines stopped')
124 self.log.debug('Engines stopped')
127 if cl.running:
125 if cl.running:
128 cl.stop()
126 cl.stop()
129 clean_data()
127 clean_data()
130 esl.on_stop(engines_stopped)
128 esl.on_stop(engines_stopped)
131 def controller_stopped(r):
129 def controller_stopped(r):
132 self.log.debug('Controller stopped')
130 self.log.debug('Controller stopped')
133 if esl.running:
131 if esl.running:
134 esl.stop()
132 esl.stop()
135 clean_data()
133 clean_data()
136 cl.on_stop(controller_stopped)
134 cl.on_stop(controller_stopped)
137
135
138 dc = ioloop.DelayedCallback(lambda: cl.start(), 0, self.loop)
136 dc = ioloop.DelayedCallback(lambda: cl.start(), 0, self.loop)
139 dc.start()
137 dc.start()
140 dc = ioloop.DelayedCallback(lambda: esl.start(n), 1000*self.delay, self.loop)
138 dc = ioloop.DelayedCallback(lambda: esl.start(n), 1000*self.delay, self.loop)
141 dc.start()
139 dc.start()
142
140
143 self.log.debug('Cluster started')
141 self.log.debug('Cluster started')
144 data['controller_launcher'] = cl
142 data['controller_launcher'] = cl
145 data['engine_set_launcher'] = esl
143 data['engine_set_launcher'] = esl
146 data['n'] = n
144 data['n'] = n
147 data['status'] = 'running'
145 data['status'] = 'running'
148 return self.profile_info(profile)
146 return self.profile_info(profile)
149
147
150 def stop_cluster(self, profile):
148 def stop_cluster(self, profile):
151 """Stop a cluster for a given profile."""
149 """Stop a cluster for a given profile."""
152 self.check_profile(profile)
150 self.check_profile(profile)
153 data = self.profiles[profile]
151 data = self.profiles[profile]
154 if data['status'] == 'stopped':
152 if data['status'] == 'stopped':
155 raise web.HTTPError(409, u'cluster not running')
153 raise web.HTTPError(409, u'cluster not running')
156 data = self.profiles[profile]
154 data = self.profiles[profile]
157 cl = data['controller_launcher']
155 cl = data['controller_launcher']
158 esl = data['engine_set_launcher']
156 esl = data['engine_set_launcher']
159 if cl.running:
157 if cl.running:
160 cl.stop()
158 cl.stop()
161 if esl.running:
159 if esl.running:
162 esl.stop()
160 esl.stop()
163 # Return a temp info dict, the real one is updated in the on_stop
161 # Return a temp info dict, the real one is updated in the on_stop
164 # logic above.
162 # logic above.
165 result = {
163 result = {
166 'profile': data['profile'],
164 'profile': data['profile'],
167 'profile_dir': data['profile_dir'],
165 'profile_dir': data['profile_dir'],
168 'status': 'stopped'
166 'status': 'stopped'
169 }
167 }
170 return result
168 return result
171
169
172 def stop_all_clusters(self):
170 def stop_all_clusters(self):
173 for p in self.profiles.keys():
171 for p in self.profiles.keys():
174 self.stop_cluster(p)
172 self.stop_cluster(p)
@@ -1,733 +1,734 b''
1 """Tornado handlers for the notebook.
1 """Tornado handlers for the notebook.
2
2
3 Authors:
3 Authors:
4
4
5 * Brian Granger
5 * Brian Granger
6 """
6 """
7
7
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2008-2011 The IPython Development Team
9 # Copyright (C) 2008-2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18
18
19 import logging
19 import logging
20 import Cookie
20 import Cookie
21 import time
21 import time
22 import uuid
22 import uuid
23
23
24 from tornado import web
24 from tornado import web
25 from tornado import websocket
25 from tornado import websocket
26
26
27 from zmq.eventloop import ioloop
27 from zmq.eventloop import ioloop
28 from zmq.utils import jsonapi
28 from zmq.utils import jsonapi
29
29
30 from IPython.external.decorator import decorator
30 from IPython.external.decorator import decorator
31 from IPython.zmq.session import Session
31 from IPython.zmq.session import Session
32 from IPython.lib.security import passwd_check
32 from IPython.lib.security import passwd_check
33 from IPython.utils.jsonutil import date_default
33
34
34 try:
35 try:
35 from docutils.core import publish_string
36 from docutils.core import publish_string
36 except ImportError:
37 except ImportError:
37 publish_string = None
38 publish_string = None
38
39
39 #-----------------------------------------------------------------------------
40 #-----------------------------------------------------------------------------
40 # Monkeypatch for Tornado <= 2.1.1 - Remove when no longer necessary!
41 # Monkeypatch for Tornado <= 2.1.1 - Remove when no longer necessary!
41 #-----------------------------------------------------------------------------
42 #-----------------------------------------------------------------------------
42
43
43 # Google Chrome, as of release 16, changed its websocket protocol number. The
44 # Google Chrome, as of release 16, changed its websocket protocol number. The
44 # parts tornado cares about haven't really changed, so it's OK to continue
45 # parts tornado cares about haven't really changed, so it's OK to continue
45 # accepting Chrome connections, but as of Tornado 2.1.1 (the currently released
46 # accepting Chrome connections, but as of Tornado 2.1.1 (the currently released
46 # version as of Oct 30/2011) the version check fails, see the issue report:
47 # version as of Oct 30/2011) the version check fails, see the issue report:
47
48
48 # https://github.com/facebook/tornado/issues/385
49 # https://github.com/facebook/tornado/issues/385
49
50
50 # This issue has been fixed in Tornado post 2.1.1:
51 # This issue has been fixed in Tornado post 2.1.1:
51
52
52 # https://github.com/facebook/tornado/commit/84d7b458f956727c3b0d6710
53 # https://github.com/facebook/tornado/commit/84d7b458f956727c3b0d6710
53
54
54 # Here we manually apply the same patch as above so that users of IPython can
55 # Here we manually apply the same patch as above so that users of IPython can
55 # continue to work with an officially released Tornado. We make the
56 # continue to work with an officially released Tornado. We make the
56 # monkeypatch version check as narrow as possible to limit its effects; once
57 # monkeypatch version check as narrow as possible to limit its effects; once
57 # Tornado 2.1.1 is no longer found in the wild we'll delete this code.
58 # Tornado 2.1.1 is no longer found in the wild we'll delete this code.
58
59
59 import tornado
60 import tornado
60
61
61 if tornado.version_info <= (2,1,1):
62 if tornado.version_info <= (2,1,1):
62
63
63 def _execute(self, transforms, *args, **kwargs):
64 def _execute(self, transforms, *args, **kwargs):
64 from tornado.websocket import WebSocketProtocol8, WebSocketProtocol76
65 from tornado.websocket import WebSocketProtocol8, WebSocketProtocol76
65
66
66 self.open_args = args
67 self.open_args = args
67 self.open_kwargs = kwargs
68 self.open_kwargs = kwargs
68
69
69 # The difference between version 8 and 13 is that in 8 the
70 # The difference between version 8 and 13 is that in 8 the
70 # client sends a "Sec-Websocket-Origin" header and in 13 it's
71 # client sends a "Sec-Websocket-Origin" header and in 13 it's
71 # simply "Origin".
72 # simply "Origin".
72 if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
73 if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
73 self.ws_connection = WebSocketProtocol8(self)
74 self.ws_connection = WebSocketProtocol8(self)
74 self.ws_connection.accept_connection()
75 self.ws_connection.accept_connection()
75
76
76 elif self.request.headers.get("Sec-WebSocket-Version"):
77 elif self.request.headers.get("Sec-WebSocket-Version"):
77 self.stream.write(tornado.escape.utf8(
78 self.stream.write(tornado.escape.utf8(
78 "HTTP/1.1 426 Upgrade Required\r\n"
79 "HTTP/1.1 426 Upgrade Required\r\n"
79 "Sec-WebSocket-Version: 8\r\n\r\n"))
80 "Sec-WebSocket-Version: 8\r\n\r\n"))
80 self.stream.close()
81 self.stream.close()
81
82
82 else:
83 else:
83 self.ws_connection = WebSocketProtocol76(self)
84 self.ws_connection = WebSocketProtocol76(self)
84 self.ws_connection.accept_connection()
85 self.ws_connection.accept_connection()
85
86
86 websocket.WebSocketHandler._execute = _execute
87 websocket.WebSocketHandler._execute = _execute
87 del _execute
88 del _execute
88
89
89 #-----------------------------------------------------------------------------
90 #-----------------------------------------------------------------------------
90 # Decorator for disabling read-only handlers
91 # Decorator for disabling read-only handlers
91 #-----------------------------------------------------------------------------
92 #-----------------------------------------------------------------------------
92
93
93 @decorator
94 @decorator
94 def not_if_readonly(f, self, *args, **kwargs):
95 def not_if_readonly(f, self, *args, **kwargs):
95 if self.application.read_only:
96 if self.application.read_only:
96 raise web.HTTPError(403, "Notebook server is read-only")
97 raise web.HTTPError(403, "Notebook server is read-only")
97 else:
98 else:
98 return f(self, *args, **kwargs)
99 return f(self, *args, **kwargs)
99
100
100 @decorator
101 @decorator
101 def authenticate_unless_readonly(f, self, *args, **kwargs):
102 def authenticate_unless_readonly(f, self, *args, **kwargs):
102 """authenticate this page *unless* readonly view is active.
103 """authenticate this page *unless* readonly view is active.
103
104
104 In read-only mode, the notebook list and print view should
105 In read-only mode, the notebook list and print view should
105 be accessible without authentication.
106 be accessible without authentication.
106 """
107 """
107
108
108 @web.authenticated
109 @web.authenticated
109 def auth_f(self, *args, **kwargs):
110 def auth_f(self, *args, **kwargs):
110 return f(self, *args, **kwargs)
111 return f(self, *args, **kwargs)
111
112
112 if self.application.read_only:
113 if self.application.read_only:
113 return f(self, *args, **kwargs)
114 return f(self, *args, **kwargs)
114 else:
115 else:
115 return auth_f(self, *args, **kwargs)
116 return auth_f(self, *args, **kwargs)
116
117
117 #-----------------------------------------------------------------------------
118 #-----------------------------------------------------------------------------
118 # Top-level handlers
119 # Top-level handlers
119 #-----------------------------------------------------------------------------
120 #-----------------------------------------------------------------------------
120
121
121 class RequestHandler(web.RequestHandler):
122 class RequestHandler(web.RequestHandler):
122 """RequestHandler with default variable setting."""
123 """RequestHandler with default variable setting."""
123
124
124 def render(*args, **kwargs):
125 def render(*args, **kwargs):
125 kwargs.setdefault('message', '')
126 kwargs.setdefault('message', '')
126 return web.RequestHandler.render(*args, **kwargs)
127 return web.RequestHandler.render(*args, **kwargs)
127
128
128 class AuthenticatedHandler(RequestHandler):
129 class AuthenticatedHandler(RequestHandler):
129 """A RequestHandler with an authenticated user."""
130 """A RequestHandler with an authenticated user."""
130
131
131 def get_current_user(self):
132 def get_current_user(self):
132 user_id = self.get_secure_cookie("username")
133 user_id = self.get_secure_cookie("username")
133 # For now the user_id should not return empty, but it could eventually
134 # For now the user_id should not return empty, but it could eventually
134 if user_id == '':
135 if user_id == '':
135 user_id = 'anonymous'
136 user_id = 'anonymous'
136 if user_id is None:
137 if user_id is None:
137 # prevent extra Invalid cookie sig warnings:
138 # prevent extra Invalid cookie sig warnings:
138 self.clear_cookie('username')
139 self.clear_cookie('username')
139 if not self.application.password and not self.application.read_only:
140 if not self.application.password and not self.application.read_only:
140 user_id = 'anonymous'
141 user_id = 'anonymous'
141 return user_id
142 return user_id
142
143
143 @property
144 @property
144 def logged_in(self):
145 def logged_in(self):
145 """Is a user currently logged in?
146 """Is a user currently logged in?
146
147
147 """
148 """
148 user = self.get_current_user()
149 user = self.get_current_user()
149 return (user and not user == 'anonymous')
150 return (user and not user == 'anonymous')
150
151
151 @property
152 @property
152 def login_available(self):
153 def login_available(self):
153 """May a user proceed to log in?
154 """May a user proceed to log in?
154
155
155 This returns True if login capability is available, irrespective of
156 This returns True if login capability is available, irrespective of
156 whether the user is already logged in or not.
157 whether the user is already logged in or not.
157
158
158 """
159 """
159 return bool(self.application.password)
160 return bool(self.application.password)
160
161
161 @property
162 @property
162 def read_only(self):
163 def read_only(self):
163 """Is the notebook read-only?
164 """Is the notebook read-only?
164
165
165 """
166 """
166 return self.application.read_only
167 return self.application.read_only
167
168
168 @property
169 @property
169 def ws_url(self):
170 def ws_url(self):
170 """websocket url matching the current request
171 """websocket url matching the current request
171
172
172 turns http[s]://host[:port] into
173 turns http[s]://host[:port] into
173 ws[s]://host[:port]
174 ws[s]://host[:port]
174 """
175 """
175 proto = self.request.protocol.replace('http', 'ws')
176 proto = self.request.protocol.replace('http', 'ws')
176 host = self.application.ipython_app.websocket_host # default to config value
177 host = self.application.ipython_app.websocket_host # default to config value
177 if host == '':
178 if host == '':
178 host = self.request.host # get from request
179 host = self.request.host # get from request
179 return "%s://%s" % (proto, host)
180 return "%s://%s" % (proto, host)
180
181
181
182
182 class AuthenticatedFileHandler(AuthenticatedHandler, web.StaticFileHandler):
183 class AuthenticatedFileHandler(AuthenticatedHandler, web.StaticFileHandler):
183 """static files should only be accessible when logged in"""
184 """static files should only be accessible when logged in"""
184
185
185 @authenticate_unless_readonly
186 @authenticate_unless_readonly
186 def get(self, path):
187 def get(self, path):
187 return web.StaticFileHandler.get(self, path)
188 return web.StaticFileHandler.get(self, path)
188
189
189
190
190 class ProjectDashboardHandler(AuthenticatedHandler):
191 class ProjectDashboardHandler(AuthenticatedHandler):
191
192
192 @authenticate_unless_readonly
193 @authenticate_unless_readonly
193 def get(self):
194 def get(self):
194 nbm = self.application.notebook_manager
195 nbm = self.application.notebook_manager
195 project = nbm.notebook_dir
196 project = nbm.notebook_dir
196 self.render(
197 self.render(
197 'projectdashboard.html', project=project,
198 'projectdashboard.html', project=project,
198 base_project_url=self.application.ipython_app.base_project_url,
199 base_project_url=self.application.ipython_app.base_project_url,
199 base_kernel_url=self.application.ipython_app.base_kernel_url,
200 base_kernel_url=self.application.ipython_app.base_kernel_url,
200 read_only=self.read_only,
201 read_only=self.read_only,
201 logged_in=self.logged_in,
202 logged_in=self.logged_in,
202 login_available=self.login_available
203 login_available=self.login_available
203 )
204 )
204
205
205
206
206 class LoginHandler(AuthenticatedHandler):
207 class LoginHandler(AuthenticatedHandler):
207
208
208 def _render(self, message=None):
209 def _render(self, message=None):
209 self.render('login.html',
210 self.render('login.html',
210 next=self.get_argument('next', default='/'),
211 next=self.get_argument('next', default='/'),
211 read_only=self.read_only,
212 read_only=self.read_only,
212 logged_in=self.logged_in,
213 logged_in=self.logged_in,
213 login_available=self.login_available,
214 login_available=self.login_available,
214 base_project_url=self.application.ipython_app.base_project_url,
215 base_project_url=self.application.ipython_app.base_project_url,
215 message=message
216 message=message
216 )
217 )
217
218
218 def get(self):
219 def get(self):
219 if self.current_user:
220 if self.current_user:
220 self.redirect(self.get_argument('next', default='/'))
221 self.redirect(self.get_argument('next', default='/'))
221 else:
222 else:
222 self._render()
223 self._render()
223
224
224 def post(self):
225 def post(self):
225 pwd = self.get_argument('password', default=u'')
226 pwd = self.get_argument('password', default=u'')
226 if self.application.password:
227 if self.application.password:
227 if passwd_check(self.application.password, pwd):
228 if passwd_check(self.application.password, pwd):
228 self.set_secure_cookie('username', str(uuid.uuid4()))
229 self.set_secure_cookie('username', str(uuid.uuid4()))
229 else:
230 else:
230 self._render(message={'error': 'Invalid password'})
231 self._render(message={'error': 'Invalid password'})
231 return
232 return
232
233
233 self.redirect(self.get_argument('next', default='/'))
234 self.redirect(self.get_argument('next', default='/'))
234
235
235
236
236 class LogoutHandler(AuthenticatedHandler):
237 class LogoutHandler(AuthenticatedHandler):
237
238
238 def get(self):
239 def get(self):
239 self.clear_cookie('username')
240 self.clear_cookie('username')
240 if self.login_available:
241 if self.login_available:
241 message = {'info': 'Successfully logged out.'}
242 message = {'info': 'Successfully logged out.'}
242 else:
243 else:
243 message = {'warning': 'Cannot log out. Notebook authentication '
244 message = {'warning': 'Cannot log out. Notebook authentication '
244 'is disabled.'}
245 'is disabled.'}
245
246
246 self.render('logout.html',
247 self.render('logout.html',
247 read_only=self.read_only,
248 read_only=self.read_only,
248 logged_in=self.logged_in,
249 logged_in=self.logged_in,
249 login_available=self.login_available,
250 login_available=self.login_available,
250 base_project_url=self.application.ipython_app.base_project_url,
251 base_project_url=self.application.ipython_app.base_project_url,
251 message=message)
252 message=message)
252
253
253
254
254 class NewHandler(AuthenticatedHandler):
255 class NewHandler(AuthenticatedHandler):
255
256
256 @web.authenticated
257 @web.authenticated
257 def get(self):
258 def get(self):
258 nbm = self.application.notebook_manager
259 nbm = self.application.notebook_manager
259 project = nbm.notebook_dir
260 project = nbm.notebook_dir
260 notebook_id = nbm.new_notebook()
261 notebook_id = nbm.new_notebook()
261 self.render(
262 self.render(
262 'notebook.html', project=project,
263 'notebook.html', project=project,
263 notebook_id=notebook_id,
264 notebook_id=notebook_id,
264 base_project_url=self.application.ipython_app.base_project_url,
265 base_project_url=self.application.ipython_app.base_project_url,
265 base_kernel_url=self.application.ipython_app.base_kernel_url,
266 base_kernel_url=self.application.ipython_app.base_kernel_url,
266 kill_kernel=False,
267 kill_kernel=False,
267 read_only=False,
268 read_only=False,
268 logged_in=self.logged_in,
269 logged_in=self.logged_in,
269 login_available=self.login_available,
270 login_available=self.login_available,
270 mathjax_url=self.application.ipython_app.mathjax_url,
271 mathjax_url=self.application.ipython_app.mathjax_url,
271 )
272 )
272
273
273
274
274 class NamedNotebookHandler(AuthenticatedHandler):
275 class NamedNotebookHandler(AuthenticatedHandler):
275
276
276 @authenticate_unless_readonly
277 @authenticate_unless_readonly
277 def get(self, notebook_id):
278 def get(self, notebook_id):
278 nbm = self.application.notebook_manager
279 nbm = self.application.notebook_manager
279 project = nbm.notebook_dir
280 project = nbm.notebook_dir
280 if not nbm.notebook_exists(notebook_id):
281 if not nbm.notebook_exists(notebook_id):
281 raise web.HTTPError(404, u'Notebook does not exist: %s' % notebook_id)
282 raise web.HTTPError(404, u'Notebook does not exist: %s' % notebook_id)
282
283
283 self.render(
284 self.render(
284 'notebook.html', project=project,
285 'notebook.html', project=project,
285 notebook_id=notebook_id,
286 notebook_id=notebook_id,
286 base_project_url=self.application.ipython_app.base_project_url,
287 base_project_url=self.application.ipython_app.base_project_url,
287 base_kernel_url=self.application.ipython_app.base_kernel_url,
288 base_kernel_url=self.application.ipython_app.base_kernel_url,
288 kill_kernel=False,
289 kill_kernel=False,
289 read_only=self.read_only,
290 read_only=self.read_only,
290 logged_in=self.logged_in,
291 logged_in=self.logged_in,
291 login_available=self.login_available,
292 login_available=self.login_available,
292 mathjax_url=self.application.ipython_app.mathjax_url,
293 mathjax_url=self.application.ipython_app.mathjax_url,
293 )
294 )
294
295
295
296
296 class PrintNotebookHandler(AuthenticatedHandler):
297 class PrintNotebookHandler(AuthenticatedHandler):
297
298
298 @authenticate_unless_readonly
299 @authenticate_unless_readonly
299 def get(self, notebook_id):
300 def get(self, notebook_id):
300 nbm = self.application.notebook_manager
301 nbm = self.application.notebook_manager
301 project = nbm.notebook_dir
302 project = nbm.notebook_dir
302 if not nbm.notebook_exists(notebook_id):
303 if not nbm.notebook_exists(notebook_id):
303 raise web.HTTPError(404, u'Notebook does not exist: %s' % notebook_id)
304 raise web.HTTPError(404, u'Notebook does not exist: %s' % notebook_id)
304
305
305 self.render(
306 self.render(
306 'printnotebook.html', project=project,
307 'printnotebook.html', project=project,
307 notebook_id=notebook_id,
308 notebook_id=notebook_id,
308 base_project_url=self.application.ipython_app.base_project_url,
309 base_project_url=self.application.ipython_app.base_project_url,
309 base_kernel_url=self.application.ipython_app.base_kernel_url,
310 base_kernel_url=self.application.ipython_app.base_kernel_url,
310 kill_kernel=False,
311 kill_kernel=False,
311 read_only=self.read_only,
312 read_only=self.read_only,
312 logged_in=self.logged_in,
313 logged_in=self.logged_in,
313 login_available=self.login_available,
314 login_available=self.login_available,
314 mathjax_url=self.application.ipython_app.mathjax_url,
315 mathjax_url=self.application.ipython_app.mathjax_url,
315 )
316 )
316
317
317 #-----------------------------------------------------------------------------
318 #-----------------------------------------------------------------------------
318 # Kernel handlers
319 # Kernel handlers
319 #-----------------------------------------------------------------------------
320 #-----------------------------------------------------------------------------
320
321
321
322
322 class MainKernelHandler(AuthenticatedHandler):
323 class MainKernelHandler(AuthenticatedHandler):
323
324
324 @web.authenticated
325 @web.authenticated
325 def get(self):
326 def get(self):
326 km = self.application.kernel_manager
327 km = self.application.kernel_manager
327 self.finish(jsonapi.dumps(km.kernel_ids))
328 self.finish(jsonapi.dumps(km.kernel_ids))
328
329
329 @web.authenticated
330 @web.authenticated
330 def post(self):
331 def post(self):
331 km = self.application.kernel_manager
332 km = self.application.kernel_manager
332 notebook_id = self.get_argument('notebook', default=None)
333 notebook_id = self.get_argument('notebook', default=None)
333 kernel_id = km.start_kernel(notebook_id)
334 kernel_id = km.start_kernel(notebook_id)
334 data = {'ws_url':self.ws_url,'kernel_id':kernel_id}
335 data = {'ws_url':self.ws_url,'kernel_id':kernel_id}
335 self.set_header('Location', '/'+kernel_id)
336 self.set_header('Location', '/'+kernel_id)
336 self.finish(jsonapi.dumps(data))
337 self.finish(jsonapi.dumps(data))
337
338
338
339
339 class KernelHandler(AuthenticatedHandler):
340 class KernelHandler(AuthenticatedHandler):
340
341
341 SUPPORTED_METHODS = ('DELETE')
342 SUPPORTED_METHODS = ('DELETE')
342
343
343 @web.authenticated
344 @web.authenticated
344 def delete(self, kernel_id):
345 def delete(self, kernel_id):
345 km = self.application.kernel_manager
346 km = self.application.kernel_manager
346 km.kill_kernel(kernel_id)
347 km.kill_kernel(kernel_id)
347 self.set_status(204)
348 self.set_status(204)
348 self.finish()
349 self.finish()
349
350
350
351
351 class KernelActionHandler(AuthenticatedHandler):
352 class KernelActionHandler(AuthenticatedHandler):
352
353
353 @web.authenticated
354 @web.authenticated
354 def post(self, kernel_id, action):
355 def post(self, kernel_id, action):
355 km = self.application.kernel_manager
356 km = self.application.kernel_manager
356 if action == 'interrupt':
357 if action == 'interrupt':
357 km.interrupt_kernel(kernel_id)
358 km.interrupt_kernel(kernel_id)
358 self.set_status(204)
359 self.set_status(204)
359 if action == 'restart':
360 if action == 'restart':
360 new_kernel_id = km.restart_kernel(kernel_id)
361 new_kernel_id = km.restart_kernel(kernel_id)
361 data = {'ws_url':self.ws_url,'kernel_id':new_kernel_id}
362 data = {'ws_url':self.ws_url,'kernel_id':new_kernel_id}
362 self.set_header('Location', '/'+new_kernel_id)
363 self.set_header('Location', '/'+new_kernel_id)
363 self.write(jsonapi.dumps(data))
364 self.write(jsonapi.dumps(data))
364 self.finish()
365 self.finish()
365
366
366
367
367 class ZMQStreamHandler(websocket.WebSocketHandler):
368 class ZMQStreamHandler(websocket.WebSocketHandler):
368
369
369 def _reserialize_reply(self, msg_list):
370 def _reserialize_reply(self, msg_list):
370 """Reserialize a reply message using JSON.
371 """Reserialize a reply message using JSON.
371
372
372 This takes the msg list from the ZMQ socket, unserializes it using
373 This takes the msg list from the ZMQ socket, unserializes it using
373 self.session and then serializes the result using JSON. This method
374 self.session and then serializes the result using JSON. This method
374 should be used by self._on_zmq_reply to build messages that can
375 should be used by self._on_zmq_reply to build messages that can
375 be sent back to the browser.
376 be sent back to the browser.
376 """
377 """
377 idents, msg_list = self.session.feed_identities(msg_list)
378 idents, msg_list = self.session.feed_identities(msg_list)
378 msg = self.session.unserialize(msg_list)
379 msg = self.session.unserialize(msg_list)
379 try:
380 try:
380 msg['header'].pop('date')
381 msg['header'].pop('date')
381 except KeyError:
382 except KeyError:
382 pass
383 pass
383 try:
384 try:
384 msg['parent_header'].pop('date')
385 msg['parent_header'].pop('date')
385 except KeyError:
386 except KeyError:
386 pass
387 pass
387 msg.pop('buffers')
388 msg.pop('buffers')
388 return jsonapi.dumps(msg)
389 return jsonapi.dumps(msg, default=date_default)
389
390
390 def _on_zmq_reply(self, msg_list):
391 def _on_zmq_reply(self, msg_list):
391 try:
392 try:
392 msg = self._reserialize_reply(msg_list)
393 msg = self._reserialize_reply(msg_list)
393 except:
394 except Exception:
394 self.application.log.critical("Malformed message: %r" % msg_list)
395 self.application.log.critical("Malformed message: %r" % msg_list, exc_info=True)
395 else:
396 else:
396 self.write_message(msg)
397 self.write_message(msg)
397
398
398 def allow_draft76(self):
399 def allow_draft76(self):
399 """Allow draft 76, until browsers such as Safari update to RFC 6455.
400 """Allow draft 76, until browsers such as Safari update to RFC 6455.
400
401
401 This has been disabled by default in tornado in release 2.2.0, and
402 This has been disabled by default in tornado in release 2.2.0, and
402 support will be removed in later versions.
403 support will be removed in later versions.
403 """
404 """
404 return True
405 return True
405
406
406
407
407 class AuthenticatedZMQStreamHandler(ZMQStreamHandler):
408 class AuthenticatedZMQStreamHandler(ZMQStreamHandler):
408
409
409 def open(self, kernel_id):
410 def open(self, kernel_id):
410 self.kernel_id = kernel_id.decode('ascii')
411 self.kernel_id = kernel_id.decode('ascii')
411 try:
412 try:
412 cfg = self.application.ipython_app.config
413 cfg = self.application.ipython_app.config
413 except AttributeError:
414 except AttributeError:
414 # protect from the case where this is run from something other than
415 # protect from the case where this is run from something other than
415 # the notebook app:
416 # the notebook app:
416 cfg = None
417 cfg = None
417 self.session = Session(config=cfg)
418 self.session = Session(config=cfg)
418 self.save_on_message = self.on_message
419 self.save_on_message = self.on_message
419 self.on_message = self.on_first_message
420 self.on_message = self.on_first_message
420
421
421 def get_current_user(self):
422 def get_current_user(self):
422 user_id = self.get_secure_cookie("username")
423 user_id = self.get_secure_cookie("username")
423 if user_id == '' or (user_id is None and not self.application.password):
424 if user_id == '' or (user_id is None and not self.application.password):
424 user_id = 'anonymous'
425 user_id = 'anonymous'
425 return user_id
426 return user_id
426
427
427 def _inject_cookie_message(self, msg):
428 def _inject_cookie_message(self, msg):
428 """Inject the first message, which is the document cookie,
429 """Inject the first message, which is the document cookie,
429 for authentication."""
430 for authentication."""
430 if isinstance(msg, unicode):
431 if isinstance(msg, unicode):
431 # Cookie can't constructor doesn't accept unicode strings for some reason
432 # Cookie can't constructor doesn't accept unicode strings for some reason
432 msg = msg.encode('utf8', 'replace')
433 msg = msg.encode('utf8', 'replace')
433 try:
434 try:
434 self.request._cookies = Cookie.SimpleCookie(msg)
435 self.request._cookies = Cookie.SimpleCookie(msg)
435 except:
436 except:
436 logging.warn("couldn't parse cookie string: %s",msg, exc_info=True)
437 logging.warn("couldn't parse cookie string: %s",msg, exc_info=True)
437
438
438 def on_first_message(self, msg):
439 def on_first_message(self, msg):
439 self._inject_cookie_message(msg)
440 self._inject_cookie_message(msg)
440 if self.get_current_user() is None:
441 if self.get_current_user() is None:
441 logging.warn("Couldn't authenticate WebSocket connection")
442 logging.warn("Couldn't authenticate WebSocket connection")
442 raise web.HTTPError(403)
443 raise web.HTTPError(403)
443 self.on_message = self.save_on_message
444 self.on_message = self.save_on_message
444
445
445
446
446 class IOPubHandler(AuthenticatedZMQStreamHandler):
447 class IOPubHandler(AuthenticatedZMQStreamHandler):
447
448
448 def initialize(self, *args, **kwargs):
449 def initialize(self, *args, **kwargs):
449 self._kernel_alive = True
450 self._kernel_alive = True
450 self._beating = False
451 self._beating = False
451 self.iopub_stream = None
452 self.iopub_stream = None
452 self.hb_stream = None
453 self.hb_stream = None
453
454
454 def on_first_message(self, msg):
455 def on_first_message(self, msg):
455 try:
456 try:
456 super(IOPubHandler, self).on_first_message(msg)
457 super(IOPubHandler, self).on_first_message(msg)
457 except web.HTTPError:
458 except web.HTTPError:
458 self.close()
459 self.close()
459 return
460 return
460 km = self.application.kernel_manager
461 km = self.application.kernel_manager
461 self.time_to_dead = km.time_to_dead
462 self.time_to_dead = km.time_to_dead
462 self.first_beat = km.first_beat
463 self.first_beat = km.first_beat
463 kernel_id = self.kernel_id
464 kernel_id = self.kernel_id
464 try:
465 try:
465 self.iopub_stream = km.create_iopub_stream(kernel_id)
466 self.iopub_stream = km.create_iopub_stream(kernel_id)
466 self.hb_stream = km.create_hb_stream(kernel_id)
467 self.hb_stream = km.create_hb_stream(kernel_id)
467 except web.HTTPError:
468 except web.HTTPError:
468 # WebSockets don't response to traditional error codes so we
469 # WebSockets don't response to traditional error codes so we
469 # close the connection.
470 # close the connection.
470 if not self.stream.closed():
471 if not self.stream.closed():
471 self.stream.close()
472 self.stream.close()
472 self.close()
473 self.close()
473 else:
474 else:
474 self.iopub_stream.on_recv(self._on_zmq_reply)
475 self.iopub_stream.on_recv(self._on_zmq_reply)
475 self.start_hb(self.kernel_died)
476 self.start_hb(self.kernel_died)
476
477
477 def on_message(self, msg):
478 def on_message(self, msg):
478 pass
479 pass
479
480
480 def on_close(self):
481 def on_close(self):
481 # This method can be called twice, once by self.kernel_died and once
482 # This method can be called twice, once by self.kernel_died and once
482 # from the WebSocket close event. If the WebSocket connection is
483 # from the WebSocket close event. If the WebSocket connection is
483 # closed before the ZMQ streams are setup, they could be None.
484 # closed before the ZMQ streams are setup, they could be None.
484 self.stop_hb()
485 self.stop_hb()
485 if self.iopub_stream is not None and not self.iopub_stream.closed():
486 if self.iopub_stream is not None and not self.iopub_stream.closed():
486 self.iopub_stream.on_recv(None)
487 self.iopub_stream.on_recv(None)
487 self.iopub_stream.close()
488 self.iopub_stream.close()
488 if self.hb_stream is not None and not self.hb_stream.closed():
489 if self.hb_stream is not None and not self.hb_stream.closed():
489 self.hb_stream.close()
490 self.hb_stream.close()
490
491
491 def start_hb(self, callback):
492 def start_hb(self, callback):
492 """Start the heartbeating and call the callback if the kernel dies."""
493 """Start the heartbeating and call the callback if the kernel dies."""
493 if not self._beating:
494 if not self._beating:
494 self._kernel_alive = True
495 self._kernel_alive = True
495
496
496 def ping_or_dead():
497 def ping_or_dead():
497 self.hb_stream.flush()
498 self.hb_stream.flush()
498 if self._kernel_alive:
499 if self._kernel_alive:
499 self._kernel_alive = False
500 self._kernel_alive = False
500 self.hb_stream.send(b'ping')
501 self.hb_stream.send(b'ping')
501 # flush stream to force immediate socket send
502 # flush stream to force immediate socket send
502 self.hb_stream.flush()
503 self.hb_stream.flush()
503 else:
504 else:
504 try:
505 try:
505 callback()
506 callback()
506 except:
507 except:
507 pass
508 pass
508 finally:
509 finally:
509 self.stop_hb()
510 self.stop_hb()
510
511
511 def beat_received(msg):
512 def beat_received(msg):
512 self._kernel_alive = True
513 self._kernel_alive = True
513
514
514 self.hb_stream.on_recv(beat_received)
515 self.hb_stream.on_recv(beat_received)
515 loop = ioloop.IOLoop.instance()
516 loop = ioloop.IOLoop.instance()
516 self._hb_periodic_callback = ioloop.PeriodicCallback(ping_or_dead, self.time_to_dead*1000, loop)
517 self._hb_periodic_callback = ioloop.PeriodicCallback(ping_or_dead, self.time_to_dead*1000, loop)
517 loop.add_timeout(time.time()+self.first_beat, self._really_start_hb)
518 loop.add_timeout(time.time()+self.first_beat, self._really_start_hb)
518 self._beating= True
519 self._beating= True
519
520
520 def _really_start_hb(self):
521 def _really_start_hb(self):
521 """callback for delayed heartbeat start
522 """callback for delayed heartbeat start
522
523
523 Only start the hb loop if we haven't been closed during the wait.
524 Only start the hb loop if we haven't been closed during the wait.
524 """
525 """
525 if self._beating and not self.hb_stream.closed():
526 if self._beating and not self.hb_stream.closed():
526 self._hb_periodic_callback.start()
527 self._hb_periodic_callback.start()
527
528
528 def stop_hb(self):
529 def stop_hb(self):
529 """Stop the heartbeating and cancel all related callbacks."""
530 """Stop the heartbeating and cancel all related callbacks."""
530 if self._beating:
531 if self._beating:
531 self._beating = False
532 self._beating = False
532 self._hb_periodic_callback.stop()
533 self._hb_periodic_callback.stop()
533 if not self.hb_stream.closed():
534 if not self.hb_stream.closed():
534 self.hb_stream.on_recv(None)
535 self.hb_stream.on_recv(None)
535
536
536 def kernel_died(self):
537 def kernel_died(self):
537 self.application.kernel_manager.delete_mapping_for_kernel(self.kernel_id)
538 self.application.kernel_manager.delete_mapping_for_kernel(self.kernel_id)
538 self.application.log.error("Kernel %s failed to respond to heartbeat", self.kernel_id)
539 self.application.log.error("Kernel %s failed to respond to heartbeat", self.kernel_id)
539 self.write_message(
540 self.write_message(
540 {'header': {'msg_type': 'status'},
541 {'header': {'msg_type': 'status'},
541 'parent_header': {},
542 'parent_header': {},
542 'content': {'execution_state':'dead'}
543 'content': {'execution_state':'dead'}
543 }
544 }
544 )
545 )
545 self.on_close()
546 self.on_close()
546
547
547
548
548 class ShellHandler(AuthenticatedZMQStreamHandler):
549 class ShellHandler(AuthenticatedZMQStreamHandler):
549
550
550 def initialize(self, *args, **kwargs):
551 def initialize(self, *args, **kwargs):
551 self.shell_stream = None
552 self.shell_stream = None
552
553
553 def on_first_message(self, msg):
554 def on_first_message(self, msg):
554 try:
555 try:
555 super(ShellHandler, self).on_first_message(msg)
556 super(ShellHandler, self).on_first_message(msg)
556 except web.HTTPError:
557 except web.HTTPError:
557 self.close()
558 self.close()
558 return
559 return
559 km = self.application.kernel_manager
560 km = self.application.kernel_manager
560 self.max_msg_size = km.max_msg_size
561 self.max_msg_size = km.max_msg_size
561 kernel_id = self.kernel_id
562 kernel_id = self.kernel_id
562 try:
563 try:
563 self.shell_stream = km.create_shell_stream(kernel_id)
564 self.shell_stream = km.create_shell_stream(kernel_id)
564 except web.HTTPError:
565 except web.HTTPError:
565 # WebSockets don't response to traditional error codes so we
566 # WebSockets don't response to traditional error codes so we
566 # close the connection.
567 # close the connection.
567 if not self.stream.closed():
568 if not self.stream.closed():
568 self.stream.close()
569 self.stream.close()
569 self.close()
570 self.close()
570 else:
571 else:
571 self.shell_stream.on_recv(self._on_zmq_reply)
572 self.shell_stream.on_recv(self._on_zmq_reply)
572
573
573 def on_message(self, msg):
574 def on_message(self, msg):
574 if len(msg) < self.max_msg_size:
575 if len(msg) < self.max_msg_size:
575 msg = jsonapi.loads(msg)
576 msg = jsonapi.loads(msg)
576 self.session.send(self.shell_stream, msg)
577 self.session.send(self.shell_stream, msg)
577
578
578 def on_close(self):
579 def on_close(self):
579 # Make sure the stream exists and is not already closed.
580 # Make sure the stream exists and is not already closed.
580 if self.shell_stream is not None and not self.shell_stream.closed():
581 if self.shell_stream is not None and not self.shell_stream.closed():
581 self.shell_stream.close()
582 self.shell_stream.close()
582
583
583
584
584 #-----------------------------------------------------------------------------
585 #-----------------------------------------------------------------------------
585 # Notebook web service handlers
586 # Notebook web service handlers
586 #-----------------------------------------------------------------------------
587 #-----------------------------------------------------------------------------
587
588
588 class NotebookRootHandler(AuthenticatedHandler):
589 class NotebookRootHandler(AuthenticatedHandler):
589
590
590 @authenticate_unless_readonly
591 @authenticate_unless_readonly
591 def get(self):
592 def get(self):
592 nbm = self.application.notebook_manager
593 nbm = self.application.notebook_manager
593 files = nbm.list_notebooks()
594 files = nbm.list_notebooks()
594 self.finish(jsonapi.dumps(files))
595 self.finish(jsonapi.dumps(files))
595
596
596 @web.authenticated
597 @web.authenticated
597 def post(self):
598 def post(self):
598 nbm = self.application.notebook_manager
599 nbm = self.application.notebook_manager
599 body = self.request.body.strip()
600 body = self.request.body.strip()
600 format = self.get_argument('format', default='json')
601 format = self.get_argument('format', default='json')
601 name = self.get_argument('name', default=None)
602 name = self.get_argument('name', default=None)
602 if body:
603 if body:
603 notebook_id = nbm.save_new_notebook(body, name=name, format=format)
604 notebook_id = nbm.save_new_notebook(body, name=name, format=format)
604 else:
605 else:
605 notebook_id = nbm.new_notebook()
606 notebook_id = nbm.new_notebook()
606 self.set_header('Location', '/'+notebook_id)
607 self.set_header('Location', '/'+notebook_id)
607 self.finish(jsonapi.dumps(notebook_id))
608 self.finish(jsonapi.dumps(notebook_id))
608
609
609
610
610 class NotebookHandler(AuthenticatedHandler):
611 class NotebookHandler(AuthenticatedHandler):
611
612
612 SUPPORTED_METHODS = ('GET', 'PUT', 'DELETE')
613 SUPPORTED_METHODS = ('GET', 'PUT', 'DELETE')
613
614
614 @authenticate_unless_readonly
615 @authenticate_unless_readonly
615 def get(self, notebook_id):
616 def get(self, notebook_id):
616 nbm = self.application.notebook_manager
617 nbm = self.application.notebook_manager
617 format = self.get_argument('format', default='json')
618 format = self.get_argument('format', default='json')
618 last_mod, name, data = nbm.get_notebook(notebook_id, format)
619 last_mod, name, data = nbm.get_notebook(notebook_id, format)
619
620
620 if format == u'json':
621 if format == u'json':
621 self.set_header('Content-Type', 'application/json')
622 self.set_header('Content-Type', 'application/json')
622 self.set_header('Content-Disposition','attachment; filename="%s.ipynb"' % name)
623 self.set_header('Content-Disposition','attachment; filename="%s.ipynb"' % name)
623 elif format == u'py':
624 elif format == u'py':
624 self.set_header('Content-Type', 'application/x-python')
625 self.set_header('Content-Type', 'application/x-python')
625 self.set_header('Content-Disposition','attachment; filename="%s.py"' % name)
626 self.set_header('Content-Disposition','attachment; filename="%s.py"' % name)
626 self.set_header('Last-Modified', last_mod)
627 self.set_header('Last-Modified', last_mod)
627 self.finish(data)
628 self.finish(data)
628
629
629 @web.authenticated
630 @web.authenticated
630 def put(self, notebook_id):
631 def put(self, notebook_id):
631 nbm = self.application.notebook_manager
632 nbm = self.application.notebook_manager
632 format = self.get_argument('format', default='json')
633 format = self.get_argument('format', default='json')
633 name = self.get_argument('name', default=None)
634 name = self.get_argument('name', default=None)
634 nbm.save_notebook(notebook_id, self.request.body, name=name, format=format)
635 nbm.save_notebook(notebook_id, self.request.body, name=name, format=format)
635 self.set_status(204)
636 self.set_status(204)
636 self.finish()
637 self.finish()
637
638
638 @web.authenticated
639 @web.authenticated
639 def delete(self, notebook_id):
640 def delete(self, notebook_id):
640 nbm = self.application.notebook_manager
641 nbm = self.application.notebook_manager
641 nbm.delete_notebook(notebook_id)
642 nbm.delete_notebook(notebook_id)
642 self.set_status(204)
643 self.set_status(204)
643 self.finish()
644 self.finish()
644
645
645
646
646 class NotebookCopyHandler(AuthenticatedHandler):
647 class NotebookCopyHandler(AuthenticatedHandler):
647
648
648 @web.authenticated
649 @web.authenticated
649 def get(self, notebook_id):
650 def get(self, notebook_id):
650 nbm = self.application.notebook_manager
651 nbm = self.application.notebook_manager
651 project = nbm.notebook_dir
652 project = nbm.notebook_dir
652 notebook_id = nbm.copy_notebook(notebook_id)
653 notebook_id = nbm.copy_notebook(notebook_id)
653 self.render(
654 self.render(
654 'notebook.html', project=project,
655 'notebook.html', project=project,
655 notebook_id=notebook_id,
656 notebook_id=notebook_id,
656 base_project_url=self.application.ipython_app.base_project_url,
657 base_project_url=self.application.ipython_app.base_project_url,
657 base_kernel_url=self.application.ipython_app.base_kernel_url,
658 base_kernel_url=self.application.ipython_app.base_kernel_url,
658 kill_kernel=False,
659 kill_kernel=False,
659 read_only=False,
660 read_only=False,
660 logged_in=self.logged_in,
661 logged_in=self.logged_in,
661 login_available=self.login_available,
662 login_available=self.login_available,
662 mathjax_url=self.application.ipython_app.mathjax_url,
663 mathjax_url=self.application.ipython_app.mathjax_url,
663 )
664 )
664
665
665
666
666 #-----------------------------------------------------------------------------
667 #-----------------------------------------------------------------------------
667 # Cluster handlers
668 # Cluster handlers
668 #-----------------------------------------------------------------------------
669 #-----------------------------------------------------------------------------
669
670
670
671
671 class MainClusterHandler(AuthenticatedHandler):
672 class MainClusterHandler(AuthenticatedHandler):
672
673
673 @web.authenticated
674 @web.authenticated
674 def get(self):
675 def get(self):
675 cm = self.application.cluster_manager
676 cm = self.application.cluster_manager
676 self.finish(jsonapi.dumps(cm.list_profiles()))
677 self.finish(jsonapi.dumps(cm.list_profiles()))
677
678
678
679
679 class ClusterProfileHandler(AuthenticatedHandler):
680 class ClusterProfileHandler(AuthenticatedHandler):
680
681
681 @web.authenticated
682 @web.authenticated
682 def get(self, profile):
683 def get(self, profile):
683 cm = self.application.cluster_manager
684 cm = self.application.cluster_manager
684 self.finish(jsonapi.dumps(cm.profile_info(profile)))
685 self.finish(jsonapi.dumps(cm.profile_info(profile)))
685
686
686
687
687 class ClusterActionHandler(AuthenticatedHandler):
688 class ClusterActionHandler(AuthenticatedHandler):
688
689
689 @web.authenticated
690 @web.authenticated
690 def post(self, profile, action):
691 def post(self, profile, action):
691 cm = self.application.cluster_manager
692 cm = self.application.cluster_manager
692 if action == 'start':
693 if action == 'start':
693 n = self.get_argument('n',default=None)
694 n = self.get_argument('n',default=None)
694 if n is None:
695 if n is None:
695 data = cm.start_cluster(profile)
696 data = cm.start_cluster(profile)
696 else:
697 else:
697 data = cm.start_cluster(profile,int(n))
698 data = cm.start_cluster(profile,int(n))
698 if action == 'stop':
699 if action == 'stop':
699 data = cm.stop_cluster(profile)
700 data = cm.stop_cluster(profile)
700 self.finish(jsonapi.dumps(data))
701 self.finish(jsonapi.dumps(data))
701
702
702
703
703 #-----------------------------------------------------------------------------
704 #-----------------------------------------------------------------------------
704 # RST web service handlers
705 # RST web service handlers
705 #-----------------------------------------------------------------------------
706 #-----------------------------------------------------------------------------
706
707
707
708
708 class RSTHandler(AuthenticatedHandler):
709 class RSTHandler(AuthenticatedHandler):
709
710
710 @web.authenticated
711 @web.authenticated
711 def post(self):
712 def post(self):
712 if publish_string is None:
713 if publish_string is None:
713 raise web.HTTPError(503, u'docutils not available')
714 raise web.HTTPError(503, u'docutils not available')
714 body = self.request.body.strip()
715 body = self.request.body.strip()
715 source = body
716 source = body
716 # template_path=os.path.join(os.path.dirname(__file__), u'templates', u'rst_template.html')
717 # template_path=os.path.join(os.path.dirname(__file__), u'templates', u'rst_template.html')
717 defaults = {'file_insertion_enabled': 0,
718 defaults = {'file_insertion_enabled': 0,
718 'raw_enabled': 0,
719 'raw_enabled': 0,
719 '_disable_config': 1,
720 '_disable_config': 1,
720 'stylesheet_path': 0
721 'stylesheet_path': 0
721 # 'template': template_path
722 # 'template': template_path
722 }
723 }
723 try:
724 try:
724 html = publish_string(source, writer_name='html',
725 html = publish_string(source, writer_name='html',
725 settings_overrides=defaults
726 settings_overrides=defaults
726 )
727 )
727 except:
728 except:
728 raise web.HTTPError(400, u'Invalid RST')
729 raise web.HTTPError(400, u'Invalid RST')
729 print html
730 print html
730 self.set_header('Content-Type', 'text/html')
731 self.set_header('Content-Type', 'text/html')
731 self.finish(html)
732 self.finish(html)
732
733
733
734
@@ -1,565 +1,565 b''
1 # coding: utf-8
1 # coding: utf-8
2 """A tornado based IPython notebook server.
2 """A tornado based IPython notebook server.
3
3
4 Authors:
4 Authors:
5
5
6 * Brian Granger
6 * Brian Granger
7 """
7 """
8 #-----------------------------------------------------------------------------
8 #-----------------------------------------------------------------------------
9 # Copyright (C) 2008-2011 The IPython Development Team
9 # Copyright (C) 2008-2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18
18
19 # stdlib
19 # stdlib
20 import errno
20 import errno
21 import logging
21 import logging
22 import os
22 import os
23 import re
23 import re
24 import select
24 import select
25 import signal
25 import signal
26 import socket
26 import socket
27 import sys
27 import sys
28 import threading
28 import threading
29 import time
29 import time
30 import webbrowser
30 import webbrowser
31
31
32 # Third party
32 # Third party
33 import zmq
33 import zmq
34
34
35 # Install the pyzmq ioloop. This has to be done before anything else from
35 # Install the pyzmq ioloop. This has to be done before anything else from
36 # tornado is imported.
36 # tornado is imported.
37 from zmq.eventloop import ioloop
37 from zmq.eventloop import ioloop
38 ioloop.install()
38 ioloop.install()
39
39
40 from tornado import httpserver
40 from tornado import httpserver
41 from tornado import web
41 from tornado import web
42
42
43 # Our own libraries
43 # Our own libraries
44 from .kernelmanager import MappingKernelManager
44 from .kernelmanager import MappingKernelManager
45 from .handlers import (LoginHandler, LogoutHandler,
45 from .handlers import (LoginHandler, LogoutHandler,
46 ProjectDashboardHandler, NewHandler, NamedNotebookHandler,
46 ProjectDashboardHandler, NewHandler, NamedNotebookHandler,
47 MainKernelHandler, KernelHandler, KernelActionHandler, IOPubHandler,
47 MainKernelHandler, KernelHandler, KernelActionHandler, IOPubHandler,
48 ShellHandler, NotebookRootHandler, NotebookHandler, NotebookCopyHandler,
48 ShellHandler, NotebookRootHandler, NotebookHandler, NotebookCopyHandler,
49 RSTHandler, AuthenticatedFileHandler, PrintNotebookHandler,
49 RSTHandler, AuthenticatedFileHandler, PrintNotebookHandler,
50 MainClusterHandler, ClusterProfileHandler, ClusterActionHandler
50 MainClusterHandler, ClusterProfileHandler, ClusterActionHandler
51 )
51 )
52 from .notebookmanager import NotebookManager
52 from .notebookmanager import NotebookManager
53 from .clustermanager import ClusterManager
53 from .clustermanager import ClusterManager
54
54
55 from IPython.config.application import catch_config_error, boolean_flag
55 from IPython.config.application import catch_config_error, boolean_flag
56 from IPython.core.application import BaseIPythonApplication
56 from IPython.core.application import BaseIPythonApplication
57 from IPython.core.profiledir import ProfileDir
57 from IPython.core.profiledir import ProfileDir
58 from IPython.lib.kernel import swallow_argv
58 from IPython.lib.kernel import swallow_argv
59 from IPython.zmq.session import Session, default_secure
59 from IPython.zmq.session import Session, default_secure
60 from IPython.zmq.zmqshell import ZMQInteractiveShell
60 from IPython.zmq.zmqshell import ZMQInteractiveShell
61 from IPython.zmq.ipkernel import (
61 from IPython.zmq.ipkernel import (
62 flags as ipkernel_flags,
62 flags as ipkernel_flags,
63 aliases as ipkernel_aliases,
63 aliases as ipkernel_aliases,
64 IPKernelApp
64 IPKernelApp
65 )
65 )
66 from IPython.utils.traitlets import Dict, Unicode, Integer, List, Enum, Bool
66 from IPython.utils.traitlets import Dict, Unicode, Integer, List, Enum, Bool
67 from IPython.utils import py3compat
67 from IPython.utils import py3compat
68
68
69 #-----------------------------------------------------------------------------
69 #-----------------------------------------------------------------------------
70 # Module globals
70 # Module globals
71 #-----------------------------------------------------------------------------
71 #-----------------------------------------------------------------------------
72
72
73 _kernel_id_regex = r"(?P<kernel_id>\w+-\w+-\w+-\w+-\w+)"
73 _kernel_id_regex = r"(?P<kernel_id>\w+-\w+-\w+-\w+-\w+)"
74 _kernel_action_regex = r"(?P<action>restart|interrupt)"
74 _kernel_action_regex = r"(?P<action>restart|interrupt)"
75 _notebook_id_regex = r"(?P<notebook_id>\w+-\w+-\w+-\w+-\w+)"
75 _notebook_id_regex = r"(?P<notebook_id>\w+-\w+-\w+-\w+-\w+)"
76 _profile_regex = r"(?P<profile>[a-zA-Z0-9]+)"
76 _profile_regex = r"(?P<profile>[a-zA-Z0-9]+)"
77 _cluster_action_regex = r"(?P<action>start|stop)"
77 _cluster_action_regex = r"(?P<action>start|stop)"
78
78
79
79
80 LOCALHOST = '127.0.0.1'
80 LOCALHOST = '127.0.0.1'
81
81
82 _examples = """
82 _examples = """
83 ipython notebook # start the notebook
83 ipython notebook # start the notebook
84 ipython notebook --profile=sympy # use the sympy profile
84 ipython notebook --profile=sympy # use the sympy profile
85 ipython notebook --pylab=inline # pylab in inline plotting mode
85 ipython notebook --pylab=inline # pylab in inline plotting mode
86 ipython notebook --certfile=mycert.pem # use SSL/TLS certificate
86 ipython notebook --certfile=mycert.pem # use SSL/TLS certificate
87 ipython notebook --port=5555 --ip=* # Listen on port 5555, all interfaces
87 ipython notebook --port=5555 --ip=* # Listen on port 5555, all interfaces
88 """
88 """
89
89
90 #-----------------------------------------------------------------------------
90 #-----------------------------------------------------------------------------
91 # Helper functions
91 # Helper functions
92 #-----------------------------------------------------------------------------
92 #-----------------------------------------------------------------------------
93
93
94 def url_path_join(a,b):
94 def url_path_join(a,b):
95 if a.endswith('/') and b.startswith('/'):
95 if a.endswith('/') and b.startswith('/'):
96 return a[:-1]+b
96 return a[:-1]+b
97 else:
97 else:
98 return a+b
98 return a+b
99
99
100 #-----------------------------------------------------------------------------
100 #-----------------------------------------------------------------------------
101 # The Tornado web application
101 # The Tornado web application
102 #-----------------------------------------------------------------------------
102 #-----------------------------------------------------------------------------
103
103
104 class NotebookWebApplication(web.Application):
104 class NotebookWebApplication(web.Application):
105
105
106 def __init__(self, ipython_app, kernel_manager, notebook_manager,
106 def __init__(self, ipython_app, kernel_manager, notebook_manager,
107 cluster_manager, log,
107 cluster_manager, log,
108 base_project_url, settings_overrides):
108 base_project_url, settings_overrides):
109 handlers = [
109 handlers = [
110 (r"/", ProjectDashboardHandler),
110 (r"/", ProjectDashboardHandler),
111 (r"/login", LoginHandler),
111 (r"/login", LoginHandler),
112 (r"/logout", LogoutHandler),
112 (r"/logout", LogoutHandler),
113 (r"/new", NewHandler),
113 (r"/new", NewHandler),
114 (r"/%s" % _notebook_id_regex, NamedNotebookHandler),
114 (r"/%s" % _notebook_id_regex, NamedNotebookHandler),
115 (r"/%s/copy" % _notebook_id_regex, NotebookCopyHandler),
115 (r"/%s/copy" % _notebook_id_regex, NotebookCopyHandler),
116 (r"/%s/print" % _notebook_id_regex, PrintNotebookHandler),
116 (r"/%s/print" % _notebook_id_regex, PrintNotebookHandler),
117 (r"/kernels", MainKernelHandler),
117 (r"/kernels", MainKernelHandler),
118 (r"/kernels/%s" % _kernel_id_regex, KernelHandler),
118 (r"/kernels/%s" % _kernel_id_regex, KernelHandler),
119 (r"/kernels/%s/%s" % (_kernel_id_regex, _kernel_action_regex), KernelActionHandler),
119 (r"/kernels/%s/%s" % (_kernel_id_regex, _kernel_action_regex), KernelActionHandler),
120 (r"/kernels/%s/iopub" % _kernel_id_regex, IOPubHandler),
120 (r"/kernels/%s/iopub" % _kernel_id_regex, IOPubHandler),
121 (r"/kernels/%s/shell" % _kernel_id_regex, ShellHandler),
121 (r"/kernels/%s/shell" % _kernel_id_regex, ShellHandler),
122 (r"/notebooks", NotebookRootHandler),
122 (r"/notebooks", NotebookRootHandler),
123 (r"/notebooks/%s" % _notebook_id_regex, NotebookHandler),
123 (r"/notebooks/%s" % _notebook_id_regex, NotebookHandler),
124 (r"/rstservice/render", RSTHandler),
124 (r"/rstservice/render", RSTHandler),
125 (r"/files/(.*)", AuthenticatedFileHandler, {'path' : notebook_manager.notebook_dir}),
125 (r"/files/(.*)", AuthenticatedFileHandler, {'path' : notebook_manager.notebook_dir}),
126 (r"/clusters", MainClusterHandler),
126 (r"/clusters", MainClusterHandler),
127 (r"/clusters/%s/%s" % (_profile_regex, _cluster_action_regex), ClusterActionHandler),
127 (r"/clusters/%s/%s" % (_profile_regex, _cluster_action_regex), ClusterActionHandler),
128 (r"/clusters/%s" % _profile_regex, ClusterProfileHandler),
128 (r"/clusters/%s" % _profile_regex, ClusterProfileHandler),
129 ]
129 ]
130 settings = dict(
130 settings = dict(
131 template_path=os.path.join(os.path.dirname(__file__), "templates"),
131 template_path=os.path.join(os.path.dirname(__file__), "templates"),
132 static_path=os.path.join(os.path.dirname(__file__), "static"),
132 static_path=os.path.join(os.path.dirname(__file__), "static"),
133 cookie_secret=os.urandom(1024),
133 cookie_secret=os.urandom(1024),
134 login_url="/login",
134 login_url="/login",
135 )
135 )
136
136
137 # allow custom overrides for the tornado web app.
137 # allow custom overrides for the tornado web app.
138 settings.update(settings_overrides)
138 settings.update(settings_overrides)
139
139
140 # Python < 2.6.5 doesn't accept unicode keys in f(**kwargs), and
140 # Python < 2.6.5 doesn't accept unicode keys in f(**kwargs), and
141 # base_project_url will always be unicode, which will in turn
141 # base_project_url will always be unicode, which will in turn
142 # make the patterns unicode, and ultimately result in unicode
142 # make the patterns unicode, and ultimately result in unicode
143 # keys in kwargs to handler._execute(**kwargs) in tornado.
143 # keys in kwargs to handler._execute(**kwargs) in tornado.
144 # This enforces that base_project_url be ascii in that situation.
144 # This enforces that base_project_url be ascii in that situation.
145 #
145 #
146 # Note that the URLs these patterns check against are escaped,
146 # Note that the URLs these patterns check against are escaped,
147 # and thus guaranteed to be ASCII: 'héllo' is really 'h%C3%A9llo'.
147 # and thus guaranteed to be ASCII: 'héllo' is really 'h%C3%A9llo'.
148 base_project_url = py3compat.unicode_to_str(base_project_url, 'ascii')
148 base_project_url = py3compat.unicode_to_str(base_project_url, 'ascii')
149
149
150 # prepend base_project_url onto the patterns that we match
150 # prepend base_project_url onto the patterns that we match
151 new_handlers = []
151 new_handlers = []
152 for handler in handlers:
152 for handler in handlers:
153 pattern = url_path_join(base_project_url, handler[0])
153 pattern = url_path_join(base_project_url, handler[0])
154 new_handler = tuple([pattern]+list(handler[1:]))
154 new_handler = tuple([pattern]+list(handler[1:]))
155 new_handlers.append( new_handler )
155 new_handlers.append( new_handler )
156
156
157 super(NotebookWebApplication, self).__init__(new_handlers, **settings)
157 super(NotebookWebApplication, self).__init__(new_handlers, **settings)
158
158
159 self.kernel_manager = kernel_manager
159 self.kernel_manager = kernel_manager
160 self.notebook_manager = notebook_manager
160 self.notebook_manager = notebook_manager
161 self.cluster_manager = cluster_manager
161 self.cluster_manager = cluster_manager
162 self.ipython_app = ipython_app
162 self.ipython_app = ipython_app
163 self.read_only = self.ipython_app.read_only
163 self.read_only = self.ipython_app.read_only
164 self.log = log
164 self.log = log
165
165
166
166
167 #-----------------------------------------------------------------------------
167 #-----------------------------------------------------------------------------
168 # Aliases and Flags
168 # Aliases and Flags
169 #-----------------------------------------------------------------------------
169 #-----------------------------------------------------------------------------
170
170
171 flags = dict(ipkernel_flags)
171 flags = dict(ipkernel_flags)
172 flags['no-browser']=(
172 flags['no-browser']=(
173 {'NotebookApp' : {'open_browser' : False}},
173 {'NotebookApp' : {'open_browser' : False}},
174 "Don't open the notebook in a browser after startup."
174 "Don't open the notebook in a browser after startup."
175 )
175 )
176 flags['no-mathjax']=(
176 flags['no-mathjax']=(
177 {'NotebookApp' : {'enable_mathjax' : False}},
177 {'NotebookApp' : {'enable_mathjax' : False}},
178 """Disable MathJax
178 """Disable MathJax
179
179
180 MathJax is the javascript library IPython uses to render math/LaTeX. It is
180 MathJax is the javascript library IPython uses to render math/LaTeX. It is
181 very large, so you may want to disable it if you have a slow internet
181 very large, so you may want to disable it if you have a slow internet
182 connection, or for offline use of the notebook.
182 connection, or for offline use of the notebook.
183
183
184 When disabled, equations etc. will appear as their untransformed TeX source.
184 When disabled, equations etc. will appear as their untransformed TeX source.
185 """
185 """
186 )
186 )
187 flags['read-only'] = (
187 flags['read-only'] = (
188 {'NotebookApp' : {'read_only' : True}},
188 {'NotebookApp' : {'read_only' : True}},
189 """Allow read-only access to notebooks.
189 """Allow read-only access to notebooks.
190
190
191 When using a password to protect the notebook server, this flag
191 When using a password to protect the notebook server, this flag
192 allows unauthenticated clients to view the notebook list, and
192 allows unauthenticated clients to view the notebook list, and
193 individual notebooks, but not edit them, start kernels, or run
193 individual notebooks, but not edit them, start kernels, or run
194 code.
194 code.
195
195
196 If no password is set, the server will be entirely read-only.
196 If no password is set, the server will be entirely read-only.
197 """
197 """
198 )
198 )
199
199
200 # Add notebook manager flags
200 # Add notebook manager flags
201 flags.update(boolean_flag('script', 'NotebookManager.save_script',
201 flags.update(boolean_flag('script', 'NotebookManager.save_script',
202 'Auto-save a .py script everytime the .ipynb notebook is saved',
202 'Auto-save a .py script everytime the .ipynb notebook is saved',
203 'Do not auto-save .py scripts for every notebook'))
203 'Do not auto-save .py scripts for every notebook'))
204
204
205 # the flags that are specific to the frontend
205 # the flags that are specific to the frontend
206 # these must be scrubbed before being passed to the kernel,
206 # these must be scrubbed before being passed to the kernel,
207 # or it will raise an error on unrecognized flags
207 # or it will raise an error on unrecognized flags
208 notebook_flags = ['no-browser', 'no-mathjax', 'read-only', 'script', 'no-script']
208 notebook_flags = ['no-browser', 'no-mathjax', 'read-only', 'script', 'no-script']
209
209
210 aliases = dict(ipkernel_aliases)
210 aliases = dict(ipkernel_aliases)
211
211
212 aliases.update({
212 aliases.update({
213 'ip': 'NotebookApp.ip',
213 'ip': 'NotebookApp.ip',
214 'port': 'NotebookApp.port',
214 'port': 'NotebookApp.port',
215 'keyfile': 'NotebookApp.keyfile',
215 'keyfile': 'NotebookApp.keyfile',
216 'certfile': 'NotebookApp.certfile',
216 'certfile': 'NotebookApp.certfile',
217 'notebook-dir': 'NotebookManager.notebook_dir',
217 'notebook-dir': 'NotebookManager.notebook_dir',
218 'browser': 'NotebookApp.browser',
218 'browser': 'NotebookApp.browser',
219 })
219 })
220
220
221 # remove ipkernel flags that are singletons, and don't make sense in
221 # remove ipkernel flags that are singletons, and don't make sense in
222 # multi-kernel evironment:
222 # multi-kernel evironment:
223 aliases.pop('f', None)
223 aliases.pop('f', None)
224
224
225 notebook_aliases = [u'port', u'ip', u'keyfile', u'certfile',
225 notebook_aliases = [u'port', u'ip', u'keyfile', u'certfile',
226 u'notebook-dir']
226 u'notebook-dir']
227
227
228 #-----------------------------------------------------------------------------
228 #-----------------------------------------------------------------------------
229 # NotebookApp
229 # NotebookApp
230 #-----------------------------------------------------------------------------
230 #-----------------------------------------------------------------------------
231
231
232 class NotebookApp(BaseIPythonApplication):
232 class NotebookApp(BaseIPythonApplication):
233
233
234 name = 'ipython-notebook'
234 name = 'ipython-notebook'
235 default_config_file_name='ipython_notebook_config.py'
235 default_config_file_name='ipython_notebook_config.py'
236
236
237 description = """
237 description = """
238 The IPython HTML Notebook.
238 The IPython HTML Notebook.
239
239
240 This launches a Tornado based HTML Notebook Server that serves up an
240 This launches a Tornado based HTML Notebook Server that serves up an
241 HTML5/Javascript Notebook client.
241 HTML5/Javascript Notebook client.
242 """
242 """
243 examples = _examples
243 examples = _examples
244
244
245 classes = [IPKernelApp, ZMQInteractiveShell, ProfileDir, Session,
245 classes = [IPKernelApp, ZMQInteractiveShell, ProfileDir, Session,
246 MappingKernelManager, NotebookManager]
246 MappingKernelManager, NotebookManager]
247 flags = Dict(flags)
247 flags = Dict(flags)
248 aliases = Dict(aliases)
248 aliases = Dict(aliases)
249
249
250 kernel_argv = List(Unicode)
250 kernel_argv = List(Unicode)
251
251
252 log_level = Enum((0,10,20,30,40,50,'DEBUG','INFO','WARN','ERROR','CRITICAL'),
252 log_level = Enum((0,10,20,30,40,50,'DEBUG','INFO','WARN','ERROR','CRITICAL'),
253 default_value=logging.INFO,
253 default_value=logging.INFO,
254 config=True,
254 config=True,
255 help="Set the log level by value or name.")
255 help="Set the log level by value or name.")
256
256
257 # create requested profiles by default, if they don't exist:
257 # create requested profiles by default, if they don't exist:
258 auto_create = Bool(True)
258 auto_create = Bool(True)
259
259
260 # file to be opened in the notebook server
260 # file to be opened in the notebook server
261 file_to_run = Unicode('')
261 file_to_run = Unicode('')
262
262
263 # Network related information.
263 # Network related information.
264
264
265 ip = Unicode(LOCALHOST, config=True,
265 ip = Unicode(LOCALHOST, config=True,
266 help="The IP address the notebook server will listen on."
266 help="The IP address the notebook server will listen on."
267 )
267 )
268
268
269 def _ip_changed(self, name, old, new):
269 def _ip_changed(self, name, old, new):
270 if new == u'*': self.ip = u''
270 if new == u'*': self.ip = u''
271
271
272 port = Integer(8888, config=True,
272 port = Integer(8888, config=True,
273 help="The port the notebook server will listen on."
273 help="The port the notebook server will listen on."
274 )
274 )
275
275
276 certfile = Unicode(u'', config=True,
276 certfile = Unicode(u'', config=True,
277 help="""The full path to an SSL/TLS certificate file."""
277 help="""The full path to an SSL/TLS certificate file."""
278 )
278 )
279
279
280 keyfile = Unicode(u'', config=True,
280 keyfile = Unicode(u'', config=True,
281 help="""The full path to a private key file for usage with SSL/TLS."""
281 help="""The full path to a private key file for usage with SSL/TLS."""
282 )
282 )
283
283
284 password = Unicode(u'', config=True,
284 password = Unicode(u'', config=True,
285 help="""Hashed password to use for web authentication.
285 help="""Hashed password to use for web authentication.
286
286
287 To generate, type in a python/IPython shell:
287 To generate, type in a python/IPython shell:
288
288
289 from IPython.lib import passwd; passwd()
289 from IPython.lib import passwd; passwd()
290
290
291 The string should be of the form type:salt:hashed-password.
291 The string should be of the form type:salt:hashed-password.
292 """
292 """
293 )
293 )
294
294
295 open_browser = Bool(True, config=True,
295 open_browser = Bool(True, config=True,
296 help="""Whether to open in a browser after starting.
296 help="""Whether to open in a browser after starting.
297 The specific browser used is platform dependent and
297 The specific browser used is platform dependent and
298 determined by the python standard library `webbrowser`
298 determined by the python standard library `webbrowser`
299 module, unless it is overridden using the --browser
299 module, unless it is overridden using the --browser
300 (NotebookApp.browser) configuration option.
300 (NotebookApp.browser) configuration option.
301 """)
301 """)
302
302
303 browser = Unicode(u'', config=True,
303 browser = Unicode(u'', config=True,
304 help="""Specify what command to use to invoke a web
304 help="""Specify what command to use to invoke a web
305 browser when opening the notebook. If not specified, the
305 browser when opening the notebook. If not specified, the
306 default browser will be determined by the `webbrowser`
306 default browser will be determined by the `webbrowser`
307 standard library module, which allows setting of the
307 standard library module, which allows setting of the
308 BROWSER environment variable to override it.
308 BROWSER environment variable to override it.
309 """)
309 """)
310
310
311 read_only = Bool(False, config=True,
311 read_only = Bool(False, config=True,
312 help="Whether to prevent editing/execution of notebooks."
312 help="Whether to prevent editing/execution of notebooks."
313 )
313 )
314
314
315 webapp_settings = Dict(config=True,
315 webapp_settings = Dict(config=True,
316 help="Supply overrides for the tornado.web.Application that the "
316 help="Supply overrides for the tornado.web.Application that the "
317 "IPython notebook uses.")
317 "IPython notebook uses.")
318
318
319 enable_mathjax = Bool(True, config=True,
319 enable_mathjax = Bool(True, config=True,
320 help="""Whether to enable MathJax for typesetting math/TeX
320 help="""Whether to enable MathJax for typesetting math/TeX
321
321
322 MathJax is the javascript library IPython uses to render math/LaTeX. It is
322 MathJax is the javascript library IPython uses to render math/LaTeX. It is
323 very large, so you may want to disable it if you have a slow internet
323 very large, so you may want to disable it if you have a slow internet
324 connection, or for offline use of the notebook.
324 connection, or for offline use of the notebook.
325
325
326 When disabled, equations etc. will appear as their untransformed TeX source.
326 When disabled, equations etc. will appear as their untransformed TeX source.
327 """
327 """
328 )
328 )
329 def _enable_mathjax_changed(self, name, old, new):
329 def _enable_mathjax_changed(self, name, old, new):
330 """set mathjax url to empty if mathjax is disabled"""
330 """set mathjax url to empty if mathjax is disabled"""
331 if not new:
331 if not new:
332 self.mathjax_url = u''
332 self.mathjax_url = u''
333
333
334 base_project_url = Unicode('/', config=True,
334 base_project_url = Unicode('/', config=True,
335 help='''The base URL for the notebook server''')
335 help='''The base URL for the notebook server''')
336 base_kernel_url = Unicode('/', config=True,
336 base_kernel_url = Unicode('/', config=True,
337 help='''The base URL for the kernel server''')
337 help='''The base URL for the kernel server''')
338 websocket_host = Unicode("", config=True,
338 websocket_host = Unicode("", config=True,
339 help="""The hostname for the websocket server."""
339 help="""The hostname for the websocket server."""
340 )
340 )
341
341
342 mathjax_url = Unicode("", config=True,
342 mathjax_url = Unicode("", config=True,
343 help="""The url for MathJax.js."""
343 help="""The url for MathJax.js."""
344 )
344 )
345 def _mathjax_url_default(self):
345 def _mathjax_url_default(self):
346 if not self.enable_mathjax:
346 if not self.enable_mathjax:
347 return u''
347 return u''
348 static_path = self.webapp_settings.get("static_path", os.path.join(os.path.dirname(__file__), "static"))
348 static_path = self.webapp_settings.get("static_path", os.path.join(os.path.dirname(__file__), "static"))
349 static_url_prefix = self.webapp_settings.get("static_url_prefix",
349 static_url_prefix = self.webapp_settings.get("static_url_prefix",
350 "/static/")
350 "/static/")
351 if os.path.exists(os.path.join(static_path, 'mathjax', "MathJax.js")):
351 if os.path.exists(os.path.join(static_path, 'mathjax', "MathJax.js")):
352 self.log.info("Using local MathJax")
352 self.log.info("Using local MathJax")
353 return static_url_prefix+u"mathjax/MathJax.js"
353 return static_url_prefix+u"mathjax/MathJax.js"
354 else:
354 else:
355 if self.certfile:
355 if self.certfile:
356 # HTTPS: load from Rackspace CDN, because SSL certificate requires it
356 # HTTPS: load from Rackspace CDN, because SSL certificate requires it
357 base = u"https://c328740.ssl.cf1.rackcdn.com"
357 base = u"https://c328740.ssl.cf1.rackcdn.com"
358 else:
358 else:
359 base = u"http://cdn.mathjax.org"
359 base = u"http://cdn.mathjax.org"
360
360
361 url = base + u"/mathjax/latest/MathJax.js"
361 url = base + u"/mathjax/latest/MathJax.js"
362 self.log.info("Using MathJax from CDN: %s", url)
362 self.log.info("Using MathJax from CDN: %s", url)
363 return url
363 return url
364
364
365 def _mathjax_url_changed(self, name, old, new):
365 def _mathjax_url_changed(self, name, old, new):
366 if new and not self.enable_mathjax:
366 if new and not self.enable_mathjax:
367 # enable_mathjax=False overrides mathjax_url
367 # enable_mathjax=False overrides mathjax_url
368 self.mathjax_url = u''
368 self.mathjax_url = u''
369 else:
369 else:
370 self.log.info("Using MathJax: %s", new)
370 self.log.info("Using MathJax: %s", new)
371
371
372 def parse_command_line(self, argv=None):
372 def parse_command_line(self, argv=None):
373 super(NotebookApp, self).parse_command_line(argv)
373 super(NotebookApp, self).parse_command_line(argv)
374 if argv is None:
374 if argv is None:
375 argv = sys.argv[1:]
375 argv = sys.argv[1:]
376
376
377 # Scrub frontend-specific flags
377 # Scrub frontend-specific flags
378 self.kernel_argv = swallow_argv(argv, notebook_aliases, notebook_flags)
378 self.kernel_argv = swallow_argv(argv, notebook_aliases, notebook_flags)
379 # Kernel should inherit default config file from frontend
379 # Kernel should inherit default config file from frontend
380 self.kernel_argv.append("--KernelApp.parent_appname='%s'"%self.name)
380 self.kernel_argv.append("--KernelApp.parent_appname='%s'"%self.name)
381
381
382 if self.extra_args:
382 if self.extra_args:
383 self.file_to_run = os.path.abspath(self.extra_args[0])
383 self.file_to_run = os.path.abspath(self.extra_args[0])
384 self.config.NotebookManager.notebook_dir = os.path.dirname(self.file_to_run)
384 self.config.NotebookManager.notebook_dir = os.path.dirname(self.file_to_run)
385
385
386 def init_configurables(self):
386 def init_configurables(self):
387 # force Session default to be secure
387 # force Session default to be secure
388 default_secure(self.config)
388 default_secure(self.config)
389 # Create a KernelManager and start a kernel.
389 # Create a KernelManager and start a kernel.
390 self.kernel_manager = MappingKernelManager(
390 self.kernel_manager = MappingKernelManager(
391 config=self.config, log=self.log, kernel_argv=self.kernel_argv,
391 config=self.config, log=self.log, kernel_argv=self.kernel_argv,
392 connection_dir = self.profile_dir.security_dir,
392 connection_dir = self.profile_dir.security_dir,
393 )
393 )
394 self.notebook_manager = NotebookManager(config=self.config, log=self.log)
394 self.notebook_manager = NotebookManager(config=self.config, log=self.log)
395 self.notebook_manager.list_notebooks()
395 self.notebook_manager.list_notebooks()
396 self.cluster_manager = ClusterManager(config=self.config, log=self.log)
396 self.cluster_manager = ClusterManager(config=self.config, log=self.log)
397 self.cluster_manager.update_profiles()
397 self.cluster_manager.update_profiles()
398
398
399 def init_logging(self):
399 def init_logging(self):
400 super(NotebookApp, self).init_logging()
401 # This prevents double log messages because tornado use a root logger that
400 # This prevents double log messages because tornado use a root logger that
402 # self.log is a child of. The logging module dipatches log messages to a log
401 # self.log is a child of. The logging module dipatches log messages to a log
403 # and all of its ancenstors until propagate is set to False.
402 # and all of its ancenstors until propagate is set to False.
404 self.log.propagate = False
403 self.log.propagate = False
405
404
406 def init_webapp(self):
405 def init_webapp(self):
407 """initialize tornado webapp and httpserver"""
406 """initialize tornado webapp and httpserver"""
408 self.web_app = NotebookWebApplication(
407 self.web_app = NotebookWebApplication(
409 self, self.kernel_manager, self.notebook_manager,
408 self, self.kernel_manager, self.notebook_manager,
410 self.cluster_manager, self.log,
409 self.cluster_manager, self.log,
411 self.base_project_url, self.webapp_settings
410 self.base_project_url, self.webapp_settings
412 )
411 )
413 if self.certfile:
412 if self.certfile:
414 ssl_options = dict(certfile=self.certfile)
413 ssl_options = dict(certfile=self.certfile)
415 if self.keyfile:
414 if self.keyfile:
416 ssl_options['keyfile'] = self.keyfile
415 ssl_options['keyfile'] = self.keyfile
417 else:
416 else:
418 ssl_options = None
417 ssl_options = None
419 self.web_app.password = self.password
418 self.web_app.password = self.password
420 self.http_server = httpserver.HTTPServer(self.web_app, ssl_options=ssl_options)
419 self.http_server = httpserver.HTTPServer(self.web_app, ssl_options=ssl_options)
421 if ssl_options is None and not self.ip and not (self.read_only and not self.password):
420 if ssl_options is None and not self.ip and not (self.read_only and not self.password):
422 self.log.critical('WARNING: the notebook server is listening on all IP addresses '
421 self.log.critical('WARNING: the notebook server is listening on all IP addresses '
423 'but not using any encryption or authentication. This is highly '
422 'but not using any encryption or authentication. This is highly '
424 'insecure and not recommended.')
423 'insecure and not recommended.')
425
424
426 # Try random ports centered around the default.
425 # Try random ports centered around the default.
427 from random import randint
426 from random import randint
428 n = 50 # Max number of attempts, keep reasonably large.
427 n = 50 # Max number of attempts, keep reasonably large.
429 for port in range(self.port, self.port+5) + [self.port + randint(-2*n, 2*n) for i in range(n-5)]:
428 for port in range(self.port, self.port+5) + [self.port + randint(-2*n, 2*n) for i in range(n-5)]:
430 try:
429 try:
431 self.http_server.listen(port, self.ip)
430 self.http_server.listen(port, self.ip)
432 except socket.error, e:
431 except socket.error, e:
433 if e.errno != errno.EADDRINUSE:
432 if e.errno != errno.EADDRINUSE:
434 raise
433 raise
435 self.log.info('The port %i is already in use, trying another random port.' % port)
434 self.log.info('The port %i is already in use, trying another random port.' % port)
436 else:
435 else:
437 self.port = port
436 self.port = port
438 break
437 break
439
438
440 def init_signal(self):
439 def init_signal(self):
441 # FIXME: remove this check when pyzmq dependency is >= 2.1.11
440 # FIXME: remove this check when pyzmq dependency is >= 2.1.11
442 # safely extract zmq version info:
441 # safely extract zmq version info:
443 try:
442 try:
444 zmq_v = zmq.pyzmq_version_info()
443 zmq_v = zmq.pyzmq_version_info()
445 except AttributeError:
444 except AttributeError:
446 zmq_v = [ int(n) for n in re.findall(r'\d+', zmq.__version__) ]
445 zmq_v = [ int(n) for n in re.findall(r'\d+', zmq.__version__) ]
447 if 'dev' in zmq.__version__:
446 if 'dev' in zmq.__version__:
448 zmq_v.append(999)
447 zmq_v.append(999)
449 zmq_v = tuple(zmq_v)
448 zmq_v = tuple(zmq_v)
450 if zmq_v >= (2,1,9):
449 if zmq_v >= (2,1,9):
451 # This won't work with 2.1.7 and
450 # This won't work with 2.1.7 and
452 # 2.1.9-10 will log ugly 'Interrupted system call' messages,
451 # 2.1.9-10 will log ugly 'Interrupted system call' messages,
453 # but it will work
452 # but it will work
454 signal.signal(signal.SIGINT, self._handle_sigint)
453 signal.signal(signal.SIGINT, self._handle_sigint)
455 signal.signal(signal.SIGTERM, self._signal_stop)
454 signal.signal(signal.SIGTERM, self._signal_stop)
456
455
457 def _handle_sigint(self, sig, frame):
456 def _handle_sigint(self, sig, frame):
458 """SIGINT handler spawns confirmation dialog"""
457 """SIGINT handler spawns confirmation dialog"""
459 # register more forceful signal handler for ^C^C case
458 # register more forceful signal handler for ^C^C case
460 signal.signal(signal.SIGINT, self._signal_stop)
459 signal.signal(signal.SIGINT, self._signal_stop)
461 # request confirmation dialog in bg thread, to avoid
460 # request confirmation dialog in bg thread, to avoid
462 # blocking the App
461 # blocking the App
463 thread = threading.Thread(target=self._confirm_exit)
462 thread = threading.Thread(target=self._confirm_exit)
464 thread.daemon = True
463 thread.daemon = True
465 thread.start()
464 thread.start()
466
465
467 def _restore_sigint_handler(self):
466 def _restore_sigint_handler(self):
468 """callback for restoring original SIGINT handler"""
467 """callback for restoring original SIGINT handler"""
469 signal.signal(signal.SIGINT, self._handle_sigint)
468 signal.signal(signal.SIGINT, self._handle_sigint)
470
469
471 def _confirm_exit(self):
470 def _confirm_exit(self):
472 """confirm shutdown on ^C
471 """confirm shutdown on ^C
473
472
474 A second ^C, or answering 'y' within 5s will cause shutdown,
473 A second ^C, or answering 'y' within 5s will cause shutdown,
475 otherwise original SIGINT handler will be restored.
474 otherwise original SIGINT handler will be restored.
476 """
475 """
477 # FIXME: remove this delay when pyzmq dependency is >= 2.1.11
476 # FIXME: remove this delay when pyzmq dependency is >= 2.1.11
478 time.sleep(0.1)
477 time.sleep(0.1)
479 sys.stdout.write("Shutdown Notebook Server (y/[n])? ")
478 sys.stdout.write("Shutdown Notebook Server (y/[n])? ")
480 sys.stdout.flush()
479 sys.stdout.flush()
481 r,w,x = select.select([sys.stdin], [], [], 5)
480 r,w,x = select.select([sys.stdin], [], [], 5)
482 if r:
481 if r:
483 line = sys.stdin.readline()
482 line = sys.stdin.readline()
484 if line.lower().startswith('y'):
483 if line.lower().startswith('y'):
485 self.log.critical("Shutdown confirmed")
484 self.log.critical("Shutdown confirmed")
486 ioloop.IOLoop.instance().stop()
485 ioloop.IOLoop.instance().stop()
487 return
486 return
488 else:
487 else:
489 print "No answer for 5s:",
488 print "No answer for 5s:",
490 print "resuming operation..."
489 print "resuming operation..."
491 # no answer, or answer is no:
490 # no answer, or answer is no:
492 # set it back to original SIGINT handler
491 # set it back to original SIGINT handler
493 # use IOLoop.add_callback because signal.signal must be called
492 # use IOLoop.add_callback because signal.signal must be called
494 # from main thread
493 # from main thread
495 ioloop.IOLoop.instance().add_callback(self._restore_sigint_handler)
494 ioloop.IOLoop.instance().add_callback(self._restore_sigint_handler)
496
495
497 def _signal_stop(self, sig, frame):
496 def _signal_stop(self, sig, frame):
498 self.log.critical("received signal %s, stopping", sig)
497 self.log.critical("received signal %s, stopping", sig)
499 ioloop.IOLoop.instance().stop()
498 ioloop.IOLoop.instance().stop()
500
499
501 @catch_config_error
500 @catch_config_error
502 def initialize(self, argv=None):
501 def initialize(self, argv=None):
502 self.init_logging()
503 super(NotebookApp, self).initialize(argv)
503 super(NotebookApp, self).initialize(argv)
504 self.init_configurables()
504 self.init_configurables()
505 self.init_webapp()
505 self.init_webapp()
506 self.init_signal()
506 self.init_signal()
507
507
508 def cleanup_kernels(self):
508 def cleanup_kernels(self):
509 """shutdown all kernels
509 """shutdown all kernels
510
510
511 The kernels will shutdown themselves when this process no longer exists,
511 The kernels will shutdown themselves when this process no longer exists,
512 but explicit shutdown allows the KernelManagers to cleanup the connection files.
512 but explicit shutdown allows the KernelManagers to cleanup the connection files.
513 """
513 """
514 self.log.info('Shutting down kernels')
514 self.log.info('Shutting down kernels')
515 km = self.kernel_manager
515 km = self.kernel_manager
516 # copy list, since kill_kernel deletes keys
516 # copy list, since kill_kernel deletes keys
517 for kid in list(km.kernel_ids):
517 for kid in list(km.kernel_ids):
518 km.kill_kernel(kid)
518 km.kill_kernel(kid)
519
519
520 def start(self):
520 def start(self):
521 ip = self.ip if self.ip else '[all ip addresses on your system]'
521 ip = self.ip if self.ip else '[all ip addresses on your system]'
522 proto = 'https' if self.certfile else 'http'
522 proto = 'https' if self.certfile else 'http'
523 info = self.log.info
523 info = self.log.info
524 info("The IPython Notebook is running at: %s://%s:%i%s" %
524 info("The IPython Notebook is running at: %s://%s:%i%s" %
525 (proto, ip, self.port,self.base_project_url) )
525 (proto, ip, self.port,self.base_project_url) )
526 info("Use Control-C to stop this server and shut down all kernels.")
526 info("Use Control-C to stop this server and shut down all kernels.")
527
527
528 if self.open_browser:
528 if self.open_browser:
529 ip = self.ip or '127.0.0.1'
529 ip = self.ip or '127.0.0.1'
530 if self.browser:
530 if self.browser:
531 browser = webbrowser.get(self.browser)
531 browser = webbrowser.get(self.browser)
532 else:
532 else:
533 browser = webbrowser.get()
533 browser = webbrowser.get()
534
534
535 if self.file_to_run:
535 if self.file_to_run:
536 filename, _ = os.path.splitext(os.path.basename(self.file_to_run))
536 filename, _ = os.path.splitext(os.path.basename(self.file_to_run))
537 for nb in self.notebook_manager.list_notebooks():
537 for nb in self.notebook_manager.list_notebooks():
538 if filename == nb['name']:
538 if filename == nb['name']:
539 url = nb['notebook_id']
539 url = nb['notebook_id']
540 break
540 break
541 else:
541 else:
542 url = ''
542 url = ''
543 else:
543 else:
544 url = ''
544 url = ''
545 b = lambda : browser.open("%s://%s:%i%s%s" % (proto, ip,
545 b = lambda : browser.open("%s://%s:%i%s%s" % (proto, ip,
546 self.port, self.base_project_url, url),
546 self.port, self.base_project_url, url),
547 new=2)
547 new=2)
548 threading.Thread(target=b).start()
548 threading.Thread(target=b).start()
549 try:
549 try:
550 ioloop.IOLoop.instance().start()
550 ioloop.IOLoop.instance().start()
551 except KeyboardInterrupt:
551 except KeyboardInterrupt:
552 info("Interrupted...")
552 info("Interrupted...")
553 finally:
553 finally:
554 self.cleanup_kernels()
554 self.cleanup_kernels()
555
555
556
556
557 #-----------------------------------------------------------------------------
557 #-----------------------------------------------------------------------------
558 # Main entry point
558 # Main entry point
559 #-----------------------------------------------------------------------------
559 #-----------------------------------------------------------------------------
560
560
561 def launch_new_instance():
561 def launch_new_instance():
562 app = NotebookApp.instance()
562 app = NotebookApp.instance()
563 app.initialize()
563 app.initialize()
564 app.start()
564 app.start()
565
565
@@ -1,382 +1,372 b''
1 """ A minimal application using the Qt console-style IPython frontend.
1 """ A minimal application using the Qt console-style IPython frontend.
2
2
3 This is not a complete console app, as subprocess will not be able to receive
3 This is not a complete console app, as subprocess will not be able to receive
4 input, there is no real readline support, among other limitations.
4 input, there is no real readline support, among other limitations.
5
5
6 Authors:
6 Authors:
7
7
8 * Evan Patterson
8 * Evan Patterson
9 * Min RK
9 * Min RK
10 * Erik Tollerud
10 * Erik Tollerud
11 * Fernando Perez
11 * Fernando Perez
12 * Bussonnier Matthias
12 * Bussonnier Matthias
13 * Thomas Kluyver
13 * Thomas Kluyver
14 * Paul Ivanov
14 * Paul Ivanov
15
15
16 """
16 """
17
17
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19 # Imports
19 # Imports
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21
21
22 # stdlib imports
22 # stdlib imports
23 import json
23 import json
24 import os
24 import os
25 import signal
25 import signal
26 import sys
26 import sys
27 import uuid
27 import uuid
28
28
29 # If run on Windows, install an exception hook which pops up a
29 # If run on Windows, install an exception hook which pops up a
30 # message box. Pythonw.exe hides the console, so without this
30 # message box. Pythonw.exe hides the console, so without this
31 # the application silently fails to load.
31 # the application silently fails to load.
32 #
32 #
33 # We always install this handler, because the expectation is for
33 # We always install this handler, because the expectation is for
34 # qtconsole to bring up a GUI even if called from the console.
34 # qtconsole to bring up a GUI even if called from the console.
35 # The old handler is called, so the exception is printed as well.
35 # The old handler is called, so the exception is printed as well.
36 # If desired, check for pythonw with an additional condition
36 # If desired, check for pythonw with an additional condition
37 # (sys.executable.lower().find('pythonw.exe') >= 0).
37 # (sys.executable.lower().find('pythonw.exe') >= 0).
38 if os.name == 'nt':
38 if os.name == 'nt':
39 old_excepthook = sys.excepthook
39 old_excepthook = sys.excepthook
40
40
41 def gui_excepthook(exctype, value, tb):
41 def gui_excepthook(exctype, value, tb):
42 try:
42 try:
43 import ctypes, traceback
43 import ctypes, traceback
44 MB_ICONERROR = 0x00000010L
44 MB_ICONERROR = 0x00000010L
45 title = u'Error starting IPython QtConsole'
45 title = u'Error starting IPython QtConsole'
46 msg = u''.join(traceback.format_exception(exctype, value, tb))
46 msg = u''.join(traceback.format_exception(exctype, value, tb))
47 ctypes.windll.user32.MessageBoxW(0, msg, title, MB_ICONERROR)
47 ctypes.windll.user32.MessageBoxW(0, msg, title, MB_ICONERROR)
48 finally:
48 finally:
49 # Also call the old exception hook to let it do
49 # Also call the old exception hook to let it do
50 # its thing too.
50 # its thing too.
51 old_excepthook(exctype, value, tb)
51 old_excepthook(exctype, value, tb)
52
52
53 sys.excepthook = gui_excepthook
53 sys.excepthook = gui_excepthook
54
54
55 # System library imports
55 # System library imports
56 from IPython.external.qt import QtCore, QtGui
56 from IPython.external.qt import QtCore, QtGui
57
57
58 # Local imports
58 # Local imports
59 from IPython.config.application import boolean_flag, catch_config_error
59 from IPython.config.application import boolean_flag, catch_config_error
60 from IPython.core.application import BaseIPythonApplication
60 from IPython.core.application import BaseIPythonApplication
61 from IPython.core.profiledir import ProfileDir
61 from IPython.core.profiledir import ProfileDir
62 from IPython.lib.kernel import tunnel_to_kernel, find_connection_file
62 from IPython.lib.kernel import tunnel_to_kernel, find_connection_file
63 from IPython.frontend.qt.console.frontend_widget import FrontendWidget
63 from IPython.frontend.qt.console.frontend_widget import FrontendWidget
64 from IPython.frontend.qt.console.ipython_widget import IPythonWidget
64 from IPython.frontend.qt.console.ipython_widget import IPythonWidget
65 from IPython.frontend.qt.console.rich_ipython_widget import RichIPythonWidget
65 from IPython.frontend.qt.console.rich_ipython_widget import RichIPythonWidget
66 from IPython.frontend.qt.console import styles
66 from IPython.frontend.qt.console import styles
67 from IPython.frontend.qt.console.mainwindow import MainWindow
67 from IPython.frontend.qt.console.mainwindow import MainWindow
68 from IPython.frontend.qt.kernelmanager import QtKernelManager
68 from IPython.frontend.qt.kernelmanager import QtKernelManager
69 from IPython.utils.path import filefind
69 from IPython.utils.path import filefind
70 from IPython.utils.py3compat import str_to_bytes
70 from IPython.utils.py3compat import str_to_bytes
71 from IPython.utils.traitlets import (
71 from IPython.utils.traitlets import (
72 Dict, List, Unicode, Integer, CaselessStrEnum, CBool, Any
72 Dict, List, Unicode, Integer, CaselessStrEnum, CBool, Any
73 )
73 )
74 from IPython.zmq.ipkernel import IPKernelApp
74 from IPython.zmq.ipkernel import IPKernelApp
75 from IPython.zmq.session import Session, default_secure
75 from IPython.zmq.session import Session, default_secure
76 from IPython.zmq.zmqshell import ZMQInteractiveShell
76 from IPython.zmq.zmqshell import ZMQInteractiveShell
77
77
78 from IPython.frontend.consoleapp import (
78 from IPython.frontend.consoleapp import (
79 IPythonConsoleApp, app_aliases, app_flags, flags, aliases
79 IPythonConsoleApp, app_aliases, app_flags, flags, aliases
80 )
80 )
81
81
82 #-----------------------------------------------------------------------------
82 #-----------------------------------------------------------------------------
83 # Network Constants
83 # Network Constants
84 #-----------------------------------------------------------------------------
84 #-----------------------------------------------------------------------------
85
85
86 from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS
86 from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS
87
87
88 #-----------------------------------------------------------------------------
88 #-----------------------------------------------------------------------------
89 # Globals
89 # Globals
90 #-----------------------------------------------------------------------------
90 #-----------------------------------------------------------------------------
91
91
92 _examples = """
92 _examples = """
93 ipython qtconsole # start the qtconsole
93 ipython qtconsole # start the qtconsole
94 ipython qtconsole --pylab=inline # start with pylab in inline plotting mode
94 ipython qtconsole --pylab=inline # start with pylab in inline plotting mode
95 """
95 """
96
96
97 #-----------------------------------------------------------------------------
97 #-----------------------------------------------------------------------------
98 # Aliases and Flags
98 # Aliases and Flags
99 #-----------------------------------------------------------------------------
99 #-----------------------------------------------------------------------------
100
100
101 # start with copy of flags
101 # start with copy of flags
102 flags = dict(flags)
102 flags = dict(flags)
103 qt_flags = {
103 qt_flags = {
104 'pure' : ({'IPythonQtConsoleApp' : {'pure' : True}},
104 'plain' : ({'IPythonQtConsoleApp' : {'plain' : True}},
105 "Use a pure Python kernel instead of an IPython kernel."),
106 'plain' : ({'ConsoleWidget' : {'kind' : 'plain'}},
107 "Disable rich text support."),
105 "Disable rich text support."),
108 }
106 }
109 qt_flags.update(boolean_flag(
107 qt_flags.update(boolean_flag(
110 'gui-completion', 'ConsoleWidget.gui_completion',
108 'gui-completion', 'ConsoleWidget.gui_completion',
111 "use a GUI widget for tab completion",
109 "use a GUI widget for tab completion",
112 "use plaintext output for completion"
110 "use plaintext output for completion"
113 ))
111 ))
114 # and app_flags from the Console Mixin
112 # and app_flags from the Console Mixin
115 qt_flags.update(app_flags)
113 qt_flags.update(app_flags)
116 # add frontend flags to the full set
114 # add frontend flags to the full set
117 flags.update(qt_flags)
115 flags.update(qt_flags)
118
116
119 # start with copy of front&backend aliases list
117 # start with copy of front&backend aliases list
120 aliases = dict(aliases)
118 aliases = dict(aliases)
121 qt_aliases = dict(
119 qt_aliases = dict(
122
120
123 style = 'IPythonWidget.syntax_style',
121 style = 'IPythonWidget.syntax_style',
124 stylesheet = 'IPythonQtConsoleApp.stylesheet',
122 stylesheet = 'IPythonQtConsoleApp.stylesheet',
125 colors = 'ZMQInteractiveShell.colors',
123 colors = 'ZMQInteractiveShell.colors',
126
124
127 editor = 'IPythonWidget.editor',
125 editor = 'IPythonWidget.editor',
128 paging = 'ConsoleWidget.paging',
126 paging = 'ConsoleWidget.paging',
129 )
127 )
130 # and app_aliases from the Console Mixin
128 # and app_aliases from the Console Mixin
131 qt_aliases.update(app_aliases)
129 qt_aliases.update(app_aliases)
132 # add frontend aliases to the full set
130 # add frontend aliases to the full set
133 aliases.update(qt_aliases)
131 aliases.update(qt_aliases)
134
132
135 # get flags&aliases into sets, and remove a couple that
133 # get flags&aliases into sets, and remove a couple that
136 # shouldn't be scrubbed from backend flags:
134 # shouldn't be scrubbed from backend flags:
137 qt_aliases = set(qt_aliases.keys())
135 qt_aliases = set(qt_aliases.keys())
138 qt_aliases.remove('colors')
136 qt_aliases.remove('colors')
139 qt_flags = set(qt_flags.keys())
137 qt_flags = set(qt_flags.keys())
140
138
141 #-----------------------------------------------------------------------------
139 #-----------------------------------------------------------------------------
142 # Classes
140 # Classes
143 #-----------------------------------------------------------------------------
141 #-----------------------------------------------------------------------------
144
142
145 #-----------------------------------------------------------------------------
143 #-----------------------------------------------------------------------------
146 # IPythonQtConsole
144 # IPythonQtConsole
147 #-----------------------------------------------------------------------------
145 #-----------------------------------------------------------------------------
148
146
149
147
150 class IPythonQtConsoleApp(BaseIPythonApplication, IPythonConsoleApp):
148 class IPythonQtConsoleApp(BaseIPythonApplication, IPythonConsoleApp):
151 name = 'ipython-qtconsole'
149 name = 'ipython-qtconsole'
152
150
153 description = """
151 description = """
154 The IPython QtConsole.
152 The IPython QtConsole.
155
153
156 This launches a Console-style application using Qt. It is not a full
154 This launches a Console-style application using Qt. It is not a full
157 console, in that launched terminal subprocesses will not be able to accept
155 console, in that launched terminal subprocesses will not be able to accept
158 input.
156 input.
159
157
160 The QtConsole supports various extra features beyond the Terminal IPython
158 The QtConsole supports various extra features beyond the Terminal IPython
161 shell, such as inline plotting with matplotlib, via:
159 shell, such as inline plotting with matplotlib, via:
162
160
163 ipython qtconsole --pylab=inline
161 ipython qtconsole --pylab=inline
164
162
165 as well as saving your session as HTML, and printing the output.
163 as well as saving your session as HTML, and printing the output.
166
164
167 """
165 """
168 examples = _examples
166 examples = _examples
169
167
170 classes = [IPKernelApp, IPythonWidget, ZMQInteractiveShell, ProfileDir, Session]
168 classes = [IPKernelApp, IPythonWidget, ZMQInteractiveShell, ProfileDir, Session]
171 flags = Dict(flags)
169 flags = Dict(flags)
172 aliases = Dict(aliases)
170 aliases = Dict(aliases)
173 frontend_flags = Any(qt_flags)
171 frontend_flags = Any(qt_flags)
174 frontend_aliases = Any(qt_aliases)
172 frontend_aliases = Any(qt_aliases)
175 kernel_manager_class = QtKernelManager
173 kernel_manager_class = QtKernelManager
176
174
177 stylesheet = Unicode('', config=True,
175 stylesheet = Unicode('', config=True,
178 help="path to a custom CSS stylesheet")
176 help="path to a custom CSS stylesheet")
179
177
180 plain = CBool(False, config=True,
178 plain = CBool(False, config=True,
181 help="Use a plaintext widget instead of rich text (plain can't print/save).")
179 help="Use a plaintext widget instead of rich text (plain can't print/save).")
182
180
183 def _pure_changed(self, name, old, new):
181 def _plain_changed(self, name, old, new):
184 kind = 'plain' if self.plain else 'rich'
182 kind = 'plain' if new else 'rich'
185 self.config.ConsoleWidget.kind = kind
183 self.config.ConsoleWidget.kind = kind
186 if self.pure:
184 if new:
187 self.widget_factory = FrontendWidget
188 elif self.plain:
189 self.widget_factory = IPythonWidget
185 self.widget_factory = IPythonWidget
190 else:
186 else:
191 self.widget_factory = RichIPythonWidget
187 self.widget_factory = RichIPythonWidget
192
188
193 _plain_changed = _pure_changed
194
195 # the factory for creating a widget
189 # the factory for creating a widget
196 widget_factory = Any(RichIPythonWidget)
190 widget_factory = Any(RichIPythonWidget)
197
191
198 def parse_command_line(self, argv=None):
192 def parse_command_line(self, argv=None):
199 super(IPythonQtConsoleApp, self).parse_command_line(argv)
193 super(IPythonQtConsoleApp, self).parse_command_line(argv)
200 self.build_kernel_argv(argv)
194 self.build_kernel_argv(argv)
201
195
202
196
203 def new_frontend_master(self):
197 def new_frontend_master(self):
204 """ Create and return new frontend attached to new kernel, launched on localhost.
198 """ Create and return new frontend attached to new kernel, launched on localhost.
205 """
199 """
206 ip = self.ip if self.ip in LOCAL_IPS else LOCALHOST
200 ip = self.ip if self.ip in LOCAL_IPS else LOCALHOST
207 kernel_manager = self.kernel_manager_class(
201 kernel_manager = self.kernel_manager_class(
208 ip=ip,
202 ip=ip,
209 connection_file=self._new_connection_file(),
203 connection_file=self._new_connection_file(),
210 config=self.config,
204 config=self.config,
211 )
205 )
212 # start the kernel
206 # start the kernel
213 kwargs = dict(ipython=not self.pure)
207 kwargs = dict()
214 kwargs['extra_arguments'] = self.kernel_argv
208 kwargs['extra_arguments'] = self.kernel_argv
215 kernel_manager.start_kernel(**kwargs)
209 kernel_manager.start_kernel(**kwargs)
216 kernel_manager.start_channels()
210 kernel_manager.start_channels()
217 widget = self.widget_factory(config=self.config,
211 widget = self.widget_factory(config=self.config,
218 local_kernel=True)
212 local_kernel=True)
219 self.init_colors(widget)
213 self.init_colors(widget)
220 widget.kernel_manager = kernel_manager
214 widget.kernel_manager = kernel_manager
221 widget._existing = False
215 widget._existing = False
222 widget._may_close = True
216 widget._may_close = True
223 widget._confirm_exit = self.confirm_exit
217 widget._confirm_exit = self.confirm_exit
224 return widget
218 return widget
225
219
226 def new_frontend_slave(self, current_widget):
220 def new_frontend_slave(self, current_widget):
227 """Create and return a new frontend attached to an existing kernel.
221 """Create and return a new frontend attached to an existing kernel.
228
222
229 Parameters
223 Parameters
230 ----------
224 ----------
231 current_widget : IPythonWidget
225 current_widget : IPythonWidget
232 The IPythonWidget whose kernel this frontend is to share
226 The IPythonWidget whose kernel this frontend is to share
233 """
227 """
234 kernel_manager = self.kernel_manager_class(
228 kernel_manager = self.kernel_manager_class(
235 connection_file=current_widget.kernel_manager.connection_file,
229 connection_file=current_widget.kernel_manager.connection_file,
236 config = self.config,
230 config = self.config,
237 )
231 )
238 kernel_manager.load_connection_file()
232 kernel_manager.load_connection_file()
239 kernel_manager.start_channels()
233 kernel_manager.start_channels()
240 widget = self.widget_factory(config=self.config,
234 widget = self.widget_factory(config=self.config,
241 local_kernel=False)
235 local_kernel=False)
242 self.init_colors(widget)
236 self.init_colors(widget)
243 widget._existing = True
237 widget._existing = True
244 widget._may_close = False
238 widget._may_close = False
245 widget._confirm_exit = False
239 widget._confirm_exit = False
246 widget.kernel_manager = kernel_manager
240 widget.kernel_manager = kernel_manager
247 return widget
241 return widget
248
242
249 def init_qt_elements(self):
243 def init_qt_elements(self):
250 # Create the widget.
244 # Create the widget.
251 self.app = QtGui.QApplication([])
245 self.app = QtGui.QApplication([])
252
246
253 base_path = os.path.abspath(os.path.dirname(__file__))
247 base_path = os.path.abspath(os.path.dirname(__file__))
254 icon_path = os.path.join(base_path, 'resources', 'icon', 'IPythonConsole.svg')
248 icon_path = os.path.join(base_path, 'resources', 'icon', 'IPythonConsole.svg')
255 self.app.icon = QtGui.QIcon(icon_path)
249 self.app.icon = QtGui.QIcon(icon_path)
256 QtGui.QApplication.setWindowIcon(self.app.icon)
250 QtGui.QApplication.setWindowIcon(self.app.icon)
257
251
258 local_kernel = (not self.existing) or self.ip in LOCAL_IPS
252 local_kernel = (not self.existing) or self.ip in LOCAL_IPS
259 self.widget = self.widget_factory(config=self.config,
253 self.widget = self.widget_factory(config=self.config,
260 local_kernel=local_kernel)
254 local_kernel=local_kernel)
261 self.init_colors(self.widget)
255 self.init_colors(self.widget)
262 self.widget._existing = self.existing
256 self.widget._existing = self.existing
263 self.widget._may_close = not self.existing
257 self.widget._may_close = not self.existing
264 self.widget._confirm_exit = self.confirm_exit
258 self.widget._confirm_exit = self.confirm_exit
265
259
266 self.widget.kernel_manager = self.kernel_manager
260 self.widget.kernel_manager = self.kernel_manager
267 self.window = MainWindow(self.app,
261 self.window = MainWindow(self.app,
268 confirm_exit=self.confirm_exit,
262 confirm_exit=self.confirm_exit,
269 new_frontend_factory=self.new_frontend_master,
263 new_frontend_factory=self.new_frontend_master,
270 slave_frontend_factory=self.new_frontend_slave,
264 slave_frontend_factory=self.new_frontend_slave,
271 )
265 )
272 self.window.log = self.log
266 self.window.log = self.log
273 self.window.add_tab_with_frontend(self.widget)
267 self.window.add_tab_with_frontend(self.widget)
274 self.window.init_menu_bar()
268 self.window.init_menu_bar()
275
269
276 self.window.setWindowTitle('Python' if self.pure else 'IPython')
270 self.window.setWindowTitle('IPython')
277
271
278 def init_colors(self, widget):
272 def init_colors(self, widget):
279 """Configure the coloring of the widget"""
273 """Configure the coloring of the widget"""
280 # Note: This will be dramatically simplified when colors
274 # Note: This will be dramatically simplified when colors
281 # are removed from the backend.
275 # are removed from the backend.
282
276
283 if self.pure:
284 # only IPythonWidget supports styling
285 return
286
287 # parse the colors arg down to current known labels
277 # parse the colors arg down to current known labels
288 try:
278 try:
289 colors = self.config.ZMQInteractiveShell.colors
279 colors = self.config.ZMQInteractiveShell.colors
290 except AttributeError:
280 except AttributeError:
291 colors = None
281 colors = None
292 try:
282 try:
293 style = self.config.IPythonWidget.syntax_style
283 style = self.config.IPythonWidget.syntax_style
294 except AttributeError:
284 except AttributeError:
295 style = None
285 style = None
296 try:
286 try:
297 sheet = self.config.IPythonWidget.style_sheet
287 sheet = self.config.IPythonWidget.style_sheet
298 except AttributeError:
288 except AttributeError:
299 sheet = None
289 sheet = None
300
290
301 # find the value for colors:
291 # find the value for colors:
302 if colors:
292 if colors:
303 colors=colors.lower()
293 colors=colors.lower()
304 if colors in ('lightbg', 'light'):
294 if colors in ('lightbg', 'light'):
305 colors='lightbg'
295 colors='lightbg'
306 elif colors in ('dark', 'linux'):
296 elif colors in ('dark', 'linux'):
307 colors='linux'
297 colors='linux'
308 else:
298 else:
309 colors='nocolor'
299 colors='nocolor'
310 elif style:
300 elif style:
311 if style=='bw':
301 if style=='bw':
312 colors='nocolor'
302 colors='nocolor'
313 elif styles.dark_style(style):
303 elif styles.dark_style(style):
314 colors='linux'
304 colors='linux'
315 else:
305 else:
316 colors='lightbg'
306 colors='lightbg'
317 else:
307 else:
318 colors=None
308 colors=None
319
309
320 # Configure the style
310 # Configure the style
321 if style:
311 if style:
322 widget.style_sheet = styles.sheet_from_template(style, colors)
312 widget.style_sheet = styles.sheet_from_template(style, colors)
323 widget.syntax_style = style
313 widget.syntax_style = style
324 widget._syntax_style_changed()
314 widget._syntax_style_changed()
325 widget._style_sheet_changed()
315 widget._style_sheet_changed()
326 elif colors:
316 elif colors:
327 # use a default dark/light/bw style
317 # use a default dark/light/bw style
328 widget.set_default_style(colors=colors)
318 widget.set_default_style(colors=colors)
329
319
330 if self.stylesheet:
320 if self.stylesheet:
331 # we got an explicit stylesheet
321 # we got an explicit stylesheet
332 if os.path.isfile(self.stylesheet):
322 if os.path.isfile(self.stylesheet):
333 with open(self.stylesheet) as f:
323 with open(self.stylesheet) as f:
334 sheet = f.read()
324 sheet = f.read()
335 else:
325 else:
336 raise IOError("Stylesheet %r not found." % self.stylesheet)
326 raise IOError("Stylesheet %r not found." % self.stylesheet)
337 if sheet:
327 if sheet:
338 widget.style_sheet = sheet
328 widget.style_sheet = sheet
339 widget._style_sheet_changed()
329 widget._style_sheet_changed()
340
330
341
331
342 def init_signal(self):
332 def init_signal(self):
343 """allow clean shutdown on sigint"""
333 """allow clean shutdown on sigint"""
344 signal.signal(signal.SIGINT, lambda sig, frame: self.exit(-2))
334 signal.signal(signal.SIGINT, lambda sig, frame: self.exit(-2))
345 # need a timer, so that QApplication doesn't block until a real
335 # need a timer, so that QApplication doesn't block until a real
346 # Qt event fires (can require mouse movement)
336 # Qt event fires (can require mouse movement)
347 # timer trick from http://stackoverflow.com/q/4938723/938949
337 # timer trick from http://stackoverflow.com/q/4938723/938949
348 timer = QtCore.QTimer()
338 timer = QtCore.QTimer()
349 # Let the interpreter run each 200 ms:
339 # Let the interpreter run each 200 ms:
350 timer.timeout.connect(lambda: None)
340 timer.timeout.connect(lambda: None)
351 timer.start(200)
341 timer.start(200)
352 # hold onto ref, so the timer doesn't get cleaned up
342 # hold onto ref, so the timer doesn't get cleaned up
353 self._sigint_timer = timer
343 self._sigint_timer = timer
354
344
355 @catch_config_error
345 @catch_config_error
356 def initialize(self, argv=None):
346 def initialize(self, argv=None):
357 super(IPythonQtConsoleApp, self).initialize(argv)
347 super(IPythonQtConsoleApp, self).initialize(argv)
358 IPythonConsoleApp.initialize(self,argv)
348 IPythonConsoleApp.initialize(self,argv)
359 self.init_qt_elements()
349 self.init_qt_elements()
360 self.init_signal()
350 self.init_signal()
361
351
362 def start(self):
352 def start(self):
363
353
364 # draw the window
354 # draw the window
365 self.window.show()
355 self.window.show()
366 self.window.raise_()
356 self.window.raise_()
367
357
368 # Start the application main loop.
358 # Start the application main loop.
369 self.app.exec_()
359 self.app.exec_()
370
360
371 #-----------------------------------------------------------------------------
361 #-----------------------------------------------------------------------------
372 # Main entry point
362 # Main entry point
373 #-----------------------------------------------------------------------------
363 #-----------------------------------------------------------------------------
374
364
375 def main():
365 def main():
376 app = IPythonQtConsoleApp()
366 app = IPythonQtConsoleApp()
377 app.initialize()
367 app.initialize()
378 app.start()
368 app.start()
379
369
380
370
381 if __name__ == '__main__':
371 if __name__ == '__main__':
382 main()
372 main()
@@ -1,315 +1,315 b''
1 """Utilities for connecting to kernels
1 """Utilities for connecting to kernels
2
2
3 Authors:
3 Authors:
4
4
5 * Min Ragan-Kelley
5 * Min Ragan-Kelley
6
6
7 """
7 """
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Copyright (C) 2011 The IPython Development Team
10 # Copyright (C) 2011 The IPython Development Team
11 #
11 #
12 # Distributed under the terms of the BSD License. The full license is in
12 # Distributed under the terms of the BSD License. The full license is in
13 # the file COPYING, distributed as part of this software.
13 # the file COPYING, distributed as part of this software.
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15
15
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17 # Imports
17 # Imports
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 import glob
20 import glob
21 import json
21 import json
22 import os
22 import os
23 import sys
23 import sys
24 from getpass import getpass
24 from getpass import getpass
25 from subprocess import Popen, PIPE
25 from subprocess import Popen, PIPE
26
26
27 # external imports
27 # external imports
28 from IPython.external.ssh import tunnel
28 from IPython.external.ssh import tunnel
29
29
30 # IPython imports
30 # IPython imports
31 from IPython.core.profiledir import ProfileDir
31 from IPython.core.profiledir import ProfileDir
32 from IPython.utils.path import filefind, get_ipython_dir
32 from IPython.utils.path import filefind, get_ipython_dir
33 from IPython.utils.py3compat import str_to_bytes
33 from IPython.utils.py3compat import str_to_bytes
34
34
35
35
36 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
37 # Functions
37 # Functions
38 #-----------------------------------------------------------------------------
38 #-----------------------------------------------------------------------------
39
39
40 def get_connection_file(app=None):
40 def get_connection_file(app=None):
41 """Return the path to the connection file of an app
41 """Return the path to the connection file of an app
42
42
43 Parameters
43 Parameters
44 ----------
44 ----------
45 app : KernelApp instance [optional]
45 app : KernelApp instance [optional]
46 If unspecified, the currently running app will be used
46 If unspecified, the currently running app will be used
47 """
47 """
48 if app is None:
48 if app is None:
49 from IPython.zmq.kernelapp import KernelApp
49 from IPython.zmq.ipkernel import IPKernelApp
50 if not KernelApp.initialized():
50 if not IPKernelApp.initialized():
51 raise RuntimeError("app not specified, and not in a running Kernel")
51 raise RuntimeError("app not specified, and not in a running Kernel")
52
52
53 app = KernelApp.instance()
53 app = IPKernelApp.instance()
54 return filefind(app.connection_file, ['.', app.profile_dir.security_dir])
54 return filefind(app.connection_file, ['.', app.profile_dir.security_dir])
55
55
56 def find_connection_file(filename, profile=None):
56 def find_connection_file(filename, profile=None):
57 """find a connection file, and return its absolute path.
57 """find a connection file, and return its absolute path.
58
58
59 The current working directory and the profile's security
59 The current working directory and the profile's security
60 directory will be searched for the file if it is not given by
60 directory will be searched for the file if it is not given by
61 absolute path.
61 absolute path.
62
62
63 If profile is unspecified, then the current running application's
63 If profile is unspecified, then the current running application's
64 profile will be used, or 'default', if not run from IPython.
64 profile will be used, or 'default', if not run from IPython.
65
65
66 If the argument does not match an existing file, it will be interpreted as a
66 If the argument does not match an existing file, it will be interpreted as a
67 fileglob, and the matching file in the profile's security dir with
67 fileglob, and the matching file in the profile's security dir with
68 the latest access time will be used.
68 the latest access time will be used.
69
69
70 Parameters
70 Parameters
71 ----------
71 ----------
72 filename : str
72 filename : str
73 The connection file or fileglob to search for.
73 The connection file or fileglob to search for.
74 profile : str [optional]
74 profile : str [optional]
75 The name of the profile to use when searching for the connection file,
75 The name of the profile to use when searching for the connection file,
76 if different from the current IPython session or 'default'.
76 if different from the current IPython session or 'default'.
77
77
78 Returns
78 Returns
79 -------
79 -------
80 str : The absolute path of the connection file.
80 str : The absolute path of the connection file.
81 """
81 """
82 from IPython.core.application import BaseIPythonApplication as IPApp
82 from IPython.core.application import BaseIPythonApplication as IPApp
83 try:
83 try:
84 # quick check for absolute path, before going through logic
84 # quick check for absolute path, before going through logic
85 return filefind(filename)
85 return filefind(filename)
86 except IOError:
86 except IOError:
87 pass
87 pass
88
88
89 if profile is None:
89 if profile is None:
90 # profile unspecified, check if running from an IPython app
90 # profile unspecified, check if running from an IPython app
91 if IPApp.initialized():
91 if IPApp.initialized():
92 app = IPApp.instance()
92 app = IPApp.instance()
93 profile_dir = app.profile_dir
93 profile_dir = app.profile_dir
94 else:
94 else:
95 # not running in IPython, use default profile
95 # not running in IPython, use default profile
96 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), 'default')
96 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), 'default')
97 else:
97 else:
98 # find profiledir by profile name:
98 # find profiledir by profile name:
99 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), profile)
99 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), profile)
100 security_dir = profile_dir.security_dir
100 security_dir = profile_dir.security_dir
101
101
102 try:
102 try:
103 # first, try explicit name
103 # first, try explicit name
104 return filefind(filename, ['.', security_dir])
104 return filefind(filename, ['.', security_dir])
105 except IOError:
105 except IOError:
106 pass
106 pass
107
107
108 # not found by full name
108 # not found by full name
109
109
110 if '*' in filename:
110 if '*' in filename:
111 # given as a glob already
111 # given as a glob already
112 pat = filename
112 pat = filename
113 else:
113 else:
114 # accept any substring match
114 # accept any substring match
115 pat = '*%s*' % filename
115 pat = '*%s*' % filename
116 matches = glob.glob( os.path.join(security_dir, pat) )
116 matches = glob.glob( os.path.join(security_dir, pat) )
117 if not matches:
117 if not matches:
118 raise IOError("Could not find %r in %r" % (filename, security_dir))
118 raise IOError("Could not find %r in %r" % (filename, security_dir))
119 elif len(matches) == 1:
119 elif len(matches) == 1:
120 return matches[0]
120 return matches[0]
121 else:
121 else:
122 # get most recent match, by access time:
122 # get most recent match, by access time:
123 return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
123 return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
124
124
125 def get_connection_info(connection_file=None, unpack=False, profile=None):
125 def get_connection_info(connection_file=None, unpack=False, profile=None):
126 """Return the connection information for the current Kernel.
126 """Return the connection information for the current Kernel.
127
127
128 Parameters
128 Parameters
129 ----------
129 ----------
130 connection_file : str [optional]
130 connection_file : str [optional]
131 The connection file to be used. Can be given by absolute path, or
131 The connection file to be used. Can be given by absolute path, or
132 IPython will search in the security directory of a given profile.
132 IPython will search in the security directory of a given profile.
133 If run from IPython,
133 If run from IPython,
134
134
135 If unspecified, the connection file for the currently running
135 If unspecified, the connection file for the currently running
136 IPython Kernel will be used, which is only allowed from inside a kernel.
136 IPython Kernel will be used, which is only allowed from inside a kernel.
137 unpack : bool [default: False]
137 unpack : bool [default: False]
138 if True, return the unpacked dict, otherwise just the string contents
138 if True, return the unpacked dict, otherwise just the string contents
139 of the file.
139 of the file.
140 profile : str [optional]
140 profile : str [optional]
141 The name of the profile to use when searching for the connection file,
141 The name of the profile to use when searching for the connection file,
142 if different from the current IPython session or 'default'.
142 if different from the current IPython session or 'default'.
143
143
144
144
145 Returns
145 Returns
146 -------
146 -------
147 The connection dictionary of the current kernel, as string or dict,
147 The connection dictionary of the current kernel, as string or dict,
148 depending on `unpack`.
148 depending on `unpack`.
149 """
149 """
150 if connection_file is None:
150 if connection_file is None:
151 # get connection file from current kernel
151 # get connection file from current kernel
152 cf = get_connection_file()
152 cf = get_connection_file()
153 else:
153 else:
154 # connection file specified, allow shortnames:
154 # connection file specified, allow shortnames:
155 cf = find_connection_file(connection_file, profile=profile)
155 cf = find_connection_file(connection_file, profile=profile)
156
156
157 with open(cf) as f:
157 with open(cf) as f:
158 info = f.read()
158 info = f.read()
159
159
160 if unpack:
160 if unpack:
161 info = json.loads(info)
161 info = json.loads(info)
162 # ensure key is bytes:
162 # ensure key is bytes:
163 info['key'] = str_to_bytes(info.get('key', ''))
163 info['key'] = str_to_bytes(info.get('key', ''))
164 return info
164 return info
165
165
166 def connect_qtconsole(connection_file=None, argv=None, profile=None):
166 def connect_qtconsole(connection_file=None, argv=None, profile=None):
167 """Connect a qtconsole to the current kernel.
167 """Connect a qtconsole to the current kernel.
168
168
169 This is useful for connecting a second qtconsole to a kernel, or to a
169 This is useful for connecting a second qtconsole to a kernel, or to a
170 local notebook.
170 local notebook.
171
171
172 Parameters
172 Parameters
173 ----------
173 ----------
174 connection_file : str [optional]
174 connection_file : str [optional]
175 The connection file to be used. Can be given by absolute path, or
175 The connection file to be used. Can be given by absolute path, or
176 IPython will search in the security directory of a given profile.
176 IPython will search in the security directory of a given profile.
177 If run from IPython,
177 If run from IPython,
178
178
179 If unspecified, the connection file for the currently running
179 If unspecified, the connection file for the currently running
180 IPython Kernel will be used, which is only allowed from inside a kernel.
180 IPython Kernel will be used, which is only allowed from inside a kernel.
181 argv : list [optional]
181 argv : list [optional]
182 Any extra args to be passed to the console.
182 Any extra args to be passed to the console.
183 profile : str [optional]
183 profile : str [optional]
184 The name of the profile to use when searching for the connection file,
184 The name of the profile to use when searching for the connection file,
185 if different from the current IPython session or 'default'.
185 if different from the current IPython session or 'default'.
186
186
187
187
188 Returns
188 Returns
189 -------
189 -------
190 subprocess.Popen instance running the qtconsole frontend
190 subprocess.Popen instance running the qtconsole frontend
191 """
191 """
192 argv = [] if argv is None else argv
192 argv = [] if argv is None else argv
193
193
194 if connection_file is None:
194 if connection_file is None:
195 # get connection file from current kernel
195 # get connection file from current kernel
196 cf = get_connection_file()
196 cf = get_connection_file()
197 else:
197 else:
198 cf = find_connection_file(connection_file, profile=profile)
198 cf = find_connection_file(connection_file, profile=profile)
199
199
200 cmd = ';'.join([
200 cmd = ';'.join([
201 "from IPython.frontend.qt.console import qtconsoleapp",
201 "from IPython.frontend.qt.console import qtconsoleapp",
202 "qtconsoleapp.main()"
202 "qtconsoleapp.main()"
203 ])
203 ])
204
204
205 return Popen([sys.executable, '-c', cmd, '--existing', cf] + argv, stdout=PIPE, stderr=PIPE)
205 return Popen([sys.executable, '-c', cmd, '--existing', cf] + argv, stdout=PIPE, stderr=PIPE)
206
206
207 def tunnel_to_kernel(connection_info, sshserver, sshkey=None):
207 def tunnel_to_kernel(connection_info, sshserver, sshkey=None):
208 """tunnel connections to a kernel via ssh
208 """tunnel connections to a kernel via ssh
209
209
210 This will open four SSH tunnels from localhost on this machine to the
210 This will open four SSH tunnels from localhost on this machine to the
211 ports associated with the kernel. They can be either direct
211 ports associated with the kernel. They can be either direct
212 localhost-localhost tunnels, or if an intermediate server is necessary,
212 localhost-localhost tunnels, or if an intermediate server is necessary,
213 the kernel must be listening on a public IP.
213 the kernel must be listening on a public IP.
214
214
215 Parameters
215 Parameters
216 ----------
216 ----------
217 connection_info : dict or str (path)
217 connection_info : dict or str (path)
218 Either a connection dict, or the path to a JSON connection file
218 Either a connection dict, or the path to a JSON connection file
219 sshserver : str
219 sshserver : str
220 The ssh sever to use to tunnel to the kernel. Can be a full
220 The ssh sever to use to tunnel to the kernel. Can be a full
221 `user@server:port` string. ssh config aliases are respected.
221 `user@server:port` string. ssh config aliases are respected.
222 sshkey : str [optional]
222 sshkey : str [optional]
223 Path to file containing ssh key to use for authentication.
223 Path to file containing ssh key to use for authentication.
224 Only necessary if your ssh config does not already associate
224 Only necessary if your ssh config does not already associate
225 a keyfile with the host.
225 a keyfile with the host.
226
226
227 Returns
227 Returns
228 -------
228 -------
229
229
230 (shell, iopub, stdin, hb) : ints
230 (shell, iopub, stdin, hb) : ints
231 The four ports on localhost that have been forwarded to the kernel.
231 The four ports on localhost that have been forwarded to the kernel.
232 """
232 """
233 if isinstance(connection_info, basestring):
233 if isinstance(connection_info, basestring):
234 # it's a path, unpack it
234 # it's a path, unpack it
235 with open(connection_info) as f:
235 with open(connection_info) as f:
236 connection_info = json.loads(f.read())
236 connection_info = json.loads(f.read())
237
237
238 cf = connection_info
238 cf = connection_info
239
239
240 lports = tunnel.select_random_ports(4)
240 lports = tunnel.select_random_ports(4)
241 rports = cf['shell_port'], cf['iopub_port'], cf['stdin_port'], cf['hb_port']
241 rports = cf['shell_port'], cf['iopub_port'], cf['stdin_port'], cf['hb_port']
242
242
243 remote_ip = cf['ip']
243 remote_ip = cf['ip']
244
244
245 if tunnel.try_passwordless_ssh(sshserver, sshkey):
245 if tunnel.try_passwordless_ssh(sshserver, sshkey):
246 password=False
246 password=False
247 else:
247 else:
248 password = getpass("SSH Password for %s: "%sshserver)
248 password = getpass("SSH Password for %s: "%sshserver)
249
249
250 for lp,rp in zip(lports, rports):
250 for lp,rp in zip(lports, rports):
251 tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
251 tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
252
252
253 return tuple(lports)
253 return tuple(lports)
254
254
255
255
256 def swallow_argv(argv, aliases=None, flags=None):
256 def swallow_argv(argv, aliases=None, flags=None):
257 """strip frontend-specific aliases and flags from an argument list
257 """strip frontend-specific aliases and flags from an argument list
258
258
259 For use primarily in frontend apps that want to pass a subset of command-line
259 For use primarily in frontend apps that want to pass a subset of command-line
260 arguments through to a subprocess, where frontend-specific flags and aliases
260 arguments through to a subprocess, where frontend-specific flags and aliases
261 should be removed from the list.
261 should be removed from the list.
262
262
263 Parameters
263 Parameters
264 ----------
264 ----------
265
265
266 argv : list(str)
266 argv : list(str)
267 The starting argv, to be filtered
267 The starting argv, to be filtered
268 aliases : container of aliases (dict, list, set, etc.)
268 aliases : container of aliases (dict, list, set, etc.)
269 The frontend-specific aliases to be removed
269 The frontend-specific aliases to be removed
270 flags : container of flags (dict, list, set, etc.)
270 flags : container of flags (dict, list, set, etc.)
271 The frontend-specific flags to be removed
271 The frontend-specific flags to be removed
272
272
273 Returns
273 Returns
274 -------
274 -------
275
275
276 argv : list(str)
276 argv : list(str)
277 The argv list, excluding flags and aliases that have been stripped
277 The argv list, excluding flags and aliases that have been stripped
278 """
278 """
279
279
280 if aliases is None:
280 if aliases is None:
281 aliases = set()
281 aliases = set()
282 if flags is None:
282 if flags is None:
283 flags = set()
283 flags = set()
284
284
285 stripped = list(argv) # copy
285 stripped = list(argv) # copy
286
286
287 swallow_next = False
287 swallow_next = False
288 was_flag = False
288 was_flag = False
289 for a in argv:
289 for a in argv:
290 if swallow_next:
290 if swallow_next:
291 swallow_next = False
291 swallow_next = False
292 # last arg was an alias, remove the next one
292 # last arg was an alias, remove the next one
293 # *unless* the last alias has a no-arg flag version, in which
293 # *unless* the last alias has a no-arg flag version, in which
294 # case, don't swallow the next arg if it's also a flag:
294 # case, don't swallow the next arg if it's also a flag:
295 if not (was_flag and a.startswith('-')):
295 if not (was_flag and a.startswith('-')):
296 stripped.remove(a)
296 stripped.remove(a)
297 continue
297 continue
298 if a.startswith('-'):
298 if a.startswith('-'):
299 split = a.lstrip('-').split('=')
299 split = a.lstrip('-').split('=')
300 alias = split[0]
300 alias = split[0]
301 if alias in aliases:
301 if alias in aliases:
302 stripped.remove(a)
302 stripped.remove(a)
303 if len(split) == 1:
303 if len(split) == 1:
304 # alias passed with arg via space
304 # alias passed with arg via space
305 swallow_next = True
305 swallow_next = True
306 # could have been a flag that matches an alias, e.g. `existing`
306 # could have been a flag that matches an alias, e.g. `existing`
307 # in which case, we might not swallow the next arg
307 # in which case, we might not swallow the next arg
308 was_flag = alias in flags
308 was_flag = alias in flags
309 elif alias in flags and len(split) == 1:
309 elif alias in flags and len(split) == 1:
310 # strip flag, but don't swallow next, as flags don't take args
310 # strip flag, but don't swallow next, as flags don't take args
311 stripped.remove(a)
311 stripped.remove(a)
312
312
313 # return shortened list
313 # return shortened list
314 return stripped
314 return stripped
315
315
@@ -1,40 +1,62 b''
1 """The IPython ZMQ-based parallel computing interface.
1 """The IPython ZMQ-based parallel computing interface.
2
2
3 Authors:
3 Authors:
4
4
5 * MinRK
5 * MinRK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2011 The IPython Development Team
8 # Copyright (C) 2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import os
18 import os
19 import warnings
19 import warnings
20
20
21 import zmq
21 import zmq
22
22
23 from IPython.zmq import check_for_zmq
23 from IPython.zmq import check_for_zmq
24
24
25 if os.name == 'nt':
25 if os.name == 'nt':
26 min_pyzmq = '2.1.7'
26 min_pyzmq = '2.1.7'
27 else:
27 else:
28 min_pyzmq = '2.1.4'
28 min_pyzmq = '2.1.4'
29
29
30 check_for_zmq(min_pyzmq, 'IPython.parallel')
30 check_for_zmq(min_pyzmq, 'IPython.parallel')
31
31
32 from IPython.utils.pickleutil import Reference
32 from IPython.utils.pickleutil import Reference
33
33
34 from .client.asyncresult import *
34 from .client.asyncresult import *
35 from .client.client import Client
35 from .client.client import Client
36 from .client.remotefunction import *
36 from .client.remotefunction import *
37 from .client.view import *
37 from .client.view import *
38 from .util import interactive
38 from .controller.dependency import *
39 from .controller.dependency import *
39
40
41 #-----------------------------------------------------------------------------
42 # Functions
43 #-----------------------------------------------------------------------------
44
45
46 def bind_kernel(**kwargs):
47 """Bind an Engine's Kernel to be used as a full IPython kernel.
48
49 This allows a running Engine to be used simultaneously as a full IPython kernel
50 with the QtConsole or other frontends.
51
52 This function returns immediately.
53 """
54 from IPython.parallel.apps.ipengineapp import IPEngineApp
55 if IPEngineApp.initialized():
56 app = IPEngineApp.instance()
57 else:
58 raise RuntimeError("Must be called from an IPEngineApp instance")
59
60 return app.bind_kernel(**kwargs)
61
40
62
@@ -1,265 +1,272 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 The Base Application class for IPython.parallel apps
3 The Base Application class for IPython.parallel apps
4
4
5 Authors:
5 Authors:
6
6
7 * Brian Granger
7 * Brian Granger
8 * Min RK
8 * Min RK
9
9
10 """
10 """
11
11
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # Copyright (C) 2008-2011 The IPython Development Team
13 # Copyright (C) 2008-2011 The IPython Development Team
14 #
14 #
15 # Distributed under the terms of the BSD License. The full license is in
15 # Distributed under the terms of the BSD License. The full license is in
16 # the file COPYING, distributed as part of this software.
16 # the file COPYING, distributed as part of this software.
17 #-----------------------------------------------------------------------------
17 #-----------------------------------------------------------------------------
18
18
19 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
20 # Imports
20 # Imports
21 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
22
22
23 from __future__ import with_statement
23 from __future__ import with_statement
24
24
25 import os
25 import os
26 import logging
26 import logging
27 import re
27 import re
28 import sys
28 import sys
29
29
30 from subprocess import Popen, PIPE
30 from subprocess import Popen, PIPE
31
31
32 from IPython.config.application import catch_config_error
32 from IPython.config.application import catch_config_error
33 from IPython.core import release
33 from IPython.core import release
34 from IPython.core.crashhandler import CrashHandler
34 from IPython.core.crashhandler import CrashHandler
35 from IPython.core.application import (
35 from IPython.core.application import (
36 BaseIPythonApplication,
36 BaseIPythonApplication,
37 base_aliases as base_ip_aliases,
37 base_aliases as base_ip_aliases,
38 base_flags as base_ip_flags
38 base_flags as base_ip_flags
39 )
39 )
40 from IPython.utils.path import expand_path
40 from IPython.utils.path import expand_path
41
41
42 from IPython.utils.traitlets import Unicode, Bool, Instance, Dict, List
42 from IPython.utils.traitlets import Unicode, Bool, Instance, Dict, List
43
43
44 #-----------------------------------------------------------------------------
44 #-----------------------------------------------------------------------------
45 # Module errors
45 # Module errors
46 #-----------------------------------------------------------------------------
46 #-----------------------------------------------------------------------------
47
47
48 class PIDFileError(Exception):
48 class PIDFileError(Exception):
49 pass
49 pass
50
50
51
51
52 #-----------------------------------------------------------------------------
52 #-----------------------------------------------------------------------------
53 # Crash handler for this application
53 # Crash handler for this application
54 #-----------------------------------------------------------------------------
54 #-----------------------------------------------------------------------------
55
55
56 class ParallelCrashHandler(CrashHandler):
56 class ParallelCrashHandler(CrashHandler):
57 """sys.excepthook for IPython itself, leaves a detailed report on disk."""
57 """sys.excepthook for IPython itself, leaves a detailed report on disk."""
58
58
59 def __init__(self, app):
59 def __init__(self, app):
60 contact_name = release.authors['Min'][0]
60 contact_name = release.authors['Min'][0]
61 contact_email = release.author_email
61 contact_email = release.author_email
62 bug_tracker = 'https://github.com/ipython/ipython/issues'
62 bug_tracker = 'https://github.com/ipython/ipython/issues'
63 super(ParallelCrashHandler,self).__init__(
63 super(ParallelCrashHandler,self).__init__(
64 app, contact_name, contact_email, bug_tracker
64 app, contact_name, contact_email, bug_tracker
65 )
65 )
66
66
67
67
68 #-----------------------------------------------------------------------------
68 #-----------------------------------------------------------------------------
69 # Main application
69 # Main application
70 #-----------------------------------------------------------------------------
70 #-----------------------------------------------------------------------------
71 base_aliases = {}
71 base_aliases = {}
72 base_aliases.update(base_ip_aliases)
72 base_aliases.update(base_ip_aliases)
73 base_aliases.update({
73 base_aliases.update({
74 'profile-dir' : 'ProfileDir.location',
74 'profile-dir' : 'ProfileDir.location',
75 'work-dir' : 'BaseParallelApplication.work_dir',
75 'work-dir' : 'BaseParallelApplication.work_dir',
76 'log-to-file' : 'BaseParallelApplication.log_to_file',
76 'log-to-file' : 'BaseParallelApplication.log_to_file',
77 'clean-logs' : 'BaseParallelApplication.clean_logs',
77 'clean-logs' : 'BaseParallelApplication.clean_logs',
78 'log-url' : 'BaseParallelApplication.log_url',
78 'log-url' : 'BaseParallelApplication.log_url',
79 'cluster-id' : 'BaseParallelApplication.cluster_id',
79 'cluster-id' : 'BaseParallelApplication.cluster_id',
80 })
80 })
81
81
82 base_flags = {
82 base_flags = {
83 'log-to-file' : (
83 'log-to-file' : (
84 {'BaseParallelApplication' : {'log_to_file' : True}},
84 {'BaseParallelApplication' : {'log_to_file' : True}},
85 "send log output to a file"
85 "send log output to a file"
86 )
86 )
87 }
87 }
88 base_flags.update(base_ip_flags)
88 base_flags.update(base_ip_flags)
89
89
90 class BaseParallelApplication(BaseIPythonApplication):
90 class BaseParallelApplication(BaseIPythonApplication):
91 """The base Application for IPython.parallel apps
91 """The base Application for IPython.parallel apps
92
92
93 Principle extensions to BaseIPyythonApplication:
93 Principle extensions to BaseIPyythonApplication:
94
94
95 * work_dir
95 * work_dir
96 * remote logging via pyzmq
96 * remote logging via pyzmq
97 * IOLoop instance
97 * IOLoop instance
98 """
98 """
99
99
100 crash_handler_class = ParallelCrashHandler
100 crash_handler_class = ParallelCrashHandler
101
101
102 def _log_level_default(self):
102 def _log_level_default(self):
103 # temporarily override default_log_level to INFO
103 # temporarily override default_log_level to INFO
104 return logging.INFO
104 return logging.INFO
105
106 def _log_format_default(self):
107 """override default log format to include time"""
108 return u"%(asctime)s.%(msecs).03d [%(name)s] %(message)s"
105
109
106 work_dir = Unicode(os.getcwdu(), config=True,
110 work_dir = Unicode(os.getcwdu(), config=True,
107 help='Set the working dir for the process.'
111 help='Set the working dir for the process.'
108 )
112 )
109 def _work_dir_changed(self, name, old, new):
113 def _work_dir_changed(self, name, old, new):
110 self.work_dir = unicode(expand_path(new))
114 self.work_dir = unicode(expand_path(new))
111
115
112 log_to_file = Bool(config=True,
116 log_to_file = Bool(config=True,
113 help="whether to log to a file")
117 help="whether to log to a file")
114
118
115 clean_logs = Bool(False, config=True,
119 clean_logs = Bool(False, config=True,
116 help="whether to cleanup old logfiles before starting")
120 help="whether to cleanup old logfiles before starting")
117
121
118 log_url = Unicode('', config=True,
122 log_url = Unicode('', config=True,
119 help="The ZMQ URL of the iplogger to aggregate logging.")
123 help="The ZMQ URL of the iplogger to aggregate logging.")
120
124
121 cluster_id = Unicode('', config=True,
125 cluster_id = Unicode('', config=True,
122 help="""String id to add to runtime files, to prevent name collisions when
126 help="""String id to add to runtime files, to prevent name collisions when
123 using multiple clusters with a single profile simultaneously.
127 using multiple clusters with a single profile simultaneously.
124
128
125 When set, files will be named like: 'ipcontroller-<cluster_id>-engine.json'
129 When set, files will be named like: 'ipcontroller-<cluster_id>-engine.json'
126
130
127 Since this is text inserted into filenames, typical recommendations apply:
131 Since this is text inserted into filenames, typical recommendations apply:
128 Simple character strings are ideal, and spaces are not recommended (but should
132 Simple character strings are ideal, and spaces are not recommended (but should
129 generally work).
133 generally work).
130 """
134 """
131 )
135 )
132 def _cluster_id_changed(self, name, old, new):
136 def _cluster_id_changed(self, name, old, new):
133 self.name = self.__class__.name
137 self.name = self.__class__.name
134 if new:
138 if new:
135 self.name += '-%s'%new
139 self.name += '-%s'%new
136
140
137 def _config_files_default(self):
141 def _config_files_default(self):
138 return ['ipcontroller_config.py', 'ipengine_config.py', 'ipcluster_config.py']
142 return ['ipcontroller_config.py', 'ipengine_config.py', 'ipcluster_config.py']
139
143
140 loop = Instance('zmq.eventloop.ioloop.IOLoop')
144 loop = Instance('zmq.eventloop.ioloop.IOLoop')
141 def _loop_default(self):
145 def _loop_default(self):
142 from zmq.eventloop.ioloop import IOLoop
146 from zmq.eventloop.ioloop import IOLoop
143 return IOLoop.instance()
147 return IOLoop.instance()
144
148
145 aliases = Dict(base_aliases)
149 aliases = Dict(base_aliases)
146 flags = Dict(base_flags)
150 flags = Dict(base_flags)
147
151
148 @catch_config_error
152 @catch_config_error
149 def initialize(self, argv=None):
153 def initialize(self, argv=None):
150 """initialize the app"""
154 """initialize the app"""
151 super(BaseParallelApplication, self).initialize(argv)
155 super(BaseParallelApplication, self).initialize(argv)
152 self.to_work_dir()
156 self.to_work_dir()
153 self.reinit_logging()
157 self.reinit_logging()
154
158
155 def to_work_dir(self):
159 def to_work_dir(self):
156 wd = self.work_dir
160 wd = self.work_dir
157 if unicode(wd) != os.getcwdu():
161 if unicode(wd) != os.getcwdu():
158 os.chdir(wd)
162 os.chdir(wd)
159 self.log.info("Changing to working dir: %s" % wd)
163 self.log.info("Changing to working dir: %s" % wd)
160 # This is the working dir by now.
164 # This is the working dir by now.
161 sys.path.insert(0, '')
165 sys.path.insert(0, '')
162
166
163 def reinit_logging(self):
167 def reinit_logging(self):
164 # Remove old log files
168 # Remove old log files
165 log_dir = self.profile_dir.log_dir
169 log_dir = self.profile_dir.log_dir
166 if self.clean_logs:
170 if self.clean_logs:
167 for f in os.listdir(log_dir):
171 for f in os.listdir(log_dir):
168 if re.match(r'%s-\d+\.(log|err|out)'%self.name,f):
172 if re.match(r'%s-\d+\.(log|err|out)'%self.name,f):
169 os.remove(os.path.join(log_dir, f))
173 os.remove(os.path.join(log_dir, f))
170 if self.log_to_file:
174 if self.log_to_file:
171 # Start logging to the new log file
175 # Start logging to the new log file
172 log_filename = self.name + u'-' + str(os.getpid()) + u'.log'
176 log_filename = self.name + u'-' + str(os.getpid()) + u'.log'
173 logfile = os.path.join(log_dir, log_filename)
177 logfile = os.path.join(log_dir, log_filename)
174 open_log_file = open(logfile, 'w')
178 open_log_file = open(logfile, 'w')
175 else:
179 else:
176 open_log_file = None
180 open_log_file = None
177 if open_log_file is not None:
181 if open_log_file is not None:
178 self.log.removeHandler(self._log_handler)
182 while self.log.handlers:
183 self.log.removeHandler(self.log.handlers[0])
179 self._log_handler = logging.StreamHandler(open_log_file)
184 self._log_handler = logging.StreamHandler(open_log_file)
180 self.log.addHandler(self._log_handler)
185 self.log.addHandler(self._log_handler)
186 else:
187 self._log_handler = self.log.handlers[0]
181 # Add timestamps to log format:
188 # Add timestamps to log format:
182 self._log_formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
189 self._log_formatter = logging.Formatter(self.log_format,
183 datefmt="%Y-%m-%d %H:%M:%S")
190 datefmt="%Y-%m-%d %H:%M:%S")
184 self._log_handler.setFormatter(self._log_formatter)
191 self._log_handler.setFormatter(self._log_formatter)
185 # do not propagate log messages to root logger
192 # do not propagate log messages to root logger
186 # ipcluster app will sometimes print duplicate messages during shutdown
193 # ipcluster app will sometimes print duplicate messages during shutdown
187 # if this is 1 (default):
194 # if this is 1 (default):
188 self.log.propagate = False
195 self.log.propagate = False
189
196
190 def write_pid_file(self, overwrite=False):
197 def write_pid_file(self, overwrite=False):
191 """Create a .pid file in the pid_dir with my pid.
198 """Create a .pid file in the pid_dir with my pid.
192
199
193 This must be called after pre_construct, which sets `self.pid_dir`.
200 This must be called after pre_construct, which sets `self.pid_dir`.
194 This raises :exc:`PIDFileError` if the pid file exists already.
201 This raises :exc:`PIDFileError` if the pid file exists already.
195 """
202 """
196 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
203 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
197 if os.path.isfile(pid_file):
204 if os.path.isfile(pid_file):
198 pid = self.get_pid_from_file()
205 pid = self.get_pid_from_file()
199 if not overwrite:
206 if not overwrite:
200 raise PIDFileError(
207 raise PIDFileError(
201 'The pid file [%s] already exists. \nThis could mean that this '
208 'The pid file [%s] already exists. \nThis could mean that this '
202 'server is already running with [pid=%s].' % (pid_file, pid)
209 'server is already running with [pid=%s].' % (pid_file, pid)
203 )
210 )
204 with open(pid_file, 'w') as f:
211 with open(pid_file, 'w') as f:
205 self.log.info("Creating pid file: %s" % pid_file)
212 self.log.info("Creating pid file: %s" % pid_file)
206 f.write(repr(os.getpid())+'\n')
213 f.write(repr(os.getpid())+'\n')
207
214
208 def remove_pid_file(self):
215 def remove_pid_file(self):
209 """Remove the pid file.
216 """Remove the pid file.
210
217
211 This should be called at shutdown by registering a callback with
218 This should be called at shutdown by registering a callback with
212 :func:`reactor.addSystemEventTrigger`. This needs to return
219 :func:`reactor.addSystemEventTrigger`. This needs to return
213 ``None``.
220 ``None``.
214 """
221 """
215 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
222 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
216 if os.path.isfile(pid_file):
223 if os.path.isfile(pid_file):
217 try:
224 try:
218 self.log.info("Removing pid file: %s" % pid_file)
225 self.log.info("Removing pid file: %s" % pid_file)
219 os.remove(pid_file)
226 os.remove(pid_file)
220 except:
227 except:
221 self.log.warn("Error removing the pid file: %s" % pid_file)
228 self.log.warn("Error removing the pid file: %s" % pid_file)
222
229
223 def get_pid_from_file(self):
230 def get_pid_from_file(self):
224 """Get the pid from the pid file.
231 """Get the pid from the pid file.
225
232
226 If the pid file doesn't exist a :exc:`PIDFileError` is raised.
233 If the pid file doesn't exist a :exc:`PIDFileError` is raised.
227 """
234 """
228 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
235 pid_file = os.path.join(self.profile_dir.pid_dir, self.name + u'.pid')
229 if os.path.isfile(pid_file):
236 if os.path.isfile(pid_file):
230 with open(pid_file, 'r') as f:
237 with open(pid_file, 'r') as f:
231 s = f.read().strip()
238 s = f.read().strip()
232 try:
239 try:
233 pid = int(s)
240 pid = int(s)
234 except:
241 except:
235 raise PIDFileError("invalid pid file: %s (contents: %r)"%(pid_file, s))
242 raise PIDFileError("invalid pid file: %s (contents: %r)"%(pid_file, s))
236 return pid
243 return pid
237 else:
244 else:
238 raise PIDFileError('pid file not found: %s' % pid_file)
245 raise PIDFileError('pid file not found: %s' % pid_file)
239
246
240 def check_pid(self, pid):
247 def check_pid(self, pid):
241 if os.name == 'nt':
248 if os.name == 'nt':
242 try:
249 try:
243 import ctypes
250 import ctypes
244 # returns 0 if no such process (of ours) exists
251 # returns 0 if no such process (of ours) exists
245 # positive int otherwise
252 # positive int otherwise
246 p = ctypes.windll.kernel32.OpenProcess(1,0,pid)
253 p = ctypes.windll.kernel32.OpenProcess(1,0,pid)
247 except Exception:
254 except Exception:
248 self.log.warn(
255 self.log.warn(
249 "Could not determine whether pid %i is running via `OpenProcess`. "
256 "Could not determine whether pid %i is running via `OpenProcess`. "
250 " Making the likely assumption that it is."%pid
257 " Making the likely assumption that it is."%pid
251 )
258 )
252 return True
259 return True
253 return bool(p)
260 return bool(p)
254 else:
261 else:
255 try:
262 try:
256 p = Popen(['ps','x'], stdout=PIPE, stderr=PIPE)
263 p = Popen(['ps','x'], stdout=PIPE, stderr=PIPE)
257 output,_ = p.communicate()
264 output,_ = p.communicate()
258 except OSError:
265 except OSError:
259 self.log.warn(
266 self.log.warn(
260 "Could not determine whether pid %i is running via `ps x`. "
267 "Could not determine whether pid %i is running via `ps x`. "
261 " Making the likely assumption that it is."%pid
268 " Making the likely assumption that it is."%pid
262 )
269 )
263 return True
270 return True
264 pids = map(int, re.findall(r'^\W*\d+', output, re.MULTILINE))
271 pids = map(int, re.findall(r'^\W*\d+', output, re.MULTILINE))
265 return pid in pids
272 return pid in pids
@@ -1,493 +1,491 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """
3 """
4 The IPython controller application.
4 The IPython controller application.
5
5
6 Authors:
6 Authors:
7
7
8 * Brian Granger
8 * Brian Granger
9 * MinRK
9 * MinRK
10
10
11 """
11 """
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Copyright (C) 2008-2011 The IPython Development Team
14 # Copyright (C) 2008-2011 The IPython Development Team
15 #
15 #
16 # Distributed under the terms of the BSD License. The full license is in
16 # Distributed under the terms of the BSD License. The full license is in
17 # the file COPYING, distributed as part of this software.
17 # the file COPYING, distributed as part of this software.
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21 # Imports
21 # Imports
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23
23
24 from __future__ import with_statement
24 from __future__ import with_statement
25
25
26 import json
26 import json
27 import os
27 import os
28 import socket
28 import socket
29 import stat
29 import stat
30 import sys
30 import sys
31
31
32 from multiprocessing import Process
32 from multiprocessing import Process
33 from signal import signal, SIGINT, SIGABRT, SIGTERM
33 from signal import signal, SIGINT, SIGABRT, SIGTERM
34
34
35 import zmq
35 import zmq
36 from zmq.devices import ProcessMonitoredQueue
36 from zmq.devices import ProcessMonitoredQueue
37 from zmq.log.handlers import PUBHandler
37 from zmq.log.handlers import PUBHandler
38
38
39 from IPython.core.profiledir import ProfileDir
39 from IPython.core.profiledir import ProfileDir
40
40
41 from IPython.parallel.apps.baseapp import (
41 from IPython.parallel.apps.baseapp import (
42 BaseParallelApplication,
42 BaseParallelApplication,
43 base_aliases,
43 base_aliases,
44 base_flags,
44 base_flags,
45 catch_config_error,
45 catch_config_error,
46 )
46 )
47 from IPython.utils.importstring import import_item
47 from IPython.utils.importstring import import_item
48 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict, TraitError
48 from IPython.utils.traitlets import Instance, Unicode, Bool, List, Dict, TraitError
49
49
50 from IPython.zmq.session import (
50 from IPython.zmq.session import (
51 Session, session_aliases, session_flags, default_secure
51 Session, session_aliases, session_flags, default_secure
52 )
52 )
53
53
54 from IPython.parallel.controller.heartmonitor import HeartMonitor
54 from IPython.parallel.controller.heartmonitor import HeartMonitor
55 from IPython.parallel.controller.hub import HubFactory
55 from IPython.parallel.controller.hub import HubFactory
56 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
56 from IPython.parallel.controller.scheduler import TaskScheduler,launch_scheduler
57 from IPython.parallel.controller.sqlitedb import SQLiteDB
57 from IPython.parallel.controller.sqlitedb import SQLiteDB
58
58
59 from IPython.parallel.util import split_url, disambiguate_url
59 from IPython.parallel.util import split_url, disambiguate_url
60
60
61 # conditional import of MongoDB backend class
61 # conditional import of MongoDB backend class
62
62
63 try:
63 try:
64 from IPython.parallel.controller.mongodb import MongoDB
64 from IPython.parallel.controller.mongodb import MongoDB
65 except ImportError:
65 except ImportError:
66 maybe_mongo = []
66 maybe_mongo = []
67 else:
67 else:
68 maybe_mongo = [MongoDB]
68 maybe_mongo = [MongoDB]
69
69
70
70
71 #-----------------------------------------------------------------------------
71 #-----------------------------------------------------------------------------
72 # Module level variables
72 # Module level variables
73 #-----------------------------------------------------------------------------
73 #-----------------------------------------------------------------------------
74
74
75
75
76 #: The default config file name for this application
76 #: The default config file name for this application
77 default_config_file_name = u'ipcontroller_config.py'
77 default_config_file_name = u'ipcontroller_config.py'
78
78
79
79
80 _description = """Start the IPython controller for parallel computing.
80 _description = """Start the IPython controller for parallel computing.
81
81
82 The IPython controller provides a gateway between the IPython engines and
82 The IPython controller provides a gateway between the IPython engines and
83 clients. The controller needs to be started before the engines and can be
83 clients. The controller needs to be started before the engines and can be
84 configured using command line options or using a cluster directory. Cluster
84 configured using command line options or using a cluster directory. Cluster
85 directories contain config, log and security files and are usually located in
85 directories contain config, log and security files and are usually located in
86 your ipython directory and named as "profile_name". See the `profile`
86 your ipython directory and named as "profile_name". See the `profile`
87 and `profile-dir` options for details.
87 and `profile-dir` options for details.
88 """
88 """
89
89
90 _examples = """
90 _examples = """
91 ipcontroller --ip=192.168.0.1 --port=1000 # listen on ip, port for engines
91 ipcontroller --ip=192.168.0.1 --port=1000 # listen on ip, port for engines
92 ipcontroller --scheme=pure # use the pure zeromq scheduler
92 ipcontroller --scheme=pure # use the pure zeromq scheduler
93 """
93 """
94
94
95
95
96 #-----------------------------------------------------------------------------
96 #-----------------------------------------------------------------------------
97 # The main application
97 # The main application
98 #-----------------------------------------------------------------------------
98 #-----------------------------------------------------------------------------
99 flags = {}
99 flags = {}
100 flags.update(base_flags)
100 flags.update(base_flags)
101 flags.update({
101 flags.update({
102 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
102 'usethreads' : ( {'IPControllerApp' : {'use_threads' : True}},
103 'Use threads instead of processes for the schedulers'),
103 'Use threads instead of processes for the schedulers'),
104 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
104 'sqlitedb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.sqlitedb.SQLiteDB'}},
105 'use the SQLiteDB backend'),
105 'use the SQLiteDB backend'),
106 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
106 'mongodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.mongodb.MongoDB'}},
107 'use the MongoDB backend'),
107 'use the MongoDB backend'),
108 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
108 'dictdb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.DictDB'}},
109 'use the in-memory DictDB backend'),
109 'use the in-memory DictDB backend'),
110 'nodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.NoDB'}},
110 'nodb' : ({'HubFactory' : {'db_class' : 'IPython.parallel.controller.dictdb.NoDB'}},
111 """use dummy DB backend, which doesn't store any information.
111 """use dummy DB backend, which doesn't store any information.
112
112
113 This can be used to prevent growth of the memory footprint of the Hub
113 This can be used to prevent growth of the memory footprint of the Hub
114 in cases where its record-keeping is not required. Requesting results
114 in cases where its record-keeping is not required. Requesting results
115 of tasks submitted by other clients, db_queries, and task resubmission
115 of tasks submitted by other clients, db_queries, and task resubmission
116 will not be available."""),
116 will not be available."""),
117 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
117 'reuse' : ({'IPControllerApp' : {'reuse_files' : True}},
118 'reuse existing json connection files')
118 'reuse existing json connection files')
119 })
119 })
120
120
121 flags.update(session_flags)
121 flags.update(session_flags)
122
122
123 aliases = dict(
123 aliases = dict(
124 ssh = 'IPControllerApp.ssh_server',
124 ssh = 'IPControllerApp.ssh_server',
125 enginessh = 'IPControllerApp.engine_ssh_server',
125 enginessh = 'IPControllerApp.engine_ssh_server',
126 location = 'IPControllerApp.location',
126 location = 'IPControllerApp.location',
127
127
128 url = 'HubFactory.url',
128 url = 'HubFactory.url',
129 ip = 'HubFactory.ip',
129 ip = 'HubFactory.ip',
130 transport = 'HubFactory.transport',
130 transport = 'HubFactory.transport',
131 port = 'HubFactory.regport',
131 port = 'HubFactory.regport',
132
132
133 ping = 'HeartMonitor.period',
133 ping = 'HeartMonitor.period',
134
134
135 scheme = 'TaskScheduler.scheme_name',
135 scheme = 'TaskScheduler.scheme_name',
136 hwm = 'TaskScheduler.hwm',
136 hwm = 'TaskScheduler.hwm',
137 )
137 )
138 aliases.update(base_aliases)
138 aliases.update(base_aliases)
139 aliases.update(session_aliases)
139 aliases.update(session_aliases)
140
140
141
141
142 class IPControllerApp(BaseParallelApplication):
142 class IPControllerApp(BaseParallelApplication):
143
143
144 name = u'ipcontroller'
144 name = u'ipcontroller'
145 description = _description
145 description = _description
146 examples = _examples
146 examples = _examples
147 config_file_name = Unicode(default_config_file_name)
147 config_file_name = Unicode(default_config_file_name)
148 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
148 classes = [ProfileDir, Session, HubFactory, TaskScheduler, HeartMonitor, SQLiteDB] + maybe_mongo
149
149
150 # change default to True
150 # change default to True
151 auto_create = Bool(True, config=True,
151 auto_create = Bool(True, config=True,
152 help="""Whether to create profile dir if it doesn't exist.""")
152 help="""Whether to create profile dir if it doesn't exist.""")
153
153
154 reuse_files = Bool(False, config=True,
154 reuse_files = Bool(False, config=True,
155 help="""Whether to reuse existing json connection files.
155 help="""Whether to reuse existing json connection files.
156 If False, connection files will be removed on a clean exit.
156 If False, connection files will be removed on a clean exit.
157 """
157 """
158 )
158 )
159 ssh_server = Unicode(u'', config=True,
159 ssh_server = Unicode(u'', config=True,
160 help="""ssh url for clients to use when connecting to the Controller
160 help="""ssh url for clients to use when connecting to the Controller
161 processes. It should be of the form: [user@]server[:port]. The
161 processes. It should be of the form: [user@]server[:port]. The
162 Controller's listening addresses must be accessible from the ssh server""",
162 Controller's listening addresses must be accessible from the ssh server""",
163 )
163 )
164 engine_ssh_server = Unicode(u'', config=True,
164 engine_ssh_server = Unicode(u'', config=True,
165 help="""ssh url for engines to use when connecting to the Controller
165 help="""ssh url for engines to use when connecting to the Controller
166 processes. It should be of the form: [user@]server[:port]. The
166 processes. It should be of the form: [user@]server[:port]. The
167 Controller's listening addresses must be accessible from the ssh server""",
167 Controller's listening addresses must be accessible from the ssh server""",
168 )
168 )
169 location = Unicode(u'', config=True,
169 location = Unicode(u'', config=True,
170 help="""The external IP or domain name of the Controller, used for disambiguating
170 help="""The external IP or domain name of the Controller, used for disambiguating
171 engine and client connections.""",
171 engine and client connections.""",
172 )
172 )
173 import_statements = List([], config=True,
173 import_statements = List([], config=True,
174 help="import statements to be run at startup. Necessary in some environments"
174 help="import statements to be run at startup. Necessary in some environments"
175 )
175 )
176
176
177 use_threads = Bool(False, config=True,
177 use_threads = Bool(False, config=True,
178 help='Use threads instead of processes for the schedulers',
178 help='Use threads instead of processes for the schedulers',
179 )
179 )
180
180
181 engine_json_file = Unicode('ipcontroller-engine.json', config=True,
181 engine_json_file = Unicode('ipcontroller-engine.json', config=True,
182 help="JSON filename where engine connection info will be stored.")
182 help="JSON filename where engine connection info will be stored.")
183 client_json_file = Unicode('ipcontroller-client.json', config=True,
183 client_json_file = Unicode('ipcontroller-client.json', config=True,
184 help="JSON filename where client connection info will be stored.")
184 help="JSON filename where client connection info will be stored.")
185
185
186 def _cluster_id_changed(self, name, old, new):
186 def _cluster_id_changed(self, name, old, new):
187 super(IPControllerApp, self)._cluster_id_changed(name, old, new)
187 super(IPControllerApp, self)._cluster_id_changed(name, old, new)
188 self.engine_json_file = "%s-engine.json" % self.name
188 self.engine_json_file = "%s-engine.json" % self.name
189 self.client_json_file = "%s-client.json" % self.name
189 self.client_json_file = "%s-client.json" % self.name
190
190
191
191
192 # internal
192 # internal
193 children = List()
193 children = List()
194 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
194 mq_class = Unicode('zmq.devices.ProcessMonitoredQueue')
195
195
196 def _use_threads_changed(self, name, old, new):
196 def _use_threads_changed(self, name, old, new):
197 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
197 self.mq_class = 'zmq.devices.%sMonitoredQueue'%('Thread' if new else 'Process')
198
198
199 write_connection_files = Bool(True,
199 write_connection_files = Bool(True,
200 help="""Whether to write connection files to disk.
200 help="""Whether to write connection files to disk.
201 True in all cases other than runs with `reuse_files=True` *after the first*
201 True in all cases other than runs with `reuse_files=True` *after the first*
202 """
202 """
203 )
203 )
204
204
205 aliases = Dict(aliases)
205 aliases = Dict(aliases)
206 flags = Dict(flags)
206 flags = Dict(flags)
207
207
208
208
209 def save_connection_dict(self, fname, cdict):
209 def save_connection_dict(self, fname, cdict):
210 """save a connection dict to json file."""
210 """save a connection dict to json file."""
211 c = self.config
211 c = self.config
212 url = cdict['url']
212 url = cdict['url']
213 location = cdict['location']
213 location = cdict['location']
214 if not location:
214 if not location:
215 try:
215 try:
216 proto,ip,port = split_url(url)
216 proto,ip,port = split_url(url)
217 except AssertionError:
217 except AssertionError:
218 pass
218 pass
219 else:
219 else:
220 try:
220 try:
221 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
221 location = socket.gethostbyname_ex(socket.gethostname())[2][-1]
222 except (socket.gaierror, IndexError):
222 except (socket.gaierror, IndexError):
223 self.log.warn("Could not identify this machine's IP, assuming 127.0.0.1."
223 self.log.warn("Could not identify this machine's IP, assuming 127.0.0.1."
224 " You may need to specify '--location=<external_ip_address>' to help"
224 " You may need to specify '--location=<external_ip_address>' to help"
225 " IPython decide when to connect via loopback.")
225 " IPython decide when to connect via loopback.")
226 location = '127.0.0.1'
226 location = '127.0.0.1'
227 cdict['location'] = location
227 cdict['location'] = location
228 fname = os.path.join(self.profile_dir.security_dir, fname)
228 fname = os.path.join(self.profile_dir.security_dir, fname)
229 self.log.info("writing connection info to %s", fname)
229 self.log.info("writing connection info to %s", fname)
230 with open(fname, 'w') as f:
230 with open(fname, 'w') as f:
231 f.write(json.dumps(cdict, indent=2))
231 f.write(json.dumps(cdict, indent=2))
232 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
232 os.chmod(fname, stat.S_IRUSR|stat.S_IWUSR)
233
233
234 def load_config_from_json(self):
234 def load_config_from_json(self):
235 """load config from existing json connector files."""
235 """load config from existing json connector files."""
236 c = self.config
236 c = self.config
237 self.log.debug("loading config from JSON")
237 self.log.debug("loading config from JSON")
238 # load from engine config
238 # load from engine config
239 fname = os.path.join(self.profile_dir.security_dir, self.engine_json_file)
239 fname = os.path.join(self.profile_dir.security_dir, self.engine_json_file)
240 self.log.info("loading connection info from %s", fname)
240 self.log.info("loading connection info from %s", fname)
241 with open(fname) as f:
241 with open(fname) as f:
242 cfg = json.loads(f.read())
242 cfg = json.loads(f.read())
243 key = cfg['exec_key']
243 key = cfg['exec_key']
244 # json gives unicode, Session.key wants bytes
244 # json gives unicode, Session.key wants bytes
245 c.Session.key = key.encode('ascii')
245 c.Session.key = key.encode('ascii')
246 xport,addr = cfg['url'].split('://')
246 xport,addr = cfg['url'].split('://')
247 c.HubFactory.engine_transport = xport
247 c.HubFactory.engine_transport = xport
248 ip,ports = addr.split(':')
248 ip,ports = addr.split(':')
249 c.HubFactory.engine_ip = ip
249 c.HubFactory.engine_ip = ip
250 c.HubFactory.regport = int(ports)
250 c.HubFactory.regport = int(ports)
251 self.location = cfg['location']
251 self.location = cfg['location']
252 if not self.engine_ssh_server:
252 if not self.engine_ssh_server:
253 self.engine_ssh_server = cfg['ssh']
253 self.engine_ssh_server = cfg['ssh']
254 # load client config
254 # load client config
255 fname = os.path.join(self.profile_dir.security_dir, self.client_json_file)
255 fname = os.path.join(self.profile_dir.security_dir, self.client_json_file)
256 self.log.info("loading connection info from %s", fname)
256 self.log.info("loading connection info from %s", fname)
257 with open(fname) as f:
257 with open(fname) as f:
258 cfg = json.loads(f.read())
258 cfg = json.loads(f.read())
259 assert key == cfg['exec_key'], "exec_key mismatch between engine and client keys"
259 assert key == cfg['exec_key'], "exec_key mismatch between engine and client keys"
260 xport,addr = cfg['url'].split('://')
260 xport,addr = cfg['url'].split('://')
261 c.HubFactory.client_transport = xport
261 c.HubFactory.client_transport = xport
262 ip,ports = addr.split(':')
262 ip,ports = addr.split(':')
263 c.HubFactory.client_ip = ip
263 c.HubFactory.client_ip = ip
264 if not self.ssh_server:
264 if not self.ssh_server:
265 self.ssh_server = cfg['ssh']
265 self.ssh_server = cfg['ssh']
266 assert int(ports) == c.HubFactory.regport, "regport mismatch"
266 assert int(ports) == c.HubFactory.regport, "regport mismatch"
267
267
268 def cleanup_connection_files(self):
268 def cleanup_connection_files(self):
269 if self.reuse_files:
269 if self.reuse_files:
270 self.log.debug("leaving JSON connection files for reuse")
270 self.log.debug("leaving JSON connection files for reuse")
271 return
271 return
272 self.log.debug("cleaning up JSON connection files")
272 self.log.debug("cleaning up JSON connection files")
273 for f in (self.client_json_file, self.engine_json_file):
273 for f in (self.client_json_file, self.engine_json_file):
274 f = os.path.join(self.profile_dir.security_dir, f)
274 f = os.path.join(self.profile_dir.security_dir, f)
275 try:
275 try:
276 os.remove(f)
276 os.remove(f)
277 except Exception as e:
277 except Exception as e:
278 self.log.error("Failed to cleanup connection file: %s", e)
278 self.log.error("Failed to cleanup connection file: %s", e)
279 else:
279 else:
280 self.log.debug(u"removed %s", f)
280 self.log.debug(u"removed %s", f)
281
281
282 def load_secondary_config(self):
282 def load_secondary_config(self):
283 """secondary config, loading from JSON and setting defaults"""
283 """secondary config, loading from JSON and setting defaults"""
284 if self.reuse_files:
284 if self.reuse_files:
285 try:
285 try:
286 self.load_config_from_json()
286 self.load_config_from_json()
287 except (AssertionError,IOError) as e:
287 except (AssertionError,IOError) as e:
288 self.log.error("Could not load config from JSON: %s" % e)
288 self.log.error("Could not load config from JSON: %s" % e)
289 else:
289 else:
290 # successfully loaded config from JSON, and reuse=True
290 # successfully loaded config from JSON, and reuse=True
291 # no need to wite back the same file
291 # no need to wite back the same file
292 self.write_connection_files = False
292 self.write_connection_files = False
293
293
294 # switch Session.key default to secure
294 # switch Session.key default to secure
295 default_secure(self.config)
295 default_secure(self.config)
296 self.log.debug("Config changed")
296 self.log.debug("Config changed")
297 self.log.debug(repr(self.config))
297 self.log.debug(repr(self.config))
298
298
299 def init_hub(self):
299 def init_hub(self):
300 c = self.config
300 c = self.config
301
301
302 self.do_import_statements()
302 self.do_import_statements()
303
303
304 try:
304 try:
305 self.factory = HubFactory(config=c, log=self.log)
305 self.factory = HubFactory(config=c, log=self.log)
306 # self.start_logging()
306 # self.start_logging()
307 self.factory.init_hub()
307 self.factory.init_hub()
308 except TraitError:
308 except TraitError:
309 raise
309 raise
310 except Exception:
310 except Exception:
311 self.log.error("Couldn't construct the Controller", exc_info=True)
311 self.log.error("Couldn't construct the Controller", exc_info=True)
312 self.exit(1)
312 self.exit(1)
313
313
314 if self.write_connection_files:
314 if self.write_connection_files:
315 # save to new json config files
315 # save to new json config files
316 f = self.factory
316 f = self.factory
317 cdict = {'exec_key' : f.session.key.decode('ascii'),
317 cdict = {'exec_key' : f.session.key.decode('ascii'),
318 'ssh' : self.ssh_server,
318 'ssh' : self.ssh_server,
319 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport),
319 'url' : "%s://%s:%s"%(f.client_transport, f.client_ip, f.regport),
320 'location' : self.location
320 'location' : self.location
321 }
321 }
322 self.save_connection_dict(self.client_json_file, cdict)
322 self.save_connection_dict(self.client_json_file, cdict)
323 edict = cdict
323 edict = cdict
324 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
324 edict['url']="%s://%s:%s"%((f.client_transport, f.client_ip, f.regport))
325 edict['ssh'] = self.engine_ssh_server
325 edict['ssh'] = self.engine_ssh_server
326 self.save_connection_dict(self.engine_json_file, edict)
326 self.save_connection_dict(self.engine_json_file, edict)
327
327
328 def init_schedulers(self):
328 def init_schedulers(self):
329 children = self.children
329 children = self.children
330 mq = import_item(str(self.mq_class))
330 mq = import_item(str(self.mq_class))
331
331
332 hub = self.factory
332 hub = self.factory
333 # disambiguate url, in case of *
333 # disambiguate url, in case of *
334 monitor_url = disambiguate_url(hub.monitor_url)
334 monitor_url = disambiguate_url(hub.monitor_url)
335 # maybe_inproc = 'inproc://monitor' if self.use_threads else monitor_url
335 # maybe_inproc = 'inproc://monitor' if self.use_threads else monitor_url
336 # IOPub relay (in a Process)
336 # IOPub relay (in a Process)
337 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, b'N/A',b'iopub')
337 q = mq(zmq.PUB, zmq.SUB, zmq.PUB, b'N/A',b'iopub')
338 q.bind_in(hub.client_info['iopub'])
338 q.bind_in(hub.client_info['iopub'])
339 q.bind_out(hub.engine_info['iopub'])
339 q.bind_out(hub.engine_info['iopub'])
340 q.setsockopt_out(zmq.SUBSCRIBE, b'')
340 q.setsockopt_out(zmq.SUBSCRIBE, b'')
341 q.connect_mon(monitor_url)
341 q.connect_mon(monitor_url)
342 q.daemon=True
342 q.daemon=True
343 children.append(q)
343 children.append(q)
344
344
345 # Multiplexer Queue (in a Process)
345 # Multiplexer Queue (in a Process)
346 q = mq(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out')
346 q = mq(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'in', b'out')
347 q.bind_in(hub.client_info['mux'])
347 q.bind_in(hub.client_info['mux'])
348 q.setsockopt_in(zmq.IDENTITY, b'mux')
348 q.setsockopt_in(zmq.IDENTITY, b'mux')
349 q.bind_out(hub.engine_info['mux'])
349 q.bind_out(hub.engine_info['mux'])
350 q.connect_mon(monitor_url)
350 q.connect_mon(monitor_url)
351 q.daemon=True
351 q.daemon=True
352 children.append(q)
352 children.append(q)
353
353
354 # Control Queue (in a Process)
354 # Control Queue (in a Process)
355 q = mq(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'incontrol', b'outcontrol')
355 q = mq(zmq.ROUTER, zmq.ROUTER, zmq.PUB, b'incontrol', b'outcontrol')
356 q.bind_in(hub.client_info['control'])
356 q.bind_in(hub.client_info['control'])
357 q.setsockopt_in(zmq.IDENTITY, b'control')
357 q.setsockopt_in(zmq.IDENTITY, b'control')
358 q.bind_out(hub.engine_info['control'])
358 q.bind_out(hub.engine_info['control'])
359 q.connect_mon(monitor_url)
359 q.connect_mon(monitor_url)
360 q.daemon=True
360 q.daemon=True
361 children.append(q)
361 children.append(q)
362 try:
362 try:
363 scheme = self.config.TaskScheduler.scheme_name
363 scheme = self.config.TaskScheduler.scheme_name
364 except AttributeError:
364 except AttributeError:
365 scheme = TaskScheduler.scheme_name.get_default_value()
365 scheme = TaskScheduler.scheme_name.get_default_value()
366 # Task Queue (in a Process)
366 # Task Queue (in a Process)
367 if scheme == 'pure':
367 if scheme == 'pure':
368 self.log.warn("task::using pure XREQ Task scheduler")
368 self.log.warn("task::using pure XREQ Task scheduler")
369 q = mq(zmq.ROUTER, zmq.DEALER, zmq.PUB, b'intask', b'outtask')
369 q = mq(zmq.ROUTER, zmq.DEALER, zmq.PUB, b'intask', b'outtask')
370 # q.setsockopt_out(zmq.HWM, hub.hwm)
370 # q.setsockopt_out(zmq.HWM, hub.hwm)
371 q.bind_in(hub.client_info['task'][1])
371 q.bind_in(hub.client_info['task'][1])
372 q.setsockopt_in(zmq.IDENTITY, b'task')
372 q.setsockopt_in(zmq.IDENTITY, b'task')
373 q.bind_out(hub.engine_info['task'])
373 q.bind_out(hub.engine_info['task'])
374 q.connect_mon(monitor_url)
374 q.connect_mon(monitor_url)
375 q.daemon=True
375 q.daemon=True
376 children.append(q)
376 children.append(q)
377 elif scheme == 'none':
377 elif scheme == 'none':
378 self.log.warn("task::using no Task scheduler")
378 self.log.warn("task::using no Task scheduler")
379
379
380 else:
380 else:
381 self.log.info("task::using Python %s Task scheduler"%scheme)
381 self.log.info("task::using Python %s Task scheduler"%scheme)
382 sargs = (hub.client_info['task'][1], hub.engine_info['task'],
382 sargs = (hub.client_info['task'][1], hub.engine_info['task'],
383 monitor_url, disambiguate_url(hub.client_info['notification']))
383 monitor_url, disambiguate_url(hub.client_info['notification']))
384 kwargs = dict(logname='scheduler', loglevel=self.log_level,
384 kwargs = dict(logname='scheduler', loglevel=self.log_level,
385 log_url = self.log_url, config=dict(self.config))
385 log_url = self.log_url, config=dict(self.config))
386 if 'Process' in self.mq_class:
386 if 'Process' in self.mq_class:
387 # run the Python scheduler in a Process
387 # run the Python scheduler in a Process
388 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
388 q = Process(target=launch_scheduler, args=sargs, kwargs=kwargs)
389 q.daemon=True
389 q.daemon=True
390 children.append(q)
390 children.append(q)
391 else:
391 else:
392 # single-threaded Controller
392 # single-threaded Controller
393 kwargs['in_thread'] = True
393 kwargs['in_thread'] = True
394 launch_scheduler(*sargs, **kwargs)
394 launch_scheduler(*sargs, **kwargs)
395
395
396 def terminate_children(self):
396 def terminate_children(self):
397 child_procs = []
397 child_procs = []
398 for child in self.children:
398 for child in self.children:
399 if isinstance(child, ProcessMonitoredQueue):
399 if isinstance(child, ProcessMonitoredQueue):
400 child_procs.append(child.launcher)
400 child_procs.append(child.launcher)
401 elif isinstance(child, Process):
401 elif isinstance(child, Process):
402 child_procs.append(child)
402 child_procs.append(child)
403 if child_procs:
403 if child_procs:
404 self.log.critical("terminating children...")
404 self.log.critical("terminating children...")
405 for child in child_procs:
405 for child in child_procs:
406 try:
406 try:
407 child.terminate()
407 child.terminate()
408 except OSError:
408 except OSError:
409 # already dead
409 # already dead
410 pass
410 pass
411
411
412 def handle_signal(self, sig, frame):
412 def handle_signal(self, sig, frame):
413 self.log.critical("Received signal %i, shutting down", sig)
413 self.log.critical("Received signal %i, shutting down", sig)
414 self.terminate_children()
414 self.terminate_children()
415 self.loop.stop()
415 self.loop.stop()
416
416
417 def init_signal(self):
417 def init_signal(self):
418 for sig in (SIGINT, SIGABRT, SIGTERM):
418 for sig in (SIGINT, SIGABRT, SIGTERM):
419 signal(sig, self.handle_signal)
419 signal(sig, self.handle_signal)
420
420
421 def do_import_statements(self):
421 def do_import_statements(self):
422 statements = self.import_statements
422 statements = self.import_statements
423 for s in statements:
423 for s in statements:
424 try:
424 try:
425 self.log.msg("Executing statement: '%s'" % s)
425 self.log.msg("Executing statement: '%s'" % s)
426 exec s in globals(), locals()
426 exec s in globals(), locals()
427 except:
427 except:
428 self.log.msg("Error running statement: %s" % s)
428 self.log.msg("Error running statement: %s" % s)
429
429
430 def forward_logging(self):
430 def forward_logging(self):
431 if self.log_url:
431 if self.log_url:
432 self.log.info("Forwarding logging to %s"%self.log_url)
432 self.log.info("Forwarding logging to %s"%self.log_url)
433 context = zmq.Context.instance()
433 context = zmq.Context.instance()
434 lsock = context.socket(zmq.PUB)
434 lsock = context.socket(zmq.PUB)
435 lsock.connect(self.log_url)
435 lsock.connect(self.log_url)
436 handler = PUBHandler(lsock)
436 handler = PUBHandler(lsock)
437 self.log.removeHandler(self._log_handler)
438 handler.root_topic = 'controller'
437 handler.root_topic = 'controller'
439 handler.setLevel(self.log_level)
438 handler.setLevel(self.log_level)
440 self.log.addHandler(handler)
439 self.log.addHandler(handler)
441 self._log_handler = handler
442
440
443 @catch_config_error
441 @catch_config_error
444 def initialize(self, argv=None):
442 def initialize(self, argv=None):
445 super(IPControllerApp, self).initialize(argv)
443 super(IPControllerApp, self).initialize(argv)
446 self.forward_logging()
444 self.forward_logging()
447 self.load_secondary_config()
445 self.load_secondary_config()
448 self.init_hub()
446 self.init_hub()
449 self.init_schedulers()
447 self.init_schedulers()
450
448
451 def start(self):
449 def start(self):
452 # Start the subprocesses:
450 # Start the subprocesses:
453 self.factory.start()
451 self.factory.start()
454 # children must be started before signals are setup,
452 # children must be started before signals are setup,
455 # otherwise signal-handling will fire multiple times
453 # otherwise signal-handling will fire multiple times
456 for child in self.children:
454 for child in self.children:
457 child.start()
455 child.start()
458 self.init_signal()
456 self.init_signal()
459
457
460 self.write_pid_file(overwrite=True)
458 self.write_pid_file(overwrite=True)
461
459
462 try:
460 try:
463 self.factory.loop.start()
461 self.factory.loop.start()
464 except KeyboardInterrupt:
462 except KeyboardInterrupt:
465 self.log.critical("Interrupted, Exiting...\n")
463 self.log.critical("Interrupted, Exiting...\n")
466 finally:
464 finally:
467 self.cleanup_connection_files()
465 self.cleanup_connection_files()
468
466
469
467
470
468
471 def launch_new_instance():
469 def launch_new_instance():
472 """Create and run the IPython controller"""
470 """Create and run the IPython controller"""
473 if sys.platform == 'win32':
471 if sys.platform == 'win32':
474 # make sure we don't get called from a multiprocessing subprocess
472 # make sure we don't get called from a multiprocessing subprocess
475 # this can result in infinite Controllers being started on Windows
473 # this can result in infinite Controllers being started on Windows
476 # which doesn't have a proper fork, so multiprocessing is wonky
474 # which doesn't have a proper fork, so multiprocessing is wonky
477
475
478 # this only comes up when IPython has been installed using vanilla
476 # this only comes up when IPython has been installed using vanilla
479 # setuptools, and *not* distribute.
477 # setuptools, and *not* distribute.
480 import multiprocessing
478 import multiprocessing
481 p = multiprocessing.current_process()
479 p = multiprocessing.current_process()
482 # the main process has name 'MainProcess'
480 # the main process has name 'MainProcess'
483 # subprocesses will have names like 'Process-1'
481 # subprocesses will have names like 'Process-1'
484 if p.name != 'MainProcess':
482 if p.name != 'MainProcess':
485 # we are a subprocess, don't start another Controller!
483 # we are a subprocess, don't start another Controller!
486 return
484 return
487 app = IPControllerApp.instance()
485 app = IPControllerApp.instance()
488 app.initialize()
486 app.initialize()
489 app.start()
487 app.start()
490
488
491
489
492 if __name__ == '__main__':
490 if __name__ == '__main__':
493 launch_new_instance()
491 launch_new_instance()
@@ -1,330 +1,377 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 # encoding: utf-8
2 # encoding: utf-8
3 """
3 """
4 The IPython engine application
4 The IPython engine application
5
5
6 Authors:
6 Authors:
7
7
8 * Brian Granger
8 * Brian Granger
9 * MinRK
9 * MinRK
10
10
11 """
11 """
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Copyright (C) 2008-2011 The IPython Development Team
14 # Copyright (C) 2008-2011 The IPython Development Team
15 #
15 #
16 # Distributed under the terms of the BSD License. The full license is in
16 # Distributed under the terms of the BSD License. The full license is in
17 # the file COPYING, distributed as part of this software.
17 # the file COPYING, distributed as part of this software.
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21 # Imports
21 # Imports
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23
23
24 import json
24 import json
25 import os
25 import os
26 import sys
26 import sys
27 import time
27 import time
28
28
29 import zmq
29 import zmq
30 from zmq.eventloop import ioloop
30 from zmq.eventloop import ioloop
31
31
32 from IPython.core.profiledir import ProfileDir
32 from IPython.core.profiledir import ProfileDir
33 from IPython.parallel.apps.baseapp import (
33 from IPython.parallel.apps.baseapp import (
34 BaseParallelApplication,
34 BaseParallelApplication,
35 base_aliases,
35 base_aliases,
36 base_flags,
36 base_flags,
37 catch_config_error,
37 catch_config_error,
38 )
38 )
39 from IPython.zmq.log import EnginePUBHandler
39 from IPython.zmq.log import EnginePUBHandler
40 from IPython.zmq.ipkernel import Kernel, IPKernelApp
40 from IPython.zmq.session import (
41 from IPython.zmq.session import (
41 Session, session_aliases, session_flags
42 Session, session_aliases, session_flags
42 )
43 )
43
44
44 from IPython.config.configurable import Configurable
45 from IPython.config.configurable import Configurable
45
46
46 from IPython.parallel.engine.engine import EngineFactory
47 from IPython.parallel.engine.engine import EngineFactory
47 from IPython.parallel.engine.streamkernel import Kernel
48 from IPython.parallel.util import disambiguate_url
48 from IPython.parallel.util import disambiguate_url, asbytes
49
49
50 from IPython.utils.importstring import import_item
50 from IPython.utils.importstring import import_item
51 from IPython.utils.traitlets import Bool, Unicode, Dict, List, Float
51 from IPython.utils.py3compat import cast_bytes
52 from IPython.utils.traitlets import Bool, Unicode, Dict, List, Float, Instance
52
53
53
54
54 #-----------------------------------------------------------------------------
55 #-----------------------------------------------------------------------------
55 # Module level variables
56 # Module level variables
56 #-----------------------------------------------------------------------------
57 #-----------------------------------------------------------------------------
57
58
58 #: The default config file name for this application
59 #: The default config file name for this application
59 default_config_file_name = u'ipengine_config.py'
60 default_config_file_name = u'ipengine_config.py'
60
61
61 _description = """Start an IPython engine for parallel computing.
62 _description = """Start an IPython engine for parallel computing.
62
63
63 IPython engines run in parallel and perform computations on behalf of a client
64 IPython engines run in parallel and perform computations on behalf of a client
64 and controller. A controller needs to be started before the engines. The
65 and controller. A controller needs to be started before the engines. The
65 engine can be configured using command line options or using a cluster
66 engine can be configured using command line options or using a cluster
66 directory. Cluster directories contain config, log and security files and are
67 directory. Cluster directories contain config, log and security files and are
67 usually located in your ipython directory and named as "profile_name".
68 usually located in your ipython directory and named as "profile_name".
68 See the `profile` and `profile-dir` options for details.
69 See the `profile` and `profile-dir` options for details.
69 """
70 """
70
71
71 _examples = """
72 _examples = """
72 ipengine --ip=192.168.0.1 --port=1000 # connect to hub at ip and port
73 ipengine --ip=192.168.0.1 --port=1000 # connect to hub at ip and port
73 ipengine --log-to-file --log-level=DEBUG # log to a file with DEBUG verbosity
74 ipengine --log-to-file --log-level=DEBUG # log to a file with DEBUG verbosity
74 """
75 """
75
76
76 #-----------------------------------------------------------------------------
77 #-----------------------------------------------------------------------------
77 # MPI configuration
78 # MPI configuration
78 #-----------------------------------------------------------------------------
79 #-----------------------------------------------------------------------------
79
80
80 mpi4py_init = """from mpi4py import MPI as mpi
81 mpi4py_init = """from mpi4py import MPI as mpi
81 mpi.size = mpi.COMM_WORLD.Get_size()
82 mpi.size = mpi.COMM_WORLD.Get_size()
82 mpi.rank = mpi.COMM_WORLD.Get_rank()
83 mpi.rank = mpi.COMM_WORLD.Get_rank()
83 """
84 """
84
85
85
86
86 pytrilinos_init = """from PyTrilinos import Epetra
87 pytrilinos_init = """from PyTrilinos import Epetra
87 class SimpleStruct:
88 class SimpleStruct:
88 pass
89 pass
89 mpi = SimpleStruct()
90 mpi = SimpleStruct()
90 mpi.rank = 0
91 mpi.rank = 0
91 mpi.size = 0
92 mpi.size = 0
92 """
93 """
93
94
94 class MPI(Configurable):
95 class MPI(Configurable):
95 """Configurable for MPI initialization"""
96 """Configurable for MPI initialization"""
96 use = Unicode('', config=True,
97 use = Unicode('', config=True,
97 help='How to enable MPI (mpi4py, pytrilinos, or empty string to disable).'
98 help='How to enable MPI (mpi4py, pytrilinos, or empty string to disable).'
98 )
99 )
99
100
100 def _use_changed(self, name, old, new):
101 def _use_changed(self, name, old, new):
101 # load default init script if it's not set
102 # load default init script if it's not set
102 if not self.init_script:
103 if not self.init_script:
103 self.init_script = self.default_inits.get(new, '')
104 self.init_script = self.default_inits.get(new, '')
104
105
105 init_script = Unicode('', config=True,
106 init_script = Unicode('', config=True,
106 help="Initialization code for MPI")
107 help="Initialization code for MPI")
107
108
108 default_inits = Dict({'mpi4py' : mpi4py_init, 'pytrilinos':pytrilinos_init},
109 default_inits = Dict({'mpi4py' : mpi4py_init, 'pytrilinos':pytrilinos_init},
109 config=True)
110 config=True)
110
111
111
112
112 #-----------------------------------------------------------------------------
113 #-----------------------------------------------------------------------------
113 # Main application
114 # Main application
114 #-----------------------------------------------------------------------------
115 #-----------------------------------------------------------------------------
115 aliases = dict(
116 aliases = dict(
116 file = 'IPEngineApp.url_file',
117 file = 'IPEngineApp.url_file',
117 c = 'IPEngineApp.startup_command',
118 c = 'IPEngineApp.startup_command',
118 s = 'IPEngineApp.startup_script',
119 s = 'IPEngineApp.startup_script',
119
120
120 url = 'EngineFactory.url',
121 url = 'EngineFactory.url',
121 ssh = 'EngineFactory.sshserver',
122 ssh = 'EngineFactory.sshserver',
122 sshkey = 'EngineFactory.sshkey',
123 sshkey = 'EngineFactory.sshkey',
123 ip = 'EngineFactory.ip',
124 ip = 'EngineFactory.ip',
124 transport = 'EngineFactory.transport',
125 transport = 'EngineFactory.transport',
125 port = 'EngineFactory.regport',
126 port = 'EngineFactory.regport',
126 location = 'EngineFactory.location',
127 location = 'EngineFactory.location',
127
128
128 timeout = 'EngineFactory.timeout',
129 timeout = 'EngineFactory.timeout',
129
130
130 mpi = 'MPI.use',
131 mpi = 'MPI.use',
131
132
132 )
133 )
133 aliases.update(base_aliases)
134 aliases.update(base_aliases)
134 aliases.update(session_aliases)
135 aliases.update(session_aliases)
135 flags = {}
136 flags = {}
136 flags.update(base_flags)
137 flags.update(base_flags)
137 flags.update(session_flags)
138 flags.update(session_flags)
138
139
139 class IPEngineApp(BaseParallelApplication):
140 class IPEngineApp(BaseParallelApplication):
140
141
141 name = 'ipengine'
142 name = 'ipengine'
142 description = _description
143 description = _description
143 examples = _examples
144 examples = _examples
144 config_file_name = Unicode(default_config_file_name)
145 config_file_name = Unicode(default_config_file_name)
145 classes = List([ProfileDir, Session, EngineFactory, Kernel, MPI])
146 classes = List([ProfileDir, Session, EngineFactory, Kernel, MPI])
146
147
147 startup_script = Unicode(u'', config=True,
148 startup_script = Unicode(u'', config=True,
148 help='specify a script to be run at startup')
149 help='specify a script to be run at startup')
149 startup_command = Unicode('', config=True,
150 startup_command = Unicode('', config=True,
150 help='specify a command to be run at startup')
151 help='specify a command to be run at startup')
151
152
152 url_file = Unicode(u'', config=True,
153 url_file = Unicode(u'', config=True,
153 help="""The full location of the file containing the connection information for
154 help="""The full location of the file containing the connection information for
154 the controller. If this is not given, the file must be in the
155 the controller. If this is not given, the file must be in the
155 security directory of the cluster directory. This location is
156 security directory of the cluster directory. This location is
156 resolved using the `profile` or `profile_dir` options.""",
157 resolved using the `profile` or `profile_dir` options.""",
157 )
158 )
158 wait_for_url_file = Float(5, config=True,
159 wait_for_url_file = Float(5, config=True,
159 help="""The maximum number of seconds to wait for url_file to exist.
160 help="""The maximum number of seconds to wait for url_file to exist.
160 This is useful for batch-systems and shared-filesystems where the
161 This is useful for batch-systems and shared-filesystems where the
161 controller and engine are started at the same time and it
162 controller and engine are started at the same time and it
162 may take a moment for the controller to write the connector files.""")
163 may take a moment for the controller to write the connector files.""")
163
164
164 url_file_name = Unicode(u'ipcontroller-engine.json', config=True)
165 url_file_name = Unicode(u'ipcontroller-engine.json', config=True)
165
166
166 def _cluster_id_changed(self, name, old, new):
167 def _cluster_id_changed(self, name, old, new):
167 if new:
168 if new:
168 base = 'ipcontroller-%s' % new
169 base = 'ipcontroller-%s' % new
169 else:
170 else:
170 base = 'ipcontroller'
171 base = 'ipcontroller'
171 self.url_file_name = "%s-engine.json" % base
172 self.url_file_name = "%s-engine.json" % base
172
173
173 log_url = Unicode('', config=True,
174 log_url = Unicode('', config=True,
174 help="""The URL for the iploggerapp instance, for forwarding
175 help="""The URL for the iploggerapp instance, for forwarding
175 logging to a central location.""")
176 logging to a central location.""")
177
178 # an IPKernelApp instance, used to setup listening for shell frontends
179 kernel_app = Instance(IPKernelApp)
176
180
177 aliases = Dict(aliases)
181 aliases = Dict(aliases)
178 flags = Dict(flags)
182 flags = Dict(flags)
183
184 @property
185 def kernel(self):
186 """allow access to the Kernel object, so I look like IPKernelApp"""
187 return self.engine.kernel
179
188
180 def find_url_file(self):
189 def find_url_file(self):
181 """Set the url file.
190 """Set the url file.
182
191
183 Here we don't try to actually see if it exists for is valid as that
192 Here we don't try to actually see if it exists for is valid as that
184 is hadled by the connection logic.
193 is hadled by the connection logic.
185 """
194 """
186 config = self.config
195 config = self.config
187 # Find the actual controller key file
196 # Find the actual controller key file
188 if not self.url_file:
197 if not self.url_file:
189 self.url_file = os.path.join(
198 self.url_file = os.path.join(
190 self.profile_dir.security_dir,
199 self.profile_dir.security_dir,
191 self.url_file_name
200 self.url_file_name
192 )
201 )
193
202
194 def load_connector_file(self):
203 def load_connector_file(self):
195 """load config from a JSON connector file,
204 """load config from a JSON connector file,
196 at a *lower* priority than command-line/config files.
205 at a *lower* priority than command-line/config files.
197 """
206 """
198
207
199 self.log.info("Loading url_file %r", self.url_file)
208 self.log.info("Loading url_file %r", self.url_file)
200 config = self.config
209 config = self.config
201
210
202 with open(self.url_file) as f:
211 with open(self.url_file) as f:
203 d = json.loads(f.read())
212 d = json.loads(f.read())
204
213
205 if 'exec_key' in d:
214 if 'exec_key' in d:
206 config.Session.key = asbytes(d['exec_key'])
215 config.Session.key = cast_bytes(d['exec_key'])
207
216
208 try:
217 try:
209 config.EngineFactory.location
218 config.EngineFactory.location
210 except AttributeError:
219 except AttributeError:
211 config.EngineFactory.location = d['location']
220 config.EngineFactory.location = d['location']
212
221
213 d['url'] = disambiguate_url(d['url'], config.EngineFactory.location)
222 d['url'] = disambiguate_url(d['url'], config.EngineFactory.location)
214 try:
223 try:
215 config.EngineFactory.url
224 config.EngineFactory.url
216 except AttributeError:
225 except AttributeError:
217 config.EngineFactory.url = d['url']
226 config.EngineFactory.url = d['url']
218
227
219 try:
228 try:
220 config.EngineFactory.sshserver
229 config.EngineFactory.sshserver
221 except AttributeError:
230 except AttributeError:
222 config.EngineFactory.sshserver = d['ssh']
231 config.EngineFactory.sshserver = d['ssh']
232
233 def bind_kernel(self, **kwargs):
234 """Promote engine to listening kernel, accessible to frontends."""
235 if self.kernel_app is not None:
236 return
237
238 self.log.info("Opening ports for direct connections as an IPython kernel")
239
240 kernel = self.kernel
241
242 kwargs.setdefault('config', self.config)
243 kwargs.setdefault('log', self.log)
244 kwargs.setdefault('profile_dir', self.profile_dir)
245 kwargs.setdefault('session', self.engine.session)
246
247 app = self.kernel_app = IPKernelApp(**kwargs)
248
249 # allow IPKernelApp.instance():
250 IPKernelApp._instance = app
223
251
252 app.init_connection_file()
253 # relevant contents of init_sockets:
254
255 app.shell_port = app._bind_socket(kernel.shell_streams[0], app.shell_port)
256 app.log.debug("shell ROUTER Channel on port: %i", app.shell_port)
257
258 app.iopub_port = app._bind_socket(kernel.iopub_socket, app.iopub_port)
259 app.log.debug("iopub PUB Channel on port: %i", app.iopub_port)
260
261 kernel.stdin_socket = self.engine.context.socket(zmq.ROUTER)
262 app.stdin_port = app._bind_socket(kernel.stdin_socket, app.stdin_port)
263 app.log.debug("stdin ROUTER Channel on port: %i", app.stdin_port)
264
265 # start the heartbeat, and log connection info:
266
267 app.init_heartbeat()
268
269 app.log_connection_info()
270 app.write_connection_file()
271
272
224 def init_engine(self):
273 def init_engine(self):
225 # This is the working dir by now.
274 # This is the working dir by now.
226 sys.path.insert(0, '')
275 sys.path.insert(0, '')
227 config = self.config
276 config = self.config
228 # print config
277 # print config
229 self.find_url_file()
278 self.find_url_file()
230
279
231 # was the url manually specified?
280 # was the url manually specified?
232 keys = set(self.config.EngineFactory.keys())
281 keys = set(self.config.EngineFactory.keys())
233 keys = keys.union(set(self.config.RegistrationFactory.keys()))
282 keys = keys.union(set(self.config.RegistrationFactory.keys()))
234
283
235 if keys.intersection(set(['ip', 'url', 'port'])):
284 if keys.intersection(set(['ip', 'url', 'port'])):
236 # Connection info was specified, don't wait for the file
285 # Connection info was specified, don't wait for the file
237 url_specified = True
286 url_specified = True
238 self.wait_for_url_file = 0
287 self.wait_for_url_file = 0
239 else:
288 else:
240 url_specified = False
289 url_specified = False
241
290
242 if self.wait_for_url_file and not os.path.exists(self.url_file):
291 if self.wait_for_url_file and not os.path.exists(self.url_file):
243 self.log.warn("url_file %r not found", self.url_file)
292 self.log.warn("url_file %r not found", self.url_file)
244 self.log.warn("Waiting up to %.1f seconds for it to arrive.", self.wait_for_url_file)
293 self.log.warn("Waiting up to %.1f seconds for it to arrive.", self.wait_for_url_file)
245 tic = time.time()
294 tic = time.time()
246 while not os.path.exists(self.url_file) and (time.time()-tic < self.wait_for_url_file):
295 while not os.path.exists(self.url_file) and (time.time()-tic < self.wait_for_url_file):
247 # wait for url_file to exist, or until time limit
296 # wait for url_file to exist, or until time limit
248 time.sleep(0.1)
297 time.sleep(0.1)
249
298
250 if os.path.exists(self.url_file):
299 if os.path.exists(self.url_file):
251 self.load_connector_file()
300 self.load_connector_file()
252 elif not url_specified:
301 elif not url_specified:
253 self.log.fatal("Fatal: url file never arrived: %s", self.url_file)
302 self.log.fatal("Fatal: url file never arrived: %s", self.url_file)
254 self.exit(1)
303 self.exit(1)
255
304
256
305
257 try:
306 try:
258 exec_lines = config.Kernel.exec_lines
307 exec_lines = config.Kernel.exec_lines
259 except AttributeError:
308 except AttributeError:
260 config.Kernel.exec_lines = []
309 config.Kernel.exec_lines = []
261 exec_lines = config.Kernel.exec_lines
310 exec_lines = config.Kernel.exec_lines
262
311
263 if self.startup_script:
312 if self.startup_script:
264 enc = sys.getfilesystemencoding() or 'utf8'
313 enc = sys.getfilesystemencoding() or 'utf8'
265 cmd="execfile(%r)" % self.startup_script.encode(enc)
314 cmd="execfile(%r)" % self.startup_script.encode(enc)
266 exec_lines.append(cmd)
315 exec_lines.append(cmd)
267 if self.startup_command:
316 if self.startup_command:
268 exec_lines.append(self.startup_command)
317 exec_lines.append(self.startup_command)
269
318
270 # Create the underlying shell class and Engine
319 # Create the underlying shell class and Engine
271 # shell_class = import_item(self.master_config.Global.shell_class)
320 # shell_class = import_item(self.master_config.Global.shell_class)
272 # print self.config
321 # print self.config
273 try:
322 try:
274 self.engine = EngineFactory(config=config, log=self.log)
323 self.engine = EngineFactory(config=config, log=self.log)
275 except:
324 except:
276 self.log.error("Couldn't start the Engine", exc_info=True)
325 self.log.error("Couldn't start the Engine", exc_info=True)
277 self.exit(1)
326 self.exit(1)
278
327
279 def forward_logging(self):
328 def forward_logging(self):
280 if self.log_url:
329 if self.log_url:
281 self.log.info("Forwarding logging to %s", self.log_url)
330 self.log.info("Forwarding logging to %s", self.log_url)
282 context = self.engine.context
331 context = self.engine.context
283 lsock = context.socket(zmq.PUB)
332 lsock = context.socket(zmq.PUB)
284 lsock.connect(self.log_url)
333 lsock.connect(self.log_url)
285 self.log.removeHandler(self._log_handler)
286 handler = EnginePUBHandler(self.engine, lsock)
334 handler = EnginePUBHandler(self.engine, lsock)
287 handler.setLevel(self.log_level)
335 handler.setLevel(self.log_level)
288 self.log.addHandler(handler)
336 self.log.addHandler(handler)
289 self._log_handler = handler
290
337
291 def init_mpi(self):
338 def init_mpi(self):
292 global mpi
339 global mpi
293 self.mpi = MPI(config=self.config)
340 self.mpi = MPI(config=self.config)
294
341
295 mpi_import_statement = self.mpi.init_script
342 mpi_import_statement = self.mpi.init_script
296 if mpi_import_statement:
343 if mpi_import_statement:
297 try:
344 try:
298 self.log.info("Initializing MPI:")
345 self.log.info("Initializing MPI:")
299 self.log.info(mpi_import_statement)
346 self.log.info(mpi_import_statement)
300 exec mpi_import_statement in globals()
347 exec mpi_import_statement in globals()
301 except:
348 except:
302 mpi = None
349 mpi = None
303 else:
350 else:
304 mpi = None
351 mpi = None
305
352
306 @catch_config_error
353 @catch_config_error
307 def initialize(self, argv=None):
354 def initialize(self, argv=None):
308 super(IPEngineApp, self).initialize(argv)
355 super(IPEngineApp, self).initialize(argv)
309 self.init_mpi()
356 self.init_mpi()
310 self.init_engine()
357 self.init_engine()
311 self.forward_logging()
358 self.forward_logging()
312
359
313 def start(self):
360 def start(self):
314 self.engine.start()
361 self.engine.start()
315 try:
362 try:
316 self.engine.loop.start()
363 self.engine.loop.start()
317 except KeyboardInterrupt:
364 except KeyboardInterrupt:
318 self.log.critical("Engine Interrupted, shutting down...\n")
365 self.log.critical("Engine Interrupted, shutting down...\n")
319
366
320
367
321 def launch_new_instance():
368 def launch_new_instance():
322 """Create and run the IPython engine"""
369 """Create and run the IPython engine"""
323 app = IPEngineApp.instance()
370 app = IPEngineApp.instance()
324 app.initialize()
371 app.initialize()
325 app.start()
372 app.start()
326
373
327
374
328 if __name__ == '__main__':
375 if __name__ == '__main__':
329 launch_new_instance()
376 launch_new_instance()
330
377
@@ -1,1345 +1,1341 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """
2 """
3 Facilities for launching IPython processes asynchronously.
3 Facilities for launching IPython processes asynchronously.
4
4
5 Authors:
5 Authors:
6
6
7 * Brian Granger
7 * Brian Granger
8 * MinRK
8 * MinRK
9 """
9 """
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Copyright (C) 2008-2011 The IPython Development Team
12 # Copyright (C) 2008-2011 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19 # Imports
19 # Imports
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21
21
22 import copy
22 import copy
23 import logging
23 import logging
24 import os
24 import os
25 import stat
25 import stat
26 import sys
26 import time
27 import time
27
28
28 # signal imports, handling various platforms, versions
29 # signal imports, handling various platforms, versions
29
30
30 from signal import SIGINT, SIGTERM
31 from signal import SIGINT, SIGTERM
31 try:
32 try:
32 from signal import SIGKILL
33 from signal import SIGKILL
33 except ImportError:
34 except ImportError:
34 # Windows
35 # Windows
35 SIGKILL=SIGTERM
36 SIGKILL=SIGTERM
36
37
37 try:
38 try:
38 # Windows >= 2.7, 3.2
39 # Windows >= 2.7, 3.2
39 from signal import CTRL_C_EVENT as SIGINT
40 from signal import CTRL_C_EVENT as SIGINT
40 except ImportError:
41 except ImportError:
41 pass
42 pass
42
43
43 from subprocess import Popen, PIPE, STDOUT
44 from subprocess import Popen, PIPE, STDOUT
44 try:
45 try:
45 from subprocess import check_output
46 from subprocess import check_output
46 except ImportError:
47 except ImportError:
47 # pre-2.7, define check_output with Popen
48 # pre-2.7, define check_output with Popen
48 def check_output(*args, **kwargs):
49 def check_output(*args, **kwargs):
49 kwargs.update(dict(stdout=PIPE))
50 kwargs.update(dict(stdout=PIPE))
50 p = Popen(*args, **kwargs)
51 p = Popen(*args, **kwargs)
51 out,err = p.communicate()
52 out,err = p.communicate()
52 return out
53 return out
53
54
54 from zmq.eventloop import ioloop
55 from zmq.eventloop import ioloop
55
56
56 from IPython.config.application import Application
57 from IPython.config.application import Application
57 from IPython.config.configurable import LoggingConfigurable
58 from IPython.config.configurable import LoggingConfigurable
58 from IPython.utils.text import EvalFormatter
59 from IPython.utils.text import EvalFormatter
59 from IPython.utils.traitlets import (
60 from IPython.utils.traitlets import (
60 Any, Integer, CFloat, List, Unicode, Dict, Instance, HasTraits, CRegExp
61 Any, Integer, CFloat, List, Unicode, Dict, Instance, HasTraits, CRegExp
61 )
62 )
62 from IPython.utils.path import get_ipython_module_path, get_home_dir
63 from IPython.utils.path import get_home_dir
63 from IPython.utils.process import find_cmd, pycmd2argv, FindCmdError
64 from IPython.utils.process import find_cmd, FindCmdError
64
65
65 from .win32support import forward_read_events
66 from .win32support import forward_read_events
66
67
67 from .winhpcjob import IPControllerTask, IPEngineTask, IPControllerJob, IPEngineSetJob
68 from .winhpcjob import IPControllerTask, IPEngineTask, IPControllerJob, IPEngineSetJob
68
69
69 WINDOWS = os.name == 'nt'
70 WINDOWS = os.name == 'nt'
70
71
71 #-----------------------------------------------------------------------------
72 #-----------------------------------------------------------------------------
72 # Paths to the kernel apps
73 # Paths to the kernel apps
73 #-----------------------------------------------------------------------------
74 #-----------------------------------------------------------------------------
74
75
76 cmd = "from IPython.parallel.apps.%s import launch_new_instance; launch_new_instance()"
75
77
76 ipcluster_cmd_argv = pycmd2argv(get_ipython_module_path(
78 ipcluster_cmd_argv = [sys.executable, "-c", cmd % "ipclusterapp"]
77 'IPython.parallel.apps.ipclusterapp'
78 ))
79
79
80 ipengine_cmd_argv = pycmd2argv(get_ipython_module_path(
80 ipengine_cmd_argv = [sys.executable, "-c", cmd % "ipengineapp"]
81 'IPython.parallel.apps.ipengineapp'
82 ))
83
81
84 ipcontroller_cmd_argv = pycmd2argv(get_ipython_module_path(
82 ipcontroller_cmd_argv = [sys.executable, "-c", cmd % "ipcontrollerapp"]
85 'IPython.parallel.apps.ipcontrollerapp'
86 ))
87
83
88 #-----------------------------------------------------------------------------
84 #-----------------------------------------------------------------------------
89 # Base launchers and errors
85 # Base launchers and errors
90 #-----------------------------------------------------------------------------
86 #-----------------------------------------------------------------------------
91
87
92
88
93 class LauncherError(Exception):
89 class LauncherError(Exception):
94 pass
90 pass
95
91
96
92
97 class ProcessStateError(LauncherError):
93 class ProcessStateError(LauncherError):
98 pass
94 pass
99
95
100
96
101 class UnknownStatus(LauncherError):
97 class UnknownStatus(LauncherError):
102 pass
98 pass
103
99
104
100
105 class BaseLauncher(LoggingConfigurable):
101 class BaseLauncher(LoggingConfigurable):
106 """An asbtraction for starting, stopping and signaling a process."""
102 """An asbtraction for starting, stopping and signaling a process."""
107
103
108 # In all of the launchers, the work_dir is where child processes will be
104 # In all of the launchers, the work_dir is where child processes will be
109 # run. This will usually be the profile_dir, but may not be. any work_dir
105 # run. This will usually be the profile_dir, but may not be. any work_dir
110 # passed into the __init__ method will override the config value.
106 # passed into the __init__ method will override the config value.
111 # This should not be used to set the work_dir for the actual engine
107 # This should not be used to set the work_dir for the actual engine
112 # and controller. Instead, use their own config files or the
108 # and controller. Instead, use their own config files or the
113 # controller_args, engine_args attributes of the launchers to add
109 # controller_args, engine_args attributes of the launchers to add
114 # the work_dir option.
110 # the work_dir option.
115 work_dir = Unicode(u'.')
111 work_dir = Unicode(u'.')
116 loop = Instance('zmq.eventloop.ioloop.IOLoop')
112 loop = Instance('zmq.eventloop.ioloop.IOLoop')
117
113
118 start_data = Any()
114 start_data = Any()
119 stop_data = Any()
115 stop_data = Any()
120
116
121 def _loop_default(self):
117 def _loop_default(self):
122 return ioloop.IOLoop.instance()
118 return ioloop.IOLoop.instance()
123
119
124 def __init__(self, work_dir=u'.', config=None, **kwargs):
120 def __init__(self, work_dir=u'.', config=None, **kwargs):
125 super(BaseLauncher, self).__init__(work_dir=work_dir, config=config, **kwargs)
121 super(BaseLauncher, self).__init__(work_dir=work_dir, config=config, **kwargs)
126 self.state = 'before' # can be before, running, after
122 self.state = 'before' # can be before, running, after
127 self.stop_callbacks = []
123 self.stop_callbacks = []
128 self.start_data = None
124 self.start_data = None
129 self.stop_data = None
125 self.stop_data = None
130
126
131 @property
127 @property
132 def args(self):
128 def args(self):
133 """A list of cmd and args that will be used to start the process.
129 """A list of cmd and args that will be used to start the process.
134
130
135 This is what is passed to :func:`spawnProcess` and the first element
131 This is what is passed to :func:`spawnProcess` and the first element
136 will be the process name.
132 will be the process name.
137 """
133 """
138 return self.find_args()
134 return self.find_args()
139
135
140 def find_args(self):
136 def find_args(self):
141 """The ``.args`` property calls this to find the args list.
137 """The ``.args`` property calls this to find the args list.
142
138
143 Subcommand should implement this to construct the cmd and args.
139 Subcommand should implement this to construct the cmd and args.
144 """
140 """
145 raise NotImplementedError('find_args must be implemented in a subclass')
141 raise NotImplementedError('find_args must be implemented in a subclass')
146
142
147 @property
143 @property
148 def arg_str(self):
144 def arg_str(self):
149 """The string form of the program arguments."""
145 """The string form of the program arguments."""
150 return ' '.join(self.args)
146 return ' '.join(self.args)
151
147
152 @property
148 @property
153 def running(self):
149 def running(self):
154 """Am I running."""
150 """Am I running."""
155 if self.state == 'running':
151 if self.state == 'running':
156 return True
152 return True
157 else:
153 else:
158 return False
154 return False
159
155
160 def start(self):
156 def start(self):
161 """Start the process."""
157 """Start the process."""
162 raise NotImplementedError('start must be implemented in a subclass')
158 raise NotImplementedError('start must be implemented in a subclass')
163
159
164 def stop(self):
160 def stop(self):
165 """Stop the process and notify observers of stopping.
161 """Stop the process and notify observers of stopping.
166
162
167 This method will return None immediately.
163 This method will return None immediately.
168 To observe the actual process stopping, see :meth:`on_stop`.
164 To observe the actual process stopping, see :meth:`on_stop`.
169 """
165 """
170 raise NotImplementedError('stop must be implemented in a subclass')
166 raise NotImplementedError('stop must be implemented in a subclass')
171
167
172 def on_stop(self, f):
168 def on_stop(self, f):
173 """Register a callback to be called with this Launcher's stop_data
169 """Register a callback to be called with this Launcher's stop_data
174 when the process actually finishes.
170 when the process actually finishes.
175 """
171 """
176 if self.state=='after':
172 if self.state=='after':
177 return f(self.stop_data)
173 return f(self.stop_data)
178 else:
174 else:
179 self.stop_callbacks.append(f)
175 self.stop_callbacks.append(f)
180
176
181 def notify_start(self, data):
177 def notify_start(self, data):
182 """Call this to trigger startup actions.
178 """Call this to trigger startup actions.
183
179
184 This logs the process startup and sets the state to 'running'. It is
180 This logs the process startup and sets the state to 'running'. It is
185 a pass-through so it can be used as a callback.
181 a pass-through so it can be used as a callback.
186 """
182 """
187
183
188 self.log.debug('Process %r started: %r', self.args[0], data)
184 self.log.debug('Process %r started: %r', self.args[0], data)
189 self.start_data = data
185 self.start_data = data
190 self.state = 'running'
186 self.state = 'running'
191 return data
187 return data
192
188
193 def notify_stop(self, data):
189 def notify_stop(self, data):
194 """Call this to trigger process stop actions.
190 """Call this to trigger process stop actions.
195
191
196 This logs the process stopping and sets the state to 'after'. Call
192 This logs the process stopping and sets the state to 'after'. Call
197 this to trigger callbacks registered via :meth:`on_stop`."""
193 this to trigger callbacks registered via :meth:`on_stop`."""
198
194
199 self.log.debug('Process %r stopped: %r', self.args[0], data)
195 self.log.debug('Process %r stopped: %r', self.args[0], data)
200 self.stop_data = data
196 self.stop_data = data
201 self.state = 'after'
197 self.state = 'after'
202 for i in range(len(self.stop_callbacks)):
198 for i in range(len(self.stop_callbacks)):
203 d = self.stop_callbacks.pop()
199 d = self.stop_callbacks.pop()
204 d(data)
200 d(data)
205 return data
201 return data
206
202
207 def signal(self, sig):
203 def signal(self, sig):
208 """Signal the process.
204 """Signal the process.
209
205
210 Parameters
206 Parameters
211 ----------
207 ----------
212 sig : str or int
208 sig : str or int
213 'KILL', 'INT', etc., or any signal number
209 'KILL', 'INT', etc., or any signal number
214 """
210 """
215 raise NotImplementedError('signal must be implemented in a subclass')
211 raise NotImplementedError('signal must be implemented in a subclass')
216
212
217 class ClusterAppMixin(HasTraits):
213 class ClusterAppMixin(HasTraits):
218 """MixIn for cluster args as traits"""
214 """MixIn for cluster args as traits"""
219 profile_dir=Unicode('')
215 profile_dir=Unicode('')
220 cluster_id=Unicode('')
216 cluster_id=Unicode('')
221
217
222 @property
218 @property
223 def cluster_args(self):
219 def cluster_args(self):
224 return ['--profile-dir', self.profile_dir, '--cluster-id', self.cluster_id]
220 return ['--profile-dir', self.profile_dir, '--cluster-id', self.cluster_id]
225
221
226 class ControllerMixin(ClusterAppMixin):
222 class ControllerMixin(ClusterAppMixin):
227 controller_cmd = List(ipcontroller_cmd_argv, config=True,
223 controller_cmd = List(ipcontroller_cmd_argv, config=True,
228 help="""Popen command to launch ipcontroller.""")
224 help="""Popen command to launch ipcontroller.""")
229 # Command line arguments to ipcontroller.
225 # Command line arguments to ipcontroller.
230 controller_args = List(['--log-to-file','--log-level=%i' % logging.INFO], config=True,
226 controller_args = List(['--log-to-file','--log-level=%i' % logging.INFO], config=True,
231 help="""command-line args to pass to ipcontroller""")
227 help="""command-line args to pass to ipcontroller""")
232
228
233 class EngineMixin(ClusterAppMixin):
229 class EngineMixin(ClusterAppMixin):
234 engine_cmd = List(ipengine_cmd_argv, config=True,
230 engine_cmd = List(ipengine_cmd_argv, config=True,
235 help="""command to launch the Engine.""")
231 help="""command to launch the Engine.""")
236 # Command line arguments for ipengine.
232 # Command line arguments for ipengine.
237 engine_args = List(['--log-to-file','--log-level=%i' % logging.INFO], config=True,
233 engine_args = List(['--log-to-file','--log-level=%i' % logging.INFO], config=True,
238 help="command-line arguments to pass to ipengine"
234 help="command-line arguments to pass to ipengine"
239 )
235 )
240
236
241
237
242 #-----------------------------------------------------------------------------
238 #-----------------------------------------------------------------------------
243 # Local process launchers
239 # Local process launchers
244 #-----------------------------------------------------------------------------
240 #-----------------------------------------------------------------------------
245
241
246
242
247 class LocalProcessLauncher(BaseLauncher):
243 class LocalProcessLauncher(BaseLauncher):
248 """Start and stop an external process in an asynchronous manner.
244 """Start and stop an external process in an asynchronous manner.
249
245
250 This will launch the external process with a working directory of
246 This will launch the external process with a working directory of
251 ``self.work_dir``.
247 ``self.work_dir``.
252 """
248 """
253
249
254 # This is used to to construct self.args, which is passed to
250 # This is used to to construct self.args, which is passed to
255 # spawnProcess.
251 # spawnProcess.
256 cmd_and_args = List([])
252 cmd_and_args = List([])
257 poll_frequency = Integer(100) # in ms
253 poll_frequency = Integer(100) # in ms
258
254
259 def __init__(self, work_dir=u'.', config=None, **kwargs):
255 def __init__(self, work_dir=u'.', config=None, **kwargs):
260 super(LocalProcessLauncher, self).__init__(
256 super(LocalProcessLauncher, self).__init__(
261 work_dir=work_dir, config=config, **kwargs
257 work_dir=work_dir, config=config, **kwargs
262 )
258 )
263 self.process = None
259 self.process = None
264 self.poller = None
260 self.poller = None
265
261
266 def find_args(self):
262 def find_args(self):
267 return self.cmd_and_args
263 return self.cmd_and_args
268
264
269 def start(self):
265 def start(self):
270 self.log.debug("Starting %s: %r", self.__class__.__name__, self.args)
266 self.log.debug("Starting %s: %r", self.__class__.__name__, self.args)
271 if self.state == 'before':
267 if self.state == 'before':
272 self.process = Popen(self.args,
268 self.process = Popen(self.args,
273 stdout=PIPE,stderr=PIPE,stdin=PIPE,
269 stdout=PIPE,stderr=PIPE,stdin=PIPE,
274 env=os.environ,
270 env=os.environ,
275 cwd=self.work_dir
271 cwd=self.work_dir
276 )
272 )
277 if WINDOWS:
273 if WINDOWS:
278 self.stdout = forward_read_events(self.process.stdout)
274 self.stdout = forward_read_events(self.process.stdout)
279 self.stderr = forward_read_events(self.process.stderr)
275 self.stderr = forward_read_events(self.process.stderr)
280 else:
276 else:
281 self.stdout = self.process.stdout.fileno()
277 self.stdout = self.process.stdout.fileno()
282 self.stderr = self.process.stderr.fileno()
278 self.stderr = self.process.stderr.fileno()
283 self.loop.add_handler(self.stdout, self.handle_stdout, self.loop.READ)
279 self.loop.add_handler(self.stdout, self.handle_stdout, self.loop.READ)
284 self.loop.add_handler(self.stderr, self.handle_stderr, self.loop.READ)
280 self.loop.add_handler(self.stderr, self.handle_stderr, self.loop.READ)
285 self.poller = ioloop.PeriodicCallback(self.poll, self.poll_frequency, self.loop)
281 self.poller = ioloop.PeriodicCallback(self.poll, self.poll_frequency, self.loop)
286 self.poller.start()
282 self.poller.start()
287 self.notify_start(self.process.pid)
283 self.notify_start(self.process.pid)
288 else:
284 else:
289 s = 'The process was already started and has state: %r' % self.state
285 s = 'The process was already started and has state: %r' % self.state
290 raise ProcessStateError(s)
286 raise ProcessStateError(s)
291
287
292 def stop(self):
288 def stop(self):
293 return self.interrupt_then_kill()
289 return self.interrupt_then_kill()
294
290
295 def signal(self, sig):
291 def signal(self, sig):
296 if self.state == 'running':
292 if self.state == 'running':
297 if WINDOWS and sig != SIGINT:
293 if WINDOWS and sig != SIGINT:
298 # use Windows tree-kill for better child cleanup
294 # use Windows tree-kill for better child cleanup
299 check_output(['taskkill', '-pid', str(self.process.pid), '-t', '-f'])
295 check_output(['taskkill', '-pid', str(self.process.pid), '-t', '-f'])
300 else:
296 else:
301 self.process.send_signal(sig)
297 self.process.send_signal(sig)
302
298
303 def interrupt_then_kill(self, delay=2.0):
299 def interrupt_then_kill(self, delay=2.0):
304 """Send INT, wait a delay and then send KILL."""
300 """Send INT, wait a delay and then send KILL."""
305 try:
301 try:
306 self.signal(SIGINT)
302 self.signal(SIGINT)
307 except Exception:
303 except Exception:
308 self.log.debug("interrupt failed")
304 self.log.debug("interrupt failed")
309 pass
305 pass
310 self.killer = ioloop.DelayedCallback(lambda : self.signal(SIGKILL), delay*1000, self.loop)
306 self.killer = ioloop.DelayedCallback(lambda : self.signal(SIGKILL), delay*1000, self.loop)
311 self.killer.start()
307 self.killer.start()
312
308
313 # callbacks, etc:
309 # callbacks, etc:
314
310
315 def handle_stdout(self, fd, events):
311 def handle_stdout(self, fd, events):
316 if WINDOWS:
312 if WINDOWS:
317 line = self.stdout.recv()
313 line = self.stdout.recv()
318 else:
314 else:
319 line = self.process.stdout.readline()
315 line = self.process.stdout.readline()
320 # a stopped process will be readable but return empty strings
316 # a stopped process will be readable but return empty strings
321 if line:
317 if line:
322 self.log.debug(line[:-1])
318 self.log.debug(line[:-1])
323 else:
319 else:
324 self.poll()
320 self.poll()
325
321
326 def handle_stderr(self, fd, events):
322 def handle_stderr(self, fd, events):
327 if WINDOWS:
323 if WINDOWS:
328 line = self.stderr.recv()
324 line = self.stderr.recv()
329 else:
325 else:
330 line = self.process.stderr.readline()
326 line = self.process.stderr.readline()
331 # a stopped process will be readable but return empty strings
327 # a stopped process will be readable but return empty strings
332 if line:
328 if line:
333 self.log.debug(line[:-1])
329 self.log.debug(line[:-1])
334 else:
330 else:
335 self.poll()
331 self.poll()
336
332
337 def poll(self):
333 def poll(self):
338 status = self.process.poll()
334 status = self.process.poll()
339 if status is not None:
335 if status is not None:
340 self.poller.stop()
336 self.poller.stop()
341 self.loop.remove_handler(self.stdout)
337 self.loop.remove_handler(self.stdout)
342 self.loop.remove_handler(self.stderr)
338 self.loop.remove_handler(self.stderr)
343 self.notify_stop(dict(exit_code=status, pid=self.process.pid))
339 self.notify_stop(dict(exit_code=status, pid=self.process.pid))
344 return status
340 return status
345
341
346 class LocalControllerLauncher(LocalProcessLauncher, ControllerMixin):
342 class LocalControllerLauncher(LocalProcessLauncher, ControllerMixin):
347 """Launch a controller as a regular external process."""
343 """Launch a controller as a regular external process."""
348
344
349 def find_args(self):
345 def find_args(self):
350 return self.controller_cmd + self.cluster_args + self.controller_args
346 return self.controller_cmd + self.cluster_args + self.controller_args
351
347
352 def start(self):
348 def start(self):
353 """Start the controller by profile_dir."""
349 """Start the controller by profile_dir."""
354 return super(LocalControllerLauncher, self).start()
350 return super(LocalControllerLauncher, self).start()
355
351
356
352
357 class LocalEngineLauncher(LocalProcessLauncher, EngineMixin):
353 class LocalEngineLauncher(LocalProcessLauncher, EngineMixin):
358 """Launch a single engine as a regular externall process."""
354 """Launch a single engine as a regular externall process."""
359
355
360 def find_args(self):
356 def find_args(self):
361 return self.engine_cmd + self.cluster_args + self.engine_args
357 return self.engine_cmd + self.cluster_args + self.engine_args
362
358
363
359
364 class LocalEngineSetLauncher(LocalEngineLauncher):
360 class LocalEngineSetLauncher(LocalEngineLauncher):
365 """Launch a set of engines as regular external processes."""
361 """Launch a set of engines as regular external processes."""
366
362
367 delay = CFloat(0.1, config=True,
363 delay = CFloat(0.1, config=True,
368 help="""delay (in seconds) between starting each engine after the first.
364 help="""delay (in seconds) between starting each engine after the first.
369 This can help force the engines to get their ids in order, or limit
365 This can help force the engines to get their ids in order, or limit
370 process flood when starting many engines."""
366 process flood when starting many engines."""
371 )
367 )
372
368
373 # launcher class
369 # launcher class
374 launcher_class = LocalEngineLauncher
370 launcher_class = LocalEngineLauncher
375
371
376 launchers = Dict()
372 launchers = Dict()
377 stop_data = Dict()
373 stop_data = Dict()
378
374
379 def __init__(self, work_dir=u'.', config=None, **kwargs):
375 def __init__(self, work_dir=u'.', config=None, **kwargs):
380 super(LocalEngineSetLauncher, self).__init__(
376 super(LocalEngineSetLauncher, self).__init__(
381 work_dir=work_dir, config=config, **kwargs
377 work_dir=work_dir, config=config, **kwargs
382 )
378 )
383 self.stop_data = {}
379 self.stop_data = {}
384
380
385 def start(self, n):
381 def start(self, n):
386 """Start n engines by profile or profile_dir."""
382 """Start n engines by profile or profile_dir."""
387 dlist = []
383 dlist = []
388 for i in range(n):
384 for i in range(n):
389 if i > 0:
385 if i > 0:
390 time.sleep(self.delay)
386 time.sleep(self.delay)
391 el = self.launcher_class(work_dir=self.work_dir, config=self.config, log=self.log,
387 el = self.launcher_class(work_dir=self.work_dir, config=self.config, log=self.log,
392 profile_dir=self.profile_dir, cluster_id=self.cluster_id,
388 profile_dir=self.profile_dir, cluster_id=self.cluster_id,
393 )
389 )
394
390
395 # Copy the engine args over to each engine launcher.
391 # Copy the engine args over to each engine launcher.
396 el.engine_cmd = copy.deepcopy(self.engine_cmd)
392 el.engine_cmd = copy.deepcopy(self.engine_cmd)
397 el.engine_args = copy.deepcopy(self.engine_args)
393 el.engine_args = copy.deepcopy(self.engine_args)
398 el.on_stop(self._notice_engine_stopped)
394 el.on_stop(self._notice_engine_stopped)
399 d = el.start()
395 d = el.start()
400 self.launchers[i] = el
396 self.launchers[i] = el
401 dlist.append(d)
397 dlist.append(d)
402 self.notify_start(dlist)
398 self.notify_start(dlist)
403 return dlist
399 return dlist
404
400
405 def find_args(self):
401 def find_args(self):
406 return ['engine set']
402 return ['engine set']
407
403
408 def signal(self, sig):
404 def signal(self, sig):
409 dlist = []
405 dlist = []
410 for el in self.launchers.itervalues():
406 for el in self.launchers.itervalues():
411 d = el.signal(sig)
407 d = el.signal(sig)
412 dlist.append(d)
408 dlist.append(d)
413 return dlist
409 return dlist
414
410
415 def interrupt_then_kill(self, delay=1.0):
411 def interrupt_then_kill(self, delay=1.0):
416 dlist = []
412 dlist = []
417 for el in self.launchers.itervalues():
413 for el in self.launchers.itervalues():
418 d = el.interrupt_then_kill(delay)
414 d = el.interrupt_then_kill(delay)
419 dlist.append(d)
415 dlist.append(d)
420 return dlist
416 return dlist
421
417
422 def stop(self):
418 def stop(self):
423 return self.interrupt_then_kill()
419 return self.interrupt_then_kill()
424
420
425 def _notice_engine_stopped(self, data):
421 def _notice_engine_stopped(self, data):
426 pid = data['pid']
422 pid = data['pid']
427 for idx,el in self.launchers.iteritems():
423 for idx,el in self.launchers.iteritems():
428 if el.process.pid == pid:
424 if el.process.pid == pid:
429 break
425 break
430 self.launchers.pop(idx)
426 self.launchers.pop(idx)
431 self.stop_data[idx] = data
427 self.stop_data[idx] = data
432 if not self.launchers:
428 if not self.launchers:
433 self.notify_stop(self.stop_data)
429 self.notify_stop(self.stop_data)
434
430
435
431
436 #-----------------------------------------------------------------------------
432 #-----------------------------------------------------------------------------
437 # MPI launchers
433 # MPI launchers
438 #-----------------------------------------------------------------------------
434 #-----------------------------------------------------------------------------
439
435
440
436
441 class MPILauncher(LocalProcessLauncher):
437 class MPILauncher(LocalProcessLauncher):
442 """Launch an external process using mpiexec."""
438 """Launch an external process using mpiexec."""
443
439
444 mpi_cmd = List(['mpiexec'], config=True,
440 mpi_cmd = List(['mpiexec'], config=True,
445 help="The mpiexec command to use in starting the process."
441 help="The mpiexec command to use in starting the process."
446 )
442 )
447 mpi_args = List([], config=True,
443 mpi_args = List([], config=True,
448 help="The command line arguments to pass to mpiexec."
444 help="The command line arguments to pass to mpiexec."
449 )
445 )
450 program = List(['date'],
446 program = List(['date'],
451 help="The program to start via mpiexec.")
447 help="The program to start via mpiexec.")
452 program_args = List([],
448 program_args = List([],
453 help="The command line argument to the program."
449 help="The command line argument to the program."
454 )
450 )
455 n = Integer(1)
451 n = Integer(1)
456
452
457 def __init__(self, *args, **kwargs):
453 def __init__(self, *args, **kwargs):
458 # deprecation for old MPIExec names:
454 # deprecation for old MPIExec names:
459 config = kwargs.get('config', {})
455 config = kwargs.get('config', {})
460 for oldname in ('MPIExecLauncher', 'MPIExecControllerLauncher', 'MPIExecEngineSetLauncher'):
456 for oldname in ('MPIExecLauncher', 'MPIExecControllerLauncher', 'MPIExecEngineSetLauncher'):
461 deprecated = config.get(oldname)
457 deprecated = config.get(oldname)
462 if deprecated:
458 if deprecated:
463 newname = oldname.replace('MPIExec', 'MPI')
459 newname = oldname.replace('MPIExec', 'MPI')
464 config[newname].update(deprecated)
460 config[newname].update(deprecated)
465 self.log.warn("WARNING: %s name has been deprecated, use %s", oldname, newname)
461 self.log.warn("WARNING: %s name has been deprecated, use %s", oldname, newname)
466
462
467 super(MPILauncher, self).__init__(*args, **kwargs)
463 super(MPILauncher, self).__init__(*args, **kwargs)
468
464
469 def find_args(self):
465 def find_args(self):
470 """Build self.args using all the fields."""
466 """Build self.args using all the fields."""
471 return self.mpi_cmd + ['-n', str(self.n)] + self.mpi_args + \
467 return self.mpi_cmd + ['-n', str(self.n)] + self.mpi_args + \
472 self.program + self.program_args
468 self.program + self.program_args
473
469
474 def start(self, n):
470 def start(self, n):
475 """Start n instances of the program using mpiexec."""
471 """Start n instances of the program using mpiexec."""
476 self.n = n
472 self.n = n
477 return super(MPILauncher, self).start()
473 return super(MPILauncher, self).start()
478
474
479
475
480 class MPIControllerLauncher(MPILauncher, ControllerMixin):
476 class MPIControllerLauncher(MPILauncher, ControllerMixin):
481 """Launch a controller using mpiexec."""
477 """Launch a controller using mpiexec."""
482
478
483 # alias back to *non-configurable* program[_args] for use in find_args()
479 # alias back to *non-configurable* program[_args] for use in find_args()
484 # this way all Controller/EngineSetLaunchers have the same form, rather
480 # this way all Controller/EngineSetLaunchers have the same form, rather
485 # than *some* having `program_args` and others `controller_args`
481 # than *some* having `program_args` and others `controller_args`
486 @property
482 @property
487 def program(self):
483 def program(self):
488 return self.controller_cmd
484 return self.controller_cmd
489
485
490 @property
486 @property
491 def program_args(self):
487 def program_args(self):
492 return self.cluster_args + self.controller_args
488 return self.cluster_args + self.controller_args
493
489
494 def start(self):
490 def start(self):
495 """Start the controller by profile_dir."""
491 """Start the controller by profile_dir."""
496 return super(MPIControllerLauncher, self).start(1)
492 return super(MPIControllerLauncher, self).start(1)
497
493
498
494
499 class MPIEngineSetLauncher(MPILauncher, EngineMixin):
495 class MPIEngineSetLauncher(MPILauncher, EngineMixin):
500 """Launch engines using mpiexec"""
496 """Launch engines using mpiexec"""
501
497
502 # alias back to *non-configurable* program[_args] for use in find_args()
498 # alias back to *non-configurable* program[_args] for use in find_args()
503 # this way all Controller/EngineSetLaunchers have the same form, rather
499 # this way all Controller/EngineSetLaunchers have the same form, rather
504 # than *some* having `program_args` and others `controller_args`
500 # than *some* having `program_args` and others `controller_args`
505 @property
501 @property
506 def program(self):
502 def program(self):
507 return self.engine_cmd
503 return self.engine_cmd
508
504
509 @property
505 @property
510 def program_args(self):
506 def program_args(self):
511 return self.cluster_args + self.engine_args
507 return self.cluster_args + self.engine_args
512
508
513 def start(self, n):
509 def start(self, n):
514 """Start n engines by profile or profile_dir."""
510 """Start n engines by profile or profile_dir."""
515 self.n = n
511 self.n = n
516 return super(MPIEngineSetLauncher, self).start(n)
512 return super(MPIEngineSetLauncher, self).start(n)
517
513
518 # deprecated MPIExec names
514 # deprecated MPIExec names
519 class DeprecatedMPILauncher(object):
515 class DeprecatedMPILauncher(object):
520 def warn(self):
516 def warn(self):
521 oldname = self.__class__.__name__
517 oldname = self.__class__.__name__
522 newname = oldname.replace('MPIExec', 'MPI')
518 newname = oldname.replace('MPIExec', 'MPI')
523 self.log.warn("WARNING: %s name is deprecated, use %s", oldname, newname)
519 self.log.warn("WARNING: %s name is deprecated, use %s", oldname, newname)
524
520
525 class MPIExecLauncher(MPILauncher, DeprecatedMPILauncher):
521 class MPIExecLauncher(MPILauncher, DeprecatedMPILauncher):
526 """Deprecated, use MPILauncher"""
522 """Deprecated, use MPILauncher"""
527 def __init__(self, *args, **kwargs):
523 def __init__(self, *args, **kwargs):
528 super(MPIExecLauncher, self).__init__(*args, **kwargs)
524 super(MPIExecLauncher, self).__init__(*args, **kwargs)
529 self.warn()
525 self.warn()
530
526
531 class MPIExecControllerLauncher(MPIControllerLauncher, DeprecatedMPILauncher):
527 class MPIExecControllerLauncher(MPIControllerLauncher, DeprecatedMPILauncher):
532 """Deprecated, use MPIControllerLauncher"""
528 """Deprecated, use MPIControllerLauncher"""
533 def __init__(self, *args, **kwargs):
529 def __init__(self, *args, **kwargs):
534 super(MPIExecControllerLauncher, self).__init__(*args, **kwargs)
530 super(MPIExecControllerLauncher, self).__init__(*args, **kwargs)
535 self.warn()
531 self.warn()
536
532
537 class MPIExecEngineSetLauncher(MPIEngineSetLauncher, DeprecatedMPILauncher):
533 class MPIExecEngineSetLauncher(MPIEngineSetLauncher, DeprecatedMPILauncher):
538 """Deprecated, use MPIEngineSetLauncher"""
534 """Deprecated, use MPIEngineSetLauncher"""
539 def __init__(self, *args, **kwargs):
535 def __init__(self, *args, **kwargs):
540 super(MPIExecEngineSetLauncher, self).__init__(*args, **kwargs)
536 super(MPIExecEngineSetLauncher, self).__init__(*args, **kwargs)
541 self.warn()
537 self.warn()
542
538
543
539
544 #-----------------------------------------------------------------------------
540 #-----------------------------------------------------------------------------
545 # SSH launchers
541 # SSH launchers
546 #-----------------------------------------------------------------------------
542 #-----------------------------------------------------------------------------
547
543
548 # TODO: Get SSH Launcher back to level of sshx in 0.10.2
544 # TODO: Get SSH Launcher back to level of sshx in 0.10.2
549
545
550 class SSHLauncher(LocalProcessLauncher):
546 class SSHLauncher(LocalProcessLauncher):
551 """A minimal launcher for ssh.
547 """A minimal launcher for ssh.
552
548
553 To be useful this will probably have to be extended to use the ``sshx``
549 To be useful this will probably have to be extended to use the ``sshx``
554 idea for environment variables. There could be other things this needs
550 idea for environment variables. There could be other things this needs
555 as well.
551 as well.
556 """
552 """
557
553
558 ssh_cmd = List(['ssh'], config=True,
554 ssh_cmd = List(['ssh'], config=True,
559 help="command for starting ssh")
555 help="command for starting ssh")
560 ssh_args = List(['-tt'], config=True,
556 ssh_args = List(['-tt'], config=True,
561 help="args to pass to ssh")
557 help="args to pass to ssh")
562 scp_cmd = List(['scp'], config=True,
558 scp_cmd = List(['scp'], config=True,
563 help="command for sending files")
559 help="command for sending files")
564 program = List(['date'],
560 program = List(['date'],
565 help="Program to launch via ssh")
561 help="Program to launch via ssh")
566 program_args = List([],
562 program_args = List([],
567 help="args to pass to remote program")
563 help="args to pass to remote program")
568 hostname = Unicode('', config=True,
564 hostname = Unicode('', config=True,
569 help="hostname on which to launch the program")
565 help="hostname on which to launch the program")
570 user = Unicode('', config=True,
566 user = Unicode('', config=True,
571 help="username for ssh")
567 help="username for ssh")
572 location = Unicode('', config=True,
568 location = Unicode('', config=True,
573 help="user@hostname location for ssh in one setting")
569 help="user@hostname location for ssh in one setting")
574 to_fetch = List([], config=True,
570 to_fetch = List([], config=True,
575 help="List of (remote, local) files to fetch after starting")
571 help="List of (remote, local) files to fetch after starting")
576 to_send = List([], config=True,
572 to_send = List([], config=True,
577 help="List of (local, remote) files to send before starting")
573 help="List of (local, remote) files to send before starting")
578
574
579 def _hostname_changed(self, name, old, new):
575 def _hostname_changed(self, name, old, new):
580 if self.user:
576 if self.user:
581 self.location = u'%s@%s' % (self.user, new)
577 self.location = u'%s@%s' % (self.user, new)
582 else:
578 else:
583 self.location = new
579 self.location = new
584
580
585 def _user_changed(self, name, old, new):
581 def _user_changed(self, name, old, new):
586 self.location = u'%s@%s' % (new, self.hostname)
582 self.location = u'%s@%s' % (new, self.hostname)
587
583
588 def find_args(self):
584 def find_args(self):
589 return self.ssh_cmd + self.ssh_args + [self.location] + \
585 return self.ssh_cmd + self.ssh_args + [self.location] + \
590 self.program + self.program_args
586 self.program + self.program_args
591
587
592 def _send_file(self, local, remote):
588 def _send_file(self, local, remote):
593 """send a single file"""
589 """send a single file"""
594 remote = "%s:%s" % (self.location, remote)
590 remote = "%s:%s" % (self.location, remote)
595 for i in range(10):
591 for i in range(10):
596 if not os.path.exists(local):
592 if not os.path.exists(local):
597 self.log.debug("waiting for %s" % local)
593 self.log.debug("waiting for %s" % local)
598 time.sleep(1)
594 time.sleep(1)
599 else:
595 else:
600 break
596 break
601 self.log.info("sending %s to %s", local, remote)
597 self.log.info("sending %s to %s", local, remote)
602 check_output(self.scp_cmd + [local, remote])
598 check_output(self.scp_cmd + [local, remote])
603
599
604 def send_files(self):
600 def send_files(self):
605 """send our files (called before start)"""
601 """send our files (called before start)"""
606 if not self.to_send:
602 if not self.to_send:
607 return
603 return
608 for local_file, remote_file in self.to_send:
604 for local_file, remote_file in self.to_send:
609 self._send_file(local_file, remote_file)
605 self._send_file(local_file, remote_file)
610
606
611 def _fetch_file(self, remote, local):
607 def _fetch_file(self, remote, local):
612 """fetch a single file"""
608 """fetch a single file"""
613 full_remote = "%s:%s" % (self.location, remote)
609 full_remote = "%s:%s" % (self.location, remote)
614 self.log.info("fetching %s from %s", local, full_remote)
610 self.log.info("fetching %s from %s", local, full_remote)
615 for i in range(10):
611 for i in range(10):
616 # wait up to 10s for remote file to exist
612 # wait up to 10s for remote file to exist
617 check = check_output(self.ssh_cmd + self.ssh_args + \
613 check = check_output(self.ssh_cmd + self.ssh_args + \
618 [self.location, 'test -e', remote, "&& echo 'yes' || echo 'no'"])
614 [self.location, 'test -e', remote, "&& echo 'yes' || echo 'no'"])
619 check = check.strip()
615 check = check.strip()
620 if check == 'no':
616 if check == 'no':
621 time.sleep(1)
617 time.sleep(1)
622 elif check == 'yes':
618 elif check == 'yes':
623 break
619 break
624 check_output(self.scp_cmd + [full_remote, local])
620 check_output(self.scp_cmd + [full_remote, local])
625
621
626 def fetch_files(self):
622 def fetch_files(self):
627 """fetch remote files (called after start)"""
623 """fetch remote files (called after start)"""
628 if not self.to_fetch:
624 if not self.to_fetch:
629 return
625 return
630 for remote_file, local_file in self.to_fetch:
626 for remote_file, local_file in self.to_fetch:
631 self._fetch_file(remote_file, local_file)
627 self._fetch_file(remote_file, local_file)
632
628
633 def start(self, hostname=None, user=None):
629 def start(self, hostname=None, user=None):
634 if hostname is not None:
630 if hostname is not None:
635 self.hostname = hostname
631 self.hostname = hostname
636 if user is not None:
632 if user is not None:
637 self.user = user
633 self.user = user
638
634
639 self.send_files()
635 self.send_files()
640 super(SSHLauncher, self).start()
636 super(SSHLauncher, self).start()
641 self.fetch_files()
637 self.fetch_files()
642
638
643 def signal(self, sig):
639 def signal(self, sig):
644 if self.state == 'running':
640 if self.state == 'running':
645 # send escaped ssh connection-closer
641 # send escaped ssh connection-closer
646 self.process.stdin.write('~.')
642 self.process.stdin.write('~.')
647 self.process.stdin.flush()
643 self.process.stdin.flush()
648
644
649 class SSHClusterLauncher(SSHLauncher):
645 class SSHClusterLauncher(SSHLauncher):
650
646
651 remote_profile_dir = Unicode('', config=True,
647 remote_profile_dir = Unicode('', config=True,
652 help="""The remote profile_dir to use.
648 help="""The remote profile_dir to use.
653
649
654 If not specified, use calling profile, stripping out possible leading homedir.
650 If not specified, use calling profile, stripping out possible leading homedir.
655 """)
651 """)
656
652
657 def _remote_profie_dir_default(self):
653 def _remote_profie_dir_default(self):
658 """turns /home/you/.ipython/profile_foo into .ipython/profile_foo
654 """turns /home/you/.ipython/profile_foo into .ipython/profile_foo
659 """
655 """
660 home = get_home_dir()
656 home = get_home_dir()
661 if not home.endswith('/'):
657 if not home.endswith('/'):
662 home = home+'/'
658 home = home+'/'
663
659
664 if self.profile_dir.startswith(home):
660 if self.profile_dir.startswith(home):
665 return self.profile_dir[len(home):]
661 return self.profile_dir[len(home):]
666 else:
662 else:
667 return self.profile_dir
663 return self.profile_dir
668
664
669 def _cluster_id_changed(self, name, old, new):
665 def _cluster_id_changed(self, name, old, new):
670 if new:
666 if new:
671 raise ValueError("cluster id not supported by SSH launchers")
667 raise ValueError("cluster id not supported by SSH launchers")
672
668
673 @property
669 @property
674 def cluster_args(self):
670 def cluster_args(self):
675 return ['--profile-dir', self.remote_profile_dir]
671 return ['--profile-dir', self.remote_profile_dir]
676
672
677 class SSHControllerLauncher(SSHClusterLauncher, ControllerMixin):
673 class SSHControllerLauncher(SSHClusterLauncher, ControllerMixin):
678
674
679 # alias back to *non-configurable* program[_args] for use in find_args()
675 # alias back to *non-configurable* program[_args] for use in find_args()
680 # this way all Controller/EngineSetLaunchers have the same form, rather
676 # this way all Controller/EngineSetLaunchers have the same form, rather
681 # than *some* having `program_args` and others `controller_args`
677 # than *some* having `program_args` and others `controller_args`
682
678
683 def _controller_cmd_default(self):
679 def _controller_cmd_default(self):
684 return ['ipcontroller']
680 return ['ipcontroller']
685
681
686 @property
682 @property
687 def program(self):
683 def program(self):
688 return self.controller_cmd
684 return self.controller_cmd
689
685
690 @property
686 @property
691 def program_args(self):
687 def program_args(self):
692 return self.cluster_args + self.controller_args
688 return self.cluster_args + self.controller_args
693
689
694 def _to_fetch_default(self):
690 def _to_fetch_default(self):
695 return [
691 return [
696 (os.path.join(self.remote_profile_dir, 'security', cf),
692 (os.path.join(self.remote_profile_dir, 'security', cf),
697 os.path.join(self.profile_dir, 'security', cf),)
693 os.path.join(self.profile_dir, 'security', cf),)
698 for cf in ('ipcontroller-client.json', 'ipcontroller-engine.json')
694 for cf in ('ipcontroller-client.json', 'ipcontroller-engine.json')
699 ]
695 ]
700
696
701 class SSHEngineLauncher(SSHClusterLauncher, EngineMixin):
697 class SSHEngineLauncher(SSHClusterLauncher, EngineMixin):
702
698
703 # alias back to *non-configurable* program[_args] for use in find_args()
699 # alias back to *non-configurable* program[_args] for use in find_args()
704 # this way all Controller/EngineSetLaunchers have the same form, rather
700 # this way all Controller/EngineSetLaunchers have the same form, rather
705 # than *some* having `program_args` and others `controller_args`
701 # than *some* having `program_args` and others `controller_args`
706
702
707 def _engine_cmd_default(self):
703 def _engine_cmd_default(self):
708 return ['ipengine']
704 return ['ipengine']
709
705
710 @property
706 @property
711 def program(self):
707 def program(self):
712 return self.engine_cmd
708 return self.engine_cmd
713
709
714 @property
710 @property
715 def program_args(self):
711 def program_args(self):
716 return self.cluster_args + self.engine_args
712 return self.cluster_args + self.engine_args
717
713
718 def _to_send_default(self):
714 def _to_send_default(self):
719 return [
715 return [
720 (os.path.join(self.profile_dir, 'security', cf),
716 (os.path.join(self.profile_dir, 'security', cf),
721 os.path.join(self.remote_profile_dir, 'security', cf))
717 os.path.join(self.remote_profile_dir, 'security', cf))
722 for cf in ('ipcontroller-client.json', 'ipcontroller-engine.json')
718 for cf in ('ipcontroller-client.json', 'ipcontroller-engine.json')
723 ]
719 ]
724
720
725
721
726 class SSHEngineSetLauncher(LocalEngineSetLauncher):
722 class SSHEngineSetLauncher(LocalEngineSetLauncher):
727 launcher_class = SSHEngineLauncher
723 launcher_class = SSHEngineLauncher
728 engines = Dict(config=True,
724 engines = Dict(config=True,
729 help="""dict of engines to launch. This is a dict by hostname of ints,
725 help="""dict of engines to launch. This is a dict by hostname of ints,
730 corresponding to the number of engines to start on that host.""")
726 corresponding to the number of engines to start on that host.""")
731
727
732 @property
728 @property
733 def engine_count(self):
729 def engine_count(self):
734 """determine engine count from `engines` dict"""
730 """determine engine count from `engines` dict"""
735 count = 0
731 count = 0
736 for n in self.engines.itervalues():
732 for n in self.engines.itervalues():
737 if isinstance(n, (tuple,list)):
733 if isinstance(n, (tuple,list)):
738 n,args = n
734 n,args = n
739 count += n
735 count += n
740 return count
736 return count
741
737
742 def start(self, n):
738 def start(self, n):
743 """Start engines by profile or profile_dir.
739 """Start engines by profile or profile_dir.
744 `n` is ignored, and the `engines` config property is used instead.
740 `n` is ignored, and the `engines` config property is used instead.
745 """
741 """
746
742
747 dlist = []
743 dlist = []
748 for host, n in self.engines.iteritems():
744 for host, n in self.engines.iteritems():
749 if isinstance(n, (tuple, list)):
745 if isinstance(n, (tuple, list)):
750 n, args = n
746 n, args = n
751 else:
747 else:
752 args = copy.deepcopy(self.engine_args)
748 args = copy.deepcopy(self.engine_args)
753
749
754 if '@' in host:
750 if '@' in host:
755 user,host = host.split('@',1)
751 user,host = host.split('@',1)
756 else:
752 else:
757 user=None
753 user=None
758 for i in range(n):
754 for i in range(n):
759 if i > 0:
755 if i > 0:
760 time.sleep(self.delay)
756 time.sleep(self.delay)
761 el = self.launcher_class(work_dir=self.work_dir, config=self.config, log=self.log,
757 el = self.launcher_class(work_dir=self.work_dir, config=self.config, log=self.log,
762 profile_dir=self.profile_dir, cluster_id=self.cluster_id,
758 profile_dir=self.profile_dir, cluster_id=self.cluster_id,
763 )
759 )
764 if i > 0:
760 if i > 0:
765 # only send files for the first engine on each host
761 # only send files for the first engine on each host
766 el.to_send = []
762 el.to_send = []
767
763
768 # Copy the engine args over to each engine launcher.
764 # Copy the engine args over to each engine launcher.
769 el.engine_cmd = self.engine_cmd
765 el.engine_cmd = self.engine_cmd
770 el.engine_args = args
766 el.engine_args = args
771 el.on_stop(self._notice_engine_stopped)
767 el.on_stop(self._notice_engine_stopped)
772 d = el.start(user=user, hostname=host)
768 d = el.start(user=user, hostname=host)
773 self.launchers[ "%s/%i" % (host,i) ] = el
769 self.launchers[ "%s/%i" % (host,i) ] = el
774 dlist.append(d)
770 dlist.append(d)
775 self.notify_start(dlist)
771 self.notify_start(dlist)
776 return dlist
772 return dlist
777
773
778
774
779 class SSHProxyEngineSetLauncher(SSHClusterLauncher):
775 class SSHProxyEngineSetLauncher(SSHClusterLauncher):
780 """Launcher for calling
776 """Launcher for calling
781 `ipcluster engines` on a remote machine.
777 `ipcluster engines` on a remote machine.
782
778
783 Requires that remote profile is already configured.
779 Requires that remote profile is already configured.
784 """
780 """
785
781
786 n = Integer()
782 n = Integer()
787 ipcluster_cmd = List(['ipcluster'], config=True)
783 ipcluster_cmd = List(['ipcluster'], config=True)
788
784
789 @property
785 @property
790 def program(self):
786 def program(self):
791 return self.ipcluster_cmd + ['engines']
787 return self.ipcluster_cmd + ['engines']
792
788
793 @property
789 @property
794 def program_args(self):
790 def program_args(self):
795 return ['-n', str(self.n), '--profile-dir', self.remote_profile_dir]
791 return ['-n', str(self.n), '--profile-dir', self.remote_profile_dir]
796
792
797 def _to_send_default(self):
793 def _to_send_default(self):
798 return [
794 return [
799 (os.path.join(self.profile_dir, 'security', cf),
795 (os.path.join(self.profile_dir, 'security', cf),
800 os.path.join(self.remote_profile_dir, 'security', cf))
796 os.path.join(self.remote_profile_dir, 'security', cf))
801 for cf in ('ipcontroller-client.json', 'ipcontroller-engine.json')
797 for cf in ('ipcontroller-client.json', 'ipcontroller-engine.json')
802 ]
798 ]
803
799
804 def start(self, n):
800 def start(self, n):
805 self.n = n
801 self.n = n
806 super(SSHProxyEngineSetLauncher, self).start()
802 super(SSHProxyEngineSetLauncher, self).start()
807
803
808
804
809 #-----------------------------------------------------------------------------
805 #-----------------------------------------------------------------------------
810 # Windows HPC Server 2008 scheduler launchers
806 # Windows HPC Server 2008 scheduler launchers
811 #-----------------------------------------------------------------------------
807 #-----------------------------------------------------------------------------
812
808
813
809
814 # This is only used on Windows.
810 # This is only used on Windows.
815 def find_job_cmd():
811 def find_job_cmd():
816 if WINDOWS:
812 if WINDOWS:
817 try:
813 try:
818 return find_cmd('job')
814 return find_cmd('job')
819 except (FindCmdError, ImportError):
815 except (FindCmdError, ImportError):
820 # ImportError will be raised if win32api is not installed
816 # ImportError will be raised if win32api is not installed
821 return 'job'
817 return 'job'
822 else:
818 else:
823 return 'job'
819 return 'job'
824
820
825
821
826 class WindowsHPCLauncher(BaseLauncher):
822 class WindowsHPCLauncher(BaseLauncher):
827
823
828 job_id_regexp = CRegExp(r'\d+', config=True,
824 job_id_regexp = CRegExp(r'\d+', config=True,
829 help="""A regular expression used to get the job id from the output of the
825 help="""A regular expression used to get the job id from the output of the
830 submit_command. """
826 submit_command. """
831 )
827 )
832 job_file_name = Unicode(u'ipython_job.xml', config=True,
828 job_file_name = Unicode(u'ipython_job.xml', config=True,
833 help="The filename of the instantiated job script.")
829 help="The filename of the instantiated job script.")
834 # The full path to the instantiated job script. This gets made dynamically
830 # The full path to the instantiated job script. This gets made dynamically
835 # by combining the work_dir with the job_file_name.
831 # by combining the work_dir with the job_file_name.
836 job_file = Unicode(u'')
832 job_file = Unicode(u'')
837 scheduler = Unicode('', config=True,
833 scheduler = Unicode('', config=True,
838 help="The hostname of the scheduler to submit the job to.")
834 help="The hostname of the scheduler to submit the job to.")
839 job_cmd = Unicode(find_job_cmd(), config=True,
835 job_cmd = Unicode(find_job_cmd(), config=True,
840 help="The command for submitting jobs.")
836 help="The command for submitting jobs.")
841
837
842 def __init__(self, work_dir=u'.', config=None, **kwargs):
838 def __init__(self, work_dir=u'.', config=None, **kwargs):
843 super(WindowsHPCLauncher, self).__init__(
839 super(WindowsHPCLauncher, self).__init__(
844 work_dir=work_dir, config=config, **kwargs
840 work_dir=work_dir, config=config, **kwargs
845 )
841 )
846
842
847 @property
843 @property
848 def job_file(self):
844 def job_file(self):
849 return os.path.join(self.work_dir, self.job_file_name)
845 return os.path.join(self.work_dir, self.job_file_name)
850
846
851 def write_job_file(self, n):
847 def write_job_file(self, n):
852 raise NotImplementedError("Implement write_job_file in a subclass.")
848 raise NotImplementedError("Implement write_job_file in a subclass.")
853
849
854 def find_args(self):
850 def find_args(self):
855 return [u'job.exe']
851 return [u'job.exe']
856
852
857 def parse_job_id(self, output):
853 def parse_job_id(self, output):
858 """Take the output of the submit command and return the job id."""
854 """Take the output of the submit command and return the job id."""
859 m = self.job_id_regexp.search(output)
855 m = self.job_id_regexp.search(output)
860 if m is not None:
856 if m is not None:
861 job_id = m.group()
857 job_id = m.group()
862 else:
858 else:
863 raise LauncherError("Job id couldn't be determined: %s" % output)
859 raise LauncherError("Job id couldn't be determined: %s" % output)
864 self.job_id = job_id
860 self.job_id = job_id
865 self.log.info('Job started with id: %r', job_id)
861 self.log.info('Job started with id: %r', job_id)
866 return job_id
862 return job_id
867
863
868 def start(self, n):
864 def start(self, n):
869 """Start n copies of the process using the Win HPC job scheduler."""
865 """Start n copies of the process using the Win HPC job scheduler."""
870 self.write_job_file(n)
866 self.write_job_file(n)
871 args = [
867 args = [
872 'submit',
868 'submit',
873 '/jobfile:%s' % self.job_file,
869 '/jobfile:%s' % self.job_file,
874 '/scheduler:%s' % self.scheduler
870 '/scheduler:%s' % self.scheduler
875 ]
871 ]
876 self.log.debug("Starting Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
872 self.log.debug("Starting Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
877
873
878 output = check_output([self.job_cmd]+args,
874 output = check_output([self.job_cmd]+args,
879 env=os.environ,
875 env=os.environ,
880 cwd=self.work_dir,
876 cwd=self.work_dir,
881 stderr=STDOUT
877 stderr=STDOUT
882 )
878 )
883 job_id = self.parse_job_id(output)
879 job_id = self.parse_job_id(output)
884 self.notify_start(job_id)
880 self.notify_start(job_id)
885 return job_id
881 return job_id
886
882
887 def stop(self):
883 def stop(self):
888 args = [
884 args = [
889 'cancel',
885 'cancel',
890 self.job_id,
886 self.job_id,
891 '/scheduler:%s' % self.scheduler
887 '/scheduler:%s' % self.scheduler
892 ]
888 ]
893 self.log.info("Stopping Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
889 self.log.info("Stopping Win HPC Job: %s" % (self.job_cmd + ' ' + ' '.join(args),))
894 try:
890 try:
895 output = check_output([self.job_cmd]+args,
891 output = check_output([self.job_cmd]+args,
896 env=os.environ,
892 env=os.environ,
897 cwd=self.work_dir,
893 cwd=self.work_dir,
898 stderr=STDOUT
894 stderr=STDOUT
899 )
895 )
900 except:
896 except:
901 output = 'The job already appears to be stoppped: %r' % self.job_id
897 output = 'The job already appears to be stoppped: %r' % self.job_id
902 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
898 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
903 return output
899 return output
904
900
905
901
906 class WindowsHPCControllerLauncher(WindowsHPCLauncher, ClusterAppMixin):
902 class WindowsHPCControllerLauncher(WindowsHPCLauncher, ClusterAppMixin):
907
903
908 job_file_name = Unicode(u'ipcontroller_job.xml', config=True,
904 job_file_name = Unicode(u'ipcontroller_job.xml', config=True,
909 help="WinHPC xml job file.")
905 help="WinHPC xml job file.")
910 controller_args = List([], config=False,
906 controller_args = List([], config=False,
911 help="extra args to pass to ipcontroller")
907 help="extra args to pass to ipcontroller")
912
908
913 def write_job_file(self, n):
909 def write_job_file(self, n):
914 job = IPControllerJob(config=self.config)
910 job = IPControllerJob(config=self.config)
915
911
916 t = IPControllerTask(config=self.config)
912 t = IPControllerTask(config=self.config)
917 # The tasks work directory is *not* the actual work directory of
913 # The tasks work directory is *not* the actual work directory of
918 # the controller. It is used as the base path for the stdout/stderr
914 # the controller. It is used as the base path for the stdout/stderr
919 # files that the scheduler redirects to.
915 # files that the scheduler redirects to.
920 t.work_directory = self.profile_dir
916 t.work_directory = self.profile_dir
921 # Add the profile_dir and from self.start().
917 # Add the profile_dir and from self.start().
922 t.controller_args.extend(self.cluster_args)
918 t.controller_args.extend(self.cluster_args)
923 t.controller_args.extend(self.controller_args)
919 t.controller_args.extend(self.controller_args)
924 job.add_task(t)
920 job.add_task(t)
925
921
926 self.log.debug("Writing job description file: %s", self.job_file)
922 self.log.debug("Writing job description file: %s", self.job_file)
927 job.write(self.job_file)
923 job.write(self.job_file)
928
924
929 @property
925 @property
930 def job_file(self):
926 def job_file(self):
931 return os.path.join(self.profile_dir, self.job_file_name)
927 return os.path.join(self.profile_dir, self.job_file_name)
932
928
933 def start(self):
929 def start(self):
934 """Start the controller by profile_dir."""
930 """Start the controller by profile_dir."""
935 return super(WindowsHPCControllerLauncher, self).start(1)
931 return super(WindowsHPCControllerLauncher, self).start(1)
936
932
937
933
938 class WindowsHPCEngineSetLauncher(WindowsHPCLauncher, ClusterAppMixin):
934 class WindowsHPCEngineSetLauncher(WindowsHPCLauncher, ClusterAppMixin):
939
935
940 job_file_name = Unicode(u'ipengineset_job.xml', config=True,
936 job_file_name = Unicode(u'ipengineset_job.xml', config=True,
941 help="jobfile for ipengines job")
937 help="jobfile for ipengines job")
942 engine_args = List([], config=False,
938 engine_args = List([], config=False,
943 help="extra args to pas to ipengine")
939 help="extra args to pas to ipengine")
944
940
945 def write_job_file(self, n):
941 def write_job_file(self, n):
946 job = IPEngineSetJob(config=self.config)
942 job = IPEngineSetJob(config=self.config)
947
943
948 for i in range(n):
944 for i in range(n):
949 t = IPEngineTask(config=self.config)
945 t = IPEngineTask(config=self.config)
950 # The tasks work directory is *not* the actual work directory of
946 # The tasks work directory is *not* the actual work directory of
951 # the engine. It is used as the base path for the stdout/stderr
947 # the engine. It is used as the base path for the stdout/stderr
952 # files that the scheduler redirects to.
948 # files that the scheduler redirects to.
953 t.work_directory = self.profile_dir
949 t.work_directory = self.profile_dir
954 # Add the profile_dir and from self.start().
950 # Add the profile_dir and from self.start().
955 t.engine_args.extend(self.cluster_args)
951 t.engine_args.extend(self.cluster_args)
956 t.engine_args.extend(self.engine_args)
952 t.engine_args.extend(self.engine_args)
957 job.add_task(t)
953 job.add_task(t)
958
954
959 self.log.debug("Writing job description file: %s", self.job_file)
955 self.log.debug("Writing job description file: %s", self.job_file)
960 job.write(self.job_file)
956 job.write(self.job_file)
961
957
962 @property
958 @property
963 def job_file(self):
959 def job_file(self):
964 return os.path.join(self.profile_dir, self.job_file_name)
960 return os.path.join(self.profile_dir, self.job_file_name)
965
961
966 def start(self, n):
962 def start(self, n):
967 """Start the controller by profile_dir."""
963 """Start the controller by profile_dir."""
968 return super(WindowsHPCEngineSetLauncher, self).start(n)
964 return super(WindowsHPCEngineSetLauncher, self).start(n)
969
965
970
966
971 #-----------------------------------------------------------------------------
967 #-----------------------------------------------------------------------------
972 # Batch (PBS) system launchers
968 # Batch (PBS) system launchers
973 #-----------------------------------------------------------------------------
969 #-----------------------------------------------------------------------------
974
970
975 class BatchClusterAppMixin(ClusterAppMixin):
971 class BatchClusterAppMixin(ClusterAppMixin):
976 """ClusterApp mixin that updates the self.context dict, rather than cl-args."""
972 """ClusterApp mixin that updates the self.context dict, rather than cl-args."""
977 def _profile_dir_changed(self, name, old, new):
973 def _profile_dir_changed(self, name, old, new):
978 self.context[name] = new
974 self.context[name] = new
979 _cluster_id_changed = _profile_dir_changed
975 _cluster_id_changed = _profile_dir_changed
980
976
981 def _profile_dir_default(self):
977 def _profile_dir_default(self):
982 self.context['profile_dir'] = ''
978 self.context['profile_dir'] = ''
983 return ''
979 return ''
984 def _cluster_id_default(self):
980 def _cluster_id_default(self):
985 self.context['cluster_id'] = ''
981 self.context['cluster_id'] = ''
986 return ''
982 return ''
987
983
988
984
989 class BatchSystemLauncher(BaseLauncher):
985 class BatchSystemLauncher(BaseLauncher):
990 """Launch an external process using a batch system.
986 """Launch an external process using a batch system.
991
987
992 This class is designed to work with UNIX batch systems like PBS, LSF,
988 This class is designed to work with UNIX batch systems like PBS, LSF,
993 GridEngine, etc. The overall model is that there are different commands
989 GridEngine, etc. The overall model is that there are different commands
994 like qsub, qdel, etc. that handle the starting and stopping of the process.
990 like qsub, qdel, etc. that handle the starting and stopping of the process.
995
991
996 This class also has the notion of a batch script. The ``batch_template``
992 This class also has the notion of a batch script. The ``batch_template``
997 attribute can be set to a string that is a template for the batch script.
993 attribute can be set to a string that is a template for the batch script.
998 This template is instantiated using string formatting. Thus the template can
994 This template is instantiated using string formatting. Thus the template can
999 use {n} fot the number of instances. Subclasses can add additional variables
995 use {n} fot the number of instances. Subclasses can add additional variables
1000 to the template dict.
996 to the template dict.
1001 """
997 """
1002
998
1003 # Subclasses must fill these in. See PBSEngineSet
999 # Subclasses must fill these in. See PBSEngineSet
1004 submit_command = List([''], config=True,
1000 submit_command = List([''], config=True,
1005 help="The name of the command line program used to submit jobs.")
1001 help="The name of the command line program used to submit jobs.")
1006 delete_command = List([''], config=True,
1002 delete_command = List([''], config=True,
1007 help="The name of the command line program used to delete jobs.")
1003 help="The name of the command line program used to delete jobs.")
1008 job_id_regexp = CRegExp('', config=True,
1004 job_id_regexp = CRegExp('', config=True,
1009 help="""A regular expression used to get the job id from the output of the
1005 help="""A regular expression used to get the job id from the output of the
1010 submit_command.""")
1006 submit_command.""")
1011 batch_template = Unicode('', config=True,
1007 batch_template = Unicode('', config=True,
1012 help="The string that is the batch script template itself.")
1008 help="The string that is the batch script template itself.")
1013 batch_template_file = Unicode(u'', config=True,
1009 batch_template_file = Unicode(u'', config=True,
1014 help="The file that contains the batch template.")
1010 help="The file that contains the batch template.")
1015 batch_file_name = Unicode(u'batch_script', config=True,
1011 batch_file_name = Unicode(u'batch_script', config=True,
1016 help="The filename of the instantiated batch script.")
1012 help="The filename of the instantiated batch script.")
1017 queue = Unicode(u'', config=True,
1013 queue = Unicode(u'', config=True,
1018 help="The PBS Queue.")
1014 help="The PBS Queue.")
1019
1015
1020 def _queue_changed(self, name, old, new):
1016 def _queue_changed(self, name, old, new):
1021 self.context[name] = new
1017 self.context[name] = new
1022
1018
1023 n = Integer(1)
1019 n = Integer(1)
1024 _n_changed = _queue_changed
1020 _n_changed = _queue_changed
1025
1021
1026 # not configurable, override in subclasses
1022 # not configurable, override in subclasses
1027 # PBS Job Array regex
1023 # PBS Job Array regex
1028 job_array_regexp = CRegExp('')
1024 job_array_regexp = CRegExp('')
1029 job_array_template = Unicode('')
1025 job_array_template = Unicode('')
1030 # PBS Queue regex
1026 # PBS Queue regex
1031 queue_regexp = CRegExp('')
1027 queue_regexp = CRegExp('')
1032 queue_template = Unicode('')
1028 queue_template = Unicode('')
1033 # The default batch template, override in subclasses
1029 # The default batch template, override in subclasses
1034 default_template = Unicode('')
1030 default_template = Unicode('')
1035 # The full path to the instantiated batch script.
1031 # The full path to the instantiated batch script.
1036 batch_file = Unicode(u'')
1032 batch_file = Unicode(u'')
1037 # the format dict used with batch_template:
1033 # the format dict used with batch_template:
1038 context = Dict()
1034 context = Dict()
1039 def _context_default(self):
1035 def _context_default(self):
1040 """load the default context with the default values for the basic keys
1036 """load the default context with the default values for the basic keys
1041
1037
1042 because the _trait_changed methods only load the context if they
1038 because the _trait_changed methods only load the context if they
1043 are set to something other than the default value.
1039 are set to something other than the default value.
1044 """
1040 """
1045 return dict(n=1, queue=u'', profile_dir=u'', cluster_id=u'')
1041 return dict(n=1, queue=u'', profile_dir=u'', cluster_id=u'')
1046
1042
1047 # the Formatter instance for rendering the templates:
1043 # the Formatter instance for rendering the templates:
1048 formatter = Instance(EvalFormatter, (), {})
1044 formatter = Instance(EvalFormatter, (), {})
1049
1045
1050
1046
1051 def find_args(self):
1047 def find_args(self):
1052 return self.submit_command + [self.batch_file]
1048 return self.submit_command + [self.batch_file]
1053
1049
1054 def __init__(self, work_dir=u'.', config=None, **kwargs):
1050 def __init__(self, work_dir=u'.', config=None, **kwargs):
1055 super(BatchSystemLauncher, self).__init__(
1051 super(BatchSystemLauncher, self).__init__(
1056 work_dir=work_dir, config=config, **kwargs
1052 work_dir=work_dir, config=config, **kwargs
1057 )
1053 )
1058 self.batch_file = os.path.join(self.work_dir, self.batch_file_name)
1054 self.batch_file = os.path.join(self.work_dir, self.batch_file_name)
1059
1055
1060 def parse_job_id(self, output):
1056 def parse_job_id(self, output):
1061 """Take the output of the submit command and return the job id."""
1057 """Take the output of the submit command and return the job id."""
1062 m = self.job_id_regexp.search(output)
1058 m = self.job_id_regexp.search(output)
1063 if m is not None:
1059 if m is not None:
1064 job_id = m.group()
1060 job_id = m.group()
1065 else:
1061 else:
1066 raise LauncherError("Job id couldn't be determined: %s" % output)
1062 raise LauncherError("Job id couldn't be determined: %s" % output)
1067 self.job_id = job_id
1063 self.job_id = job_id
1068 self.log.info('Job submitted with job id: %r', job_id)
1064 self.log.info('Job submitted with job id: %r', job_id)
1069 return job_id
1065 return job_id
1070
1066
1071 def write_batch_script(self, n):
1067 def write_batch_script(self, n):
1072 """Instantiate and write the batch script to the work_dir."""
1068 """Instantiate and write the batch script to the work_dir."""
1073 self.n = n
1069 self.n = n
1074 # first priority is batch_template if set
1070 # first priority is batch_template if set
1075 if self.batch_template_file and not self.batch_template:
1071 if self.batch_template_file and not self.batch_template:
1076 # second priority is batch_template_file
1072 # second priority is batch_template_file
1077 with open(self.batch_template_file) as f:
1073 with open(self.batch_template_file) as f:
1078 self.batch_template = f.read()
1074 self.batch_template = f.read()
1079 if not self.batch_template:
1075 if not self.batch_template:
1080 # third (last) priority is default_template
1076 # third (last) priority is default_template
1081 self.batch_template = self.default_template
1077 self.batch_template = self.default_template
1082
1078
1083 # add jobarray or queue lines to user-specified template
1079 # add jobarray or queue lines to user-specified template
1084 # note that this is *only* when user did not specify a template.
1080 # note that this is *only* when user did not specify a template.
1085 # print self.job_array_regexp.search(self.batch_template)
1081 # print self.job_array_regexp.search(self.batch_template)
1086 if not self.job_array_regexp.search(self.batch_template):
1082 if not self.job_array_regexp.search(self.batch_template):
1087 self.log.debug("adding job array settings to batch script")
1083 self.log.debug("adding job array settings to batch script")
1088 firstline, rest = self.batch_template.split('\n',1)
1084 firstline, rest = self.batch_template.split('\n',1)
1089 self.batch_template = u'\n'.join([firstline, self.job_array_template, rest])
1085 self.batch_template = u'\n'.join([firstline, self.job_array_template, rest])
1090
1086
1091 # print self.queue_regexp.search(self.batch_template)
1087 # print self.queue_regexp.search(self.batch_template)
1092 if self.queue and not self.queue_regexp.search(self.batch_template):
1088 if self.queue and not self.queue_regexp.search(self.batch_template):
1093 self.log.debug("adding PBS queue settings to batch script")
1089 self.log.debug("adding PBS queue settings to batch script")
1094 firstline, rest = self.batch_template.split('\n',1)
1090 firstline, rest = self.batch_template.split('\n',1)
1095 self.batch_template = u'\n'.join([firstline, self.queue_template, rest])
1091 self.batch_template = u'\n'.join([firstline, self.queue_template, rest])
1096
1092
1097 script_as_string = self.formatter.format(self.batch_template, **self.context)
1093 script_as_string = self.formatter.format(self.batch_template, **self.context)
1098 self.log.debug('Writing batch script: %s', self.batch_file)
1094 self.log.debug('Writing batch script: %s', self.batch_file)
1099
1095
1100 with open(self.batch_file, 'w') as f:
1096 with open(self.batch_file, 'w') as f:
1101 f.write(script_as_string)
1097 f.write(script_as_string)
1102 os.chmod(self.batch_file, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
1098 os.chmod(self.batch_file, stat.S_IRUSR | stat.S_IWUSR | stat.S_IXUSR)
1103
1099
1104 def start(self, n):
1100 def start(self, n):
1105 """Start n copies of the process using a batch system."""
1101 """Start n copies of the process using a batch system."""
1106 self.log.debug("Starting %s: %r", self.__class__.__name__, self.args)
1102 self.log.debug("Starting %s: %r", self.__class__.__name__, self.args)
1107 # Here we save profile_dir in the context so they
1103 # Here we save profile_dir in the context so they
1108 # can be used in the batch script template as {profile_dir}
1104 # can be used in the batch script template as {profile_dir}
1109 self.write_batch_script(n)
1105 self.write_batch_script(n)
1110 output = check_output(self.args, env=os.environ)
1106 output = check_output(self.args, env=os.environ)
1111
1107
1112 job_id = self.parse_job_id(output)
1108 job_id = self.parse_job_id(output)
1113 self.notify_start(job_id)
1109 self.notify_start(job_id)
1114 return job_id
1110 return job_id
1115
1111
1116 def stop(self):
1112 def stop(self):
1117 output = check_output(self.delete_command+[self.job_id], env=os.environ)
1113 output = check_output(self.delete_command+[self.job_id], env=os.environ)
1118 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
1114 self.notify_stop(dict(job_id=self.job_id, output=output)) # Pass the output of the kill cmd
1119 return output
1115 return output
1120
1116
1121
1117
1122 class PBSLauncher(BatchSystemLauncher):
1118 class PBSLauncher(BatchSystemLauncher):
1123 """A BatchSystemLauncher subclass for PBS."""
1119 """A BatchSystemLauncher subclass for PBS."""
1124
1120
1125 submit_command = List(['qsub'], config=True,
1121 submit_command = List(['qsub'], config=True,
1126 help="The PBS submit command ['qsub']")
1122 help="The PBS submit command ['qsub']")
1127 delete_command = List(['qdel'], config=True,
1123 delete_command = List(['qdel'], config=True,
1128 help="The PBS delete command ['qsub']")
1124 help="The PBS delete command ['qsub']")
1129 job_id_regexp = CRegExp(r'\d+', config=True,
1125 job_id_regexp = CRegExp(r'\d+', config=True,
1130 help="Regular expresion for identifying the job ID [r'\d+']")
1126 help="Regular expresion for identifying the job ID [r'\d+']")
1131
1127
1132 batch_file = Unicode(u'')
1128 batch_file = Unicode(u'')
1133 job_array_regexp = CRegExp('#PBS\W+-t\W+[\w\d\-\$]+')
1129 job_array_regexp = CRegExp('#PBS\W+-t\W+[\w\d\-\$]+')
1134 job_array_template = Unicode('#PBS -t 1-{n}')
1130 job_array_template = Unicode('#PBS -t 1-{n}')
1135 queue_regexp = CRegExp('#PBS\W+-q\W+\$?\w+')
1131 queue_regexp = CRegExp('#PBS\W+-q\W+\$?\w+')
1136 queue_template = Unicode('#PBS -q {queue}')
1132 queue_template = Unicode('#PBS -q {queue}')
1137
1133
1138
1134
1139 class PBSControllerLauncher(PBSLauncher, BatchClusterAppMixin):
1135 class PBSControllerLauncher(PBSLauncher, BatchClusterAppMixin):
1140 """Launch a controller using PBS."""
1136 """Launch a controller using PBS."""
1141
1137
1142 batch_file_name = Unicode(u'pbs_controller', config=True,
1138 batch_file_name = Unicode(u'pbs_controller', config=True,
1143 help="batch file name for the controller job.")
1139 help="batch file name for the controller job.")
1144 default_template= Unicode("""#!/bin/sh
1140 default_template= Unicode("""#!/bin/sh
1145 #PBS -V
1141 #PBS -V
1146 #PBS -N ipcontroller
1142 #PBS -N ipcontroller
1147 %s --log-to-file --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1143 %s --log-to-file --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1148 """%(' '.join(ipcontroller_cmd_argv)))
1144 """%(' '.join(ipcontroller_cmd_argv)))
1149
1145
1150
1146
1151 def start(self):
1147 def start(self):
1152 """Start the controller by profile or profile_dir."""
1148 """Start the controller by profile or profile_dir."""
1153 return super(PBSControllerLauncher, self).start(1)
1149 return super(PBSControllerLauncher, self).start(1)
1154
1150
1155
1151
1156 class PBSEngineSetLauncher(PBSLauncher, BatchClusterAppMixin):
1152 class PBSEngineSetLauncher(PBSLauncher, BatchClusterAppMixin):
1157 """Launch Engines using PBS"""
1153 """Launch Engines using PBS"""
1158 batch_file_name = Unicode(u'pbs_engines', config=True,
1154 batch_file_name = Unicode(u'pbs_engines', config=True,
1159 help="batch file name for the engine(s) job.")
1155 help="batch file name for the engine(s) job.")
1160 default_template= Unicode(u"""#!/bin/sh
1156 default_template= Unicode(u"""#!/bin/sh
1161 #PBS -V
1157 #PBS -V
1162 #PBS -N ipengine
1158 #PBS -N ipengine
1163 %s --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1159 %s --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1164 """%(' '.join(ipengine_cmd_argv)))
1160 """%(' '.join(ipengine_cmd_argv)))
1165
1161
1166 def start(self, n):
1162 def start(self, n):
1167 """Start n engines by profile or profile_dir."""
1163 """Start n engines by profile or profile_dir."""
1168 return super(PBSEngineSetLauncher, self).start(n)
1164 return super(PBSEngineSetLauncher, self).start(n)
1169
1165
1170 #SGE is very similar to PBS
1166 #SGE is very similar to PBS
1171
1167
1172 class SGELauncher(PBSLauncher):
1168 class SGELauncher(PBSLauncher):
1173 """Sun GridEngine is a PBS clone with slightly different syntax"""
1169 """Sun GridEngine is a PBS clone with slightly different syntax"""
1174 job_array_regexp = CRegExp('#\$\W+\-t')
1170 job_array_regexp = CRegExp('#\$\W+\-t')
1175 job_array_template = Unicode('#$ -t 1-{n}')
1171 job_array_template = Unicode('#$ -t 1-{n}')
1176 queue_regexp = CRegExp('#\$\W+-q\W+\$?\w+')
1172 queue_regexp = CRegExp('#\$\W+-q\W+\$?\w+')
1177 queue_template = Unicode('#$ -q {queue}')
1173 queue_template = Unicode('#$ -q {queue}')
1178
1174
1179 class SGEControllerLauncher(SGELauncher, BatchClusterAppMixin):
1175 class SGEControllerLauncher(SGELauncher, BatchClusterAppMixin):
1180 """Launch a controller using SGE."""
1176 """Launch a controller using SGE."""
1181
1177
1182 batch_file_name = Unicode(u'sge_controller', config=True,
1178 batch_file_name = Unicode(u'sge_controller', config=True,
1183 help="batch file name for the ipontroller job.")
1179 help="batch file name for the ipontroller job.")
1184 default_template= Unicode(u"""#$ -V
1180 default_template= Unicode(u"""#$ -V
1185 #$ -S /bin/sh
1181 #$ -S /bin/sh
1186 #$ -N ipcontroller
1182 #$ -N ipcontroller
1187 %s --log-to-file --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1183 %s --log-to-file --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1188 """%(' '.join(ipcontroller_cmd_argv)))
1184 """%(' '.join(ipcontroller_cmd_argv)))
1189
1185
1190 def start(self):
1186 def start(self):
1191 """Start the controller by profile or profile_dir."""
1187 """Start the controller by profile or profile_dir."""
1192 return super(SGEControllerLauncher, self).start(1)
1188 return super(SGEControllerLauncher, self).start(1)
1193
1189
1194 class SGEEngineSetLauncher(SGELauncher, BatchClusterAppMixin):
1190 class SGEEngineSetLauncher(SGELauncher, BatchClusterAppMixin):
1195 """Launch Engines with SGE"""
1191 """Launch Engines with SGE"""
1196 batch_file_name = Unicode(u'sge_engines', config=True,
1192 batch_file_name = Unicode(u'sge_engines', config=True,
1197 help="batch file name for the engine(s) job.")
1193 help="batch file name for the engine(s) job.")
1198 default_template = Unicode("""#$ -V
1194 default_template = Unicode("""#$ -V
1199 #$ -S /bin/sh
1195 #$ -S /bin/sh
1200 #$ -N ipengine
1196 #$ -N ipengine
1201 %s --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1197 %s --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1202 """%(' '.join(ipengine_cmd_argv)))
1198 """%(' '.join(ipengine_cmd_argv)))
1203
1199
1204 def start(self, n):
1200 def start(self, n):
1205 """Start n engines by profile or profile_dir."""
1201 """Start n engines by profile or profile_dir."""
1206 return super(SGEEngineSetLauncher, self).start(n)
1202 return super(SGEEngineSetLauncher, self).start(n)
1207
1203
1208
1204
1209 # LSF launchers
1205 # LSF launchers
1210
1206
1211 class LSFLauncher(BatchSystemLauncher):
1207 class LSFLauncher(BatchSystemLauncher):
1212 """A BatchSystemLauncher subclass for LSF."""
1208 """A BatchSystemLauncher subclass for LSF."""
1213
1209
1214 submit_command = List(['bsub'], config=True,
1210 submit_command = List(['bsub'], config=True,
1215 help="The PBS submit command ['bsub']")
1211 help="The PBS submit command ['bsub']")
1216 delete_command = List(['bkill'], config=True,
1212 delete_command = List(['bkill'], config=True,
1217 help="The PBS delete command ['bkill']")
1213 help="The PBS delete command ['bkill']")
1218 job_id_regexp = CRegExp(r'\d+', config=True,
1214 job_id_regexp = CRegExp(r'\d+', config=True,
1219 help="Regular expresion for identifying the job ID [r'\d+']")
1215 help="Regular expresion for identifying the job ID [r'\d+']")
1220
1216
1221 batch_file = Unicode(u'')
1217 batch_file = Unicode(u'')
1222 job_array_regexp = CRegExp('#BSUB[ \t]-J+\w+\[\d+-\d+\]')
1218 job_array_regexp = CRegExp('#BSUB[ \t]-J+\w+\[\d+-\d+\]')
1223 job_array_template = Unicode('#BSUB -J ipengine[1-{n}]')
1219 job_array_template = Unicode('#BSUB -J ipengine[1-{n}]')
1224 queue_regexp = CRegExp('#BSUB[ \t]+-q[ \t]+\w+')
1220 queue_regexp = CRegExp('#BSUB[ \t]+-q[ \t]+\w+')
1225 queue_template = Unicode('#BSUB -q {queue}')
1221 queue_template = Unicode('#BSUB -q {queue}')
1226
1222
1227 def start(self, n):
1223 def start(self, n):
1228 """Start n copies of the process using LSF batch system.
1224 """Start n copies of the process using LSF batch system.
1229 This cant inherit from the base class because bsub expects
1225 This cant inherit from the base class because bsub expects
1230 to be piped a shell script in order to honor the #BSUB directives :
1226 to be piped a shell script in order to honor the #BSUB directives :
1231 bsub < script
1227 bsub < script
1232 """
1228 """
1233 # Here we save profile_dir in the context so they
1229 # Here we save profile_dir in the context so they
1234 # can be used in the batch script template as {profile_dir}
1230 # can be used in the batch script template as {profile_dir}
1235 self.write_batch_script(n)
1231 self.write_batch_script(n)
1236 #output = check_output(self.args, env=os.environ)
1232 #output = check_output(self.args, env=os.environ)
1237 piped_cmd = self.args[0]+'<\"'+self.args[1]+'\"'
1233 piped_cmd = self.args[0]+'<\"'+self.args[1]+'\"'
1238 self.log.debug("Starting %s: %s", self.__class__.__name__, piped_cmd)
1234 self.log.debug("Starting %s: %s", self.__class__.__name__, piped_cmd)
1239 p = Popen(piped_cmd, shell=True,env=os.environ,stdout=PIPE)
1235 p = Popen(piped_cmd, shell=True,env=os.environ,stdout=PIPE)
1240 output,err = p.communicate()
1236 output,err = p.communicate()
1241 job_id = self.parse_job_id(output)
1237 job_id = self.parse_job_id(output)
1242 self.notify_start(job_id)
1238 self.notify_start(job_id)
1243 return job_id
1239 return job_id
1244
1240
1245
1241
1246 class LSFControllerLauncher(LSFLauncher, BatchClusterAppMixin):
1242 class LSFControllerLauncher(LSFLauncher, BatchClusterAppMixin):
1247 """Launch a controller using LSF."""
1243 """Launch a controller using LSF."""
1248
1244
1249 batch_file_name = Unicode(u'lsf_controller', config=True,
1245 batch_file_name = Unicode(u'lsf_controller', config=True,
1250 help="batch file name for the controller job.")
1246 help="batch file name for the controller job.")
1251 default_template= Unicode("""#!/bin/sh
1247 default_template= Unicode("""#!/bin/sh
1252 #BSUB -J ipcontroller
1248 #BSUB -J ipcontroller
1253 #BSUB -oo ipcontroller.o.%%J
1249 #BSUB -oo ipcontroller.o.%%J
1254 #BSUB -eo ipcontroller.e.%%J
1250 #BSUB -eo ipcontroller.e.%%J
1255 %s --log-to-file --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1251 %s --log-to-file --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1256 """%(' '.join(ipcontroller_cmd_argv)))
1252 """%(' '.join(ipcontroller_cmd_argv)))
1257
1253
1258 def start(self):
1254 def start(self):
1259 """Start the controller by profile or profile_dir."""
1255 """Start the controller by profile or profile_dir."""
1260 return super(LSFControllerLauncher, self).start(1)
1256 return super(LSFControllerLauncher, self).start(1)
1261
1257
1262
1258
1263 class LSFEngineSetLauncher(LSFLauncher, BatchClusterAppMixin):
1259 class LSFEngineSetLauncher(LSFLauncher, BatchClusterAppMixin):
1264 """Launch Engines using LSF"""
1260 """Launch Engines using LSF"""
1265 batch_file_name = Unicode(u'lsf_engines', config=True,
1261 batch_file_name = Unicode(u'lsf_engines', config=True,
1266 help="batch file name for the engine(s) job.")
1262 help="batch file name for the engine(s) job.")
1267 default_template= Unicode(u"""#!/bin/sh
1263 default_template= Unicode(u"""#!/bin/sh
1268 #BSUB -oo ipengine.o.%%J
1264 #BSUB -oo ipengine.o.%%J
1269 #BSUB -eo ipengine.e.%%J
1265 #BSUB -eo ipengine.e.%%J
1270 %s --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1266 %s --profile-dir="{profile_dir}" --cluster-id="{cluster_id}"
1271 """%(' '.join(ipengine_cmd_argv)))
1267 """%(' '.join(ipengine_cmd_argv)))
1272
1268
1273 def start(self, n):
1269 def start(self, n):
1274 """Start n engines by profile or profile_dir."""
1270 """Start n engines by profile or profile_dir."""
1275 return super(LSFEngineSetLauncher, self).start(n)
1271 return super(LSFEngineSetLauncher, self).start(n)
1276
1272
1277
1273
1278 #-----------------------------------------------------------------------------
1274 #-----------------------------------------------------------------------------
1279 # A launcher for ipcluster itself!
1275 # A launcher for ipcluster itself!
1280 #-----------------------------------------------------------------------------
1276 #-----------------------------------------------------------------------------
1281
1277
1282
1278
1283 class IPClusterLauncher(LocalProcessLauncher):
1279 class IPClusterLauncher(LocalProcessLauncher):
1284 """Launch the ipcluster program in an external process."""
1280 """Launch the ipcluster program in an external process."""
1285
1281
1286 ipcluster_cmd = List(ipcluster_cmd_argv, config=True,
1282 ipcluster_cmd = List(ipcluster_cmd_argv, config=True,
1287 help="Popen command for ipcluster")
1283 help="Popen command for ipcluster")
1288 ipcluster_args = List(
1284 ipcluster_args = List(
1289 ['--clean-logs=True', '--log-to-file', '--log-level=%i'%logging.INFO], config=True,
1285 ['--clean-logs=True', '--log-to-file', '--log-level=%i'%logging.INFO], config=True,
1290 help="Command line arguments to pass to ipcluster.")
1286 help="Command line arguments to pass to ipcluster.")
1291 ipcluster_subcommand = Unicode('start')
1287 ipcluster_subcommand = Unicode('start')
1292 profile = Unicode('default')
1288 profile = Unicode('default')
1293 n = Integer(2)
1289 n = Integer(2)
1294
1290
1295 def find_args(self):
1291 def find_args(self):
1296 return self.ipcluster_cmd + [self.ipcluster_subcommand] + \
1292 return self.ipcluster_cmd + [self.ipcluster_subcommand] + \
1297 ['--n=%i'%self.n, '--profile=%s'%self.profile] + \
1293 ['--n=%i'%self.n, '--profile=%s'%self.profile] + \
1298 self.ipcluster_args
1294 self.ipcluster_args
1299
1295
1300 def start(self):
1296 def start(self):
1301 return super(IPClusterLauncher, self).start()
1297 return super(IPClusterLauncher, self).start()
1302
1298
1303 #-----------------------------------------------------------------------------
1299 #-----------------------------------------------------------------------------
1304 # Collections of launchers
1300 # Collections of launchers
1305 #-----------------------------------------------------------------------------
1301 #-----------------------------------------------------------------------------
1306
1302
1307 local_launchers = [
1303 local_launchers = [
1308 LocalControllerLauncher,
1304 LocalControllerLauncher,
1309 LocalEngineLauncher,
1305 LocalEngineLauncher,
1310 LocalEngineSetLauncher,
1306 LocalEngineSetLauncher,
1311 ]
1307 ]
1312 mpi_launchers = [
1308 mpi_launchers = [
1313 MPILauncher,
1309 MPILauncher,
1314 MPIControllerLauncher,
1310 MPIControllerLauncher,
1315 MPIEngineSetLauncher,
1311 MPIEngineSetLauncher,
1316 ]
1312 ]
1317 ssh_launchers = [
1313 ssh_launchers = [
1318 SSHLauncher,
1314 SSHLauncher,
1319 SSHControllerLauncher,
1315 SSHControllerLauncher,
1320 SSHEngineLauncher,
1316 SSHEngineLauncher,
1321 SSHEngineSetLauncher,
1317 SSHEngineSetLauncher,
1322 ]
1318 ]
1323 winhpc_launchers = [
1319 winhpc_launchers = [
1324 WindowsHPCLauncher,
1320 WindowsHPCLauncher,
1325 WindowsHPCControllerLauncher,
1321 WindowsHPCControllerLauncher,
1326 WindowsHPCEngineSetLauncher,
1322 WindowsHPCEngineSetLauncher,
1327 ]
1323 ]
1328 pbs_launchers = [
1324 pbs_launchers = [
1329 PBSLauncher,
1325 PBSLauncher,
1330 PBSControllerLauncher,
1326 PBSControllerLauncher,
1331 PBSEngineSetLauncher,
1327 PBSEngineSetLauncher,
1332 ]
1328 ]
1333 sge_launchers = [
1329 sge_launchers = [
1334 SGELauncher,
1330 SGELauncher,
1335 SGEControllerLauncher,
1331 SGEControllerLauncher,
1336 SGEEngineSetLauncher,
1332 SGEEngineSetLauncher,
1337 ]
1333 ]
1338 lsf_launchers = [
1334 lsf_launchers = [
1339 LSFLauncher,
1335 LSFLauncher,
1340 LSFControllerLauncher,
1336 LSFControllerLauncher,
1341 LSFEngineSetLauncher,
1337 LSFEngineSetLauncher,
1342 ]
1338 ]
1343 all_launchers = local_launchers + mpi_launchers + ssh_launchers + winhpc_launchers\
1339 all_launchers = local_launchers + mpi_launchers + ssh_launchers + winhpc_launchers\
1344 + pbs_launchers + sge_launchers + lsf_launchers
1340 + pbs_launchers + sge_launchers + lsf_launchers
1345
1341
@@ -1,1505 +1,1628 b''
1 """A semi-synchronous Client for the ZMQ cluster
1 """A semi-synchronous Client for the ZMQ cluster
2
2
3 Authors:
3 Authors:
4
4
5 * MinRK
5 * MinRK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import os
18 import os
19 import json
19 import json
20 import sys
20 import sys
21 from threading import Thread, Event
21 from threading import Thread, Event
22 import time
22 import time
23 import warnings
23 import warnings
24 from datetime import datetime
24 from datetime import datetime
25 from getpass import getpass
25 from getpass import getpass
26 from pprint import pprint
26 from pprint import pprint
27
27
28 pjoin = os.path.join
28 pjoin = os.path.join
29
29
30 import zmq
30 import zmq
31 # from zmq.eventloop import ioloop, zmqstream
31 # from zmq.eventloop import ioloop, zmqstream
32
32
33 from IPython.config.configurable import MultipleInstanceError
33 from IPython.config.configurable import MultipleInstanceError
34 from IPython.core.application import BaseIPythonApplication
34 from IPython.core.application import BaseIPythonApplication
35
35
36 from IPython.utils.jsonutil import rekey
36 from IPython.utils.jsonutil import rekey
37 from IPython.utils.localinterfaces import LOCAL_IPS
37 from IPython.utils.localinterfaces import LOCAL_IPS
38 from IPython.utils.path import get_ipython_dir
38 from IPython.utils.path import get_ipython_dir
39 from IPython.utils.py3compat import cast_bytes
39 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
40 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
40 Dict, List, Bool, Set, Any)
41 Dict, List, Bool, Set, Any)
41 from IPython.external.decorator import decorator
42 from IPython.external.decorator import decorator
42 from IPython.external.ssh import tunnel
43 from IPython.external.ssh import tunnel
43
44
44 from IPython.parallel import Reference
45 from IPython.parallel import Reference
45 from IPython.parallel import error
46 from IPython.parallel import error
46 from IPython.parallel import util
47 from IPython.parallel import util
47
48
48 from IPython.zmq.session import Session, Message
49 from IPython.zmq.session import Session, Message
49
50
50 from .asyncresult import AsyncResult, AsyncHubResult
51 from .asyncresult import AsyncResult, AsyncHubResult
51 from IPython.core.profiledir import ProfileDir, ProfileDirError
52 from IPython.core.profiledir import ProfileDir, ProfileDirError
52 from .view import DirectView, LoadBalancedView
53 from .view import DirectView, LoadBalancedView
53
54
54 if sys.version_info[0] >= 3:
55 if sys.version_info[0] >= 3:
55 # xrange is used in a couple 'isinstance' tests in py2
56 # xrange is used in a couple 'isinstance' tests in py2
56 # should be just 'range' in 3k
57 # should be just 'range' in 3k
57 xrange = range
58 xrange = range
58
59
59 #--------------------------------------------------------------------------
60 #--------------------------------------------------------------------------
60 # Decorators for Client methods
61 # Decorators for Client methods
61 #--------------------------------------------------------------------------
62 #--------------------------------------------------------------------------
62
63
63 @decorator
64 @decorator
64 def spin_first(f, self, *args, **kwargs):
65 def spin_first(f, self, *args, **kwargs):
65 """Call spin() to sync state prior to calling the method."""
66 """Call spin() to sync state prior to calling the method."""
66 self.spin()
67 self.spin()
67 return f(self, *args, **kwargs)
68 return f(self, *args, **kwargs)
68
69
69
70
70 #--------------------------------------------------------------------------
71 #--------------------------------------------------------------------------
71 # Classes
72 # Classes
72 #--------------------------------------------------------------------------
73 #--------------------------------------------------------------------------
73
74
75
76 class ExecuteReply(object):
77 """wrapper for finished Execute results"""
78 def __init__(self, msg_id, content, metadata):
79 self.msg_id = msg_id
80 self._content = content
81 self.execution_count = content['execution_count']
82 self.metadata = metadata
83
84 def __getitem__(self, key):
85 return self.metadata[key]
86
87 def __getattr__(self, key):
88 if key not in self.metadata:
89 raise AttributeError(key)
90 return self.metadata[key]
91
92 def __repr__(self):
93 pyout = self.metadata['pyout'] or {}
94 text_out = pyout.get('data', {}).get('text/plain', '')
95 if len(text_out) > 32:
96 text_out = text_out[:29] + '...'
97
98 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
99
100 def _repr_html_(self):
101 pyout = self.metadata['pyout'] or {'data':{}}
102 return pyout['data'].get("text/html")
103
104 def _repr_latex_(self):
105 pyout = self.metadata['pyout'] or {'data':{}}
106 return pyout['data'].get("text/latex")
107
108 def _repr_json_(self):
109 pyout = self.metadata['pyout'] or {'data':{}}
110 return pyout['data'].get("application/json")
111
112 def _repr_javascript_(self):
113 pyout = self.metadata['pyout'] or {'data':{}}
114 return pyout['data'].get("application/javascript")
115
116 def _repr_png_(self):
117 pyout = self.metadata['pyout'] or {'data':{}}
118 return pyout['data'].get("image/png")
119
120 def _repr_jpeg_(self):
121 pyout = self.metadata['pyout'] or {'data':{}}
122 return pyout['data'].get("image/jpeg")
123
124 def _repr_svg_(self):
125 pyout = self.metadata['pyout'] or {'data':{}}
126 return pyout['data'].get("image/svg+xml")
127
128
74 class Metadata(dict):
129 class Metadata(dict):
75 """Subclass of dict for initializing metadata values.
130 """Subclass of dict for initializing metadata values.
76
131
77 Attribute access works on keys.
132 Attribute access works on keys.
78
133
79 These objects have a strict set of keys - errors will raise if you try
134 These objects have a strict set of keys - errors will raise if you try
80 to add new keys.
135 to add new keys.
81 """
136 """
82 def __init__(self, *args, **kwargs):
137 def __init__(self, *args, **kwargs):
83 dict.__init__(self)
138 dict.__init__(self)
84 md = {'msg_id' : None,
139 md = {'msg_id' : None,
85 'submitted' : None,
140 'submitted' : None,
86 'started' : None,
141 'started' : None,
87 'completed' : None,
142 'completed' : None,
88 'received' : None,
143 'received' : None,
89 'engine_uuid' : None,
144 'engine_uuid' : None,
90 'engine_id' : None,
145 'engine_id' : None,
91 'follow' : None,
146 'follow' : None,
92 'after' : None,
147 'after' : None,
93 'status' : None,
148 'status' : None,
94
149
95 'pyin' : None,
150 'pyin' : None,
96 'pyout' : None,
151 'pyout' : None,
97 'pyerr' : None,
152 'pyerr' : None,
98 'stdout' : '',
153 'stdout' : '',
99 'stderr' : '',
154 'stderr' : '',
155 'outputs' : [],
100 }
156 }
101 self.update(md)
157 self.update(md)
102 self.update(dict(*args, **kwargs))
158 self.update(dict(*args, **kwargs))
103
159
104 def __getattr__(self, key):
160 def __getattr__(self, key):
105 """getattr aliased to getitem"""
161 """getattr aliased to getitem"""
106 if key in self.iterkeys():
162 if key in self.iterkeys():
107 return self[key]
163 return self[key]
108 else:
164 else:
109 raise AttributeError(key)
165 raise AttributeError(key)
110
166
111 def __setattr__(self, key, value):
167 def __setattr__(self, key, value):
112 """setattr aliased to setitem, with strict"""
168 """setattr aliased to setitem, with strict"""
113 if key in self.iterkeys():
169 if key in self.iterkeys():
114 self[key] = value
170 self[key] = value
115 else:
171 else:
116 raise AttributeError(key)
172 raise AttributeError(key)
117
173
118 def __setitem__(self, key, value):
174 def __setitem__(self, key, value):
119 """strict static key enforcement"""
175 """strict static key enforcement"""
120 if key in self.iterkeys():
176 if key in self.iterkeys():
121 dict.__setitem__(self, key, value)
177 dict.__setitem__(self, key, value)
122 else:
178 else:
123 raise KeyError(key)
179 raise KeyError(key)
124
180
125
181
126 class Client(HasTraits):
182 class Client(HasTraits):
127 """A semi-synchronous client to the IPython ZMQ cluster
183 """A semi-synchronous client to the IPython ZMQ cluster
128
184
129 Parameters
185 Parameters
130 ----------
186 ----------
131
187
132 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
188 url_or_file : bytes or unicode; zmq url or path to ipcontroller-client.json
133 Connection information for the Hub's registration. If a json connector
189 Connection information for the Hub's registration. If a json connector
134 file is given, then likely no further configuration is necessary.
190 file is given, then likely no further configuration is necessary.
135 [Default: use profile]
191 [Default: use profile]
136 profile : bytes
192 profile : bytes
137 The name of the Cluster profile to be used to find connector information.
193 The name of the Cluster profile to be used to find connector information.
138 If run from an IPython application, the default profile will be the same
194 If run from an IPython application, the default profile will be the same
139 as the running application, otherwise it will be 'default'.
195 as the running application, otherwise it will be 'default'.
140 context : zmq.Context
196 context : zmq.Context
141 Pass an existing zmq.Context instance, otherwise the client will create its own.
197 Pass an existing zmq.Context instance, otherwise the client will create its own.
142 debug : bool
198 debug : bool
143 flag for lots of message printing for debug purposes
199 flag for lots of message printing for debug purposes
144 timeout : int/float
200 timeout : int/float
145 time (in seconds) to wait for connection replies from the Hub
201 time (in seconds) to wait for connection replies from the Hub
146 [Default: 10]
202 [Default: 10]
147
203
148 #-------------- session related args ----------------
204 #-------------- session related args ----------------
149
205
150 config : Config object
206 config : Config object
151 If specified, this will be relayed to the Session for configuration
207 If specified, this will be relayed to the Session for configuration
152 username : str
208 username : str
153 set username for the session object
209 set username for the session object
154 packer : str (import_string) or callable
210 packer : str (import_string) or callable
155 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
211 Can be either the simple keyword 'json' or 'pickle', or an import_string to a
156 function to serialize messages. Must support same input as
212 function to serialize messages. Must support same input as
157 JSON, and output must be bytes.
213 JSON, and output must be bytes.
158 You can pass a callable directly as `pack`
214 You can pass a callable directly as `pack`
159 unpacker : str (import_string) or callable
215 unpacker : str (import_string) or callable
160 The inverse of packer. Only necessary if packer is specified as *not* one
216 The inverse of packer. Only necessary if packer is specified as *not* one
161 of 'json' or 'pickle'.
217 of 'json' or 'pickle'.
162
218
163 #-------------- ssh related args ----------------
219 #-------------- ssh related args ----------------
164 # These are args for configuring the ssh tunnel to be used
220 # These are args for configuring the ssh tunnel to be used
165 # credentials are used to forward connections over ssh to the Controller
221 # credentials are used to forward connections over ssh to the Controller
166 # Note that the ip given in `addr` needs to be relative to sshserver
222 # Note that the ip given in `addr` needs to be relative to sshserver
167 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
223 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
168 # and set sshserver as the same machine the Controller is on. However,
224 # and set sshserver as the same machine the Controller is on. However,
169 # the only requirement is that sshserver is able to see the Controller
225 # the only requirement is that sshserver is able to see the Controller
170 # (i.e. is within the same trusted network).
226 # (i.e. is within the same trusted network).
171
227
172 sshserver : str
228 sshserver : str
173 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
229 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
174 If keyfile or password is specified, and this is not, it will default to
230 If keyfile or password is specified, and this is not, it will default to
175 the ip given in addr.
231 the ip given in addr.
176 sshkey : str; path to ssh private key file
232 sshkey : str; path to ssh private key file
177 This specifies a key to be used in ssh login, default None.
233 This specifies a key to be used in ssh login, default None.
178 Regular default ssh keys will be used without specifying this argument.
234 Regular default ssh keys will be used without specifying this argument.
179 password : str
235 password : str
180 Your ssh password to sshserver. Note that if this is left None,
236 Your ssh password to sshserver. Note that if this is left None,
181 you will be prompted for it if passwordless key based login is unavailable.
237 you will be prompted for it if passwordless key based login is unavailable.
182 paramiko : bool
238 paramiko : bool
183 flag for whether to use paramiko instead of shell ssh for tunneling.
239 flag for whether to use paramiko instead of shell ssh for tunneling.
184 [default: True on win32, False else]
240 [default: True on win32, False else]
185
241
186 ------- exec authentication args -------
242 ------- exec authentication args -------
187 If even localhost is untrusted, you can have some protection against
243 If even localhost is untrusted, you can have some protection against
188 unauthorized execution by signing messages with HMAC digests.
244 unauthorized execution by signing messages with HMAC digests.
189 Messages are still sent as cleartext, so if someone can snoop your
245 Messages are still sent as cleartext, so if someone can snoop your
190 loopback traffic this will not protect your privacy, but will prevent
246 loopback traffic this will not protect your privacy, but will prevent
191 unauthorized execution.
247 unauthorized execution.
192
248
193 exec_key : str
249 exec_key : str
194 an authentication key or file containing a key
250 an authentication key or file containing a key
195 default: None
251 default: None
196
252
197
253
198 Attributes
254 Attributes
199 ----------
255 ----------
200
256
201 ids : list of int engine IDs
257 ids : list of int engine IDs
202 requesting the ids attribute always synchronizes
258 requesting the ids attribute always synchronizes
203 the registration state. To request ids without synchronization,
259 the registration state. To request ids without synchronization,
204 use semi-private _ids attributes.
260 use semi-private _ids attributes.
205
261
206 history : list of msg_ids
262 history : list of msg_ids
207 a list of msg_ids, keeping track of all the execution
263 a list of msg_ids, keeping track of all the execution
208 messages you have submitted in order.
264 messages you have submitted in order.
209
265
210 outstanding : set of msg_ids
266 outstanding : set of msg_ids
211 a set of msg_ids that have been submitted, but whose
267 a set of msg_ids that have been submitted, but whose
212 results have not yet been received.
268 results have not yet been received.
213
269
214 results : dict
270 results : dict
215 a dict of all our results, keyed by msg_id
271 a dict of all our results, keyed by msg_id
216
272
217 block : bool
273 block : bool
218 determines default behavior when block not specified
274 determines default behavior when block not specified
219 in execution methods
275 in execution methods
220
276
221 Methods
277 Methods
222 -------
278 -------
223
279
224 spin
280 spin
225 flushes incoming results and registration state changes
281 flushes incoming results and registration state changes
226 control methods spin, and requesting `ids` also ensures up to date
282 control methods spin, and requesting `ids` also ensures up to date
227
283
228 wait
284 wait
229 wait on one or more msg_ids
285 wait on one or more msg_ids
230
286
231 execution methods
287 execution methods
232 apply
288 apply
233 legacy: execute, run
289 legacy: execute, run
234
290
235 data movement
291 data movement
236 push, pull, scatter, gather
292 push, pull, scatter, gather
237
293
238 query methods
294 query methods
239 queue_status, get_result, purge, result_status
295 queue_status, get_result, purge, result_status
240
296
241 control methods
297 control methods
242 abort, shutdown
298 abort, shutdown
243
299
244 """
300 """
245
301
246
302
247 block = Bool(False)
303 block = Bool(False)
248 outstanding = Set()
304 outstanding = Set()
249 results = Instance('collections.defaultdict', (dict,))
305 results = Instance('collections.defaultdict', (dict,))
250 metadata = Instance('collections.defaultdict', (Metadata,))
306 metadata = Instance('collections.defaultdict', (Metadata,))
251 history = List()
307 history = List()
252 debug = Bool(False)
308 debug = Bool(False)
253 _spin_thread = Any()
309 _spin_thread = Any()
254 _stop_spinning = Any()
310 _stop_spinning = Any()
255
311
256 profile=Unicode()
312 profile=Unicode()
257 def _profile_default(self):
313 def _profile_default(self):
258 if BaseIPythonApplication.initialized():
314 if BaseIPythonApplication.initialized():
259 # an IPython app *might* be running, try to get its profile
315 # an IPython app *might* be running, try to get its profile
260 try:
316 try:
261 return BaseIPythonApplication.instance().profile
317 return BaseIPythonApplication.instance().profile
262 except (AttributeError, MultipleInstanceError):
318 except (AttributeError, MultipleInstanceError):
263 # could be a *different* subclass of config.Application,
319 # could be a *different* subclass of config.Application,
264 # which would raise one of these two errors.
320 # which would raise one of these two errors.
265 return u'default'
321 return u'default'
266 else:
322 else:
267 return u'default'
323 return u'default'
268
324
269
325
270 _outstanding_dict = Instance('collections.defaultdict', (set,))
326 _outstanding_dict = Instance('collections.defaultdict', (set,))
271 _ids = List()
327 _ids = List()
272 _connected=Bool(False)
328 _connected=Bool(False)
273 _ssh=Bool(False)
329 _ssh=Bool(False)
274 _context = Instance('zmq.Context')
330 _context = Instance('zmq.Context')
275 _config = Dict()
331 _config = Dict()
276 _engines=Instance(util.ReverseDict, (), {})
332 _engines=Instance(util.ReverseDict, (), {})
277 # _hub_socket=Instance('zmq.Socket')
333 # _hub_socket=Instance('zmq.Socket')
278 _query_socket=Instance('zmq.Socket')
334 _query_socket=Instance('zmq.Socket')
279 _control_socket=Instance('zmq.Socket')
335 _control_socket=Instance('zmq.Socket')
280 _iopub_socket=Instance('zmq.Socket')
336 _iopub_socket=Instance('zmq.Socket')
281 _notification_socket=Instance('zmq.Socket')
337 _notification_socket=Instance('zmq.Socket')
282 _mux_socket=Instance('zmq.Socket')
338 _mux_socket=Instance('zmq.Socket')
283 _task_socket=Instance('zmq.Socket')
339 _task_socket=Instance('zmq.Socket')
284 _task_scheme=Unicode()
340 _task_scheme=Unicode()
285 _closed = False
341 _closed = False
286 _ignored_control_replies=Integer(0)
342 _ignored_control_replies=Integer(0)
287 _ignored_hub_replies=Integer(0)
343 _ignored_hub_replies=Integer(0)
288
344
289 def __new__(self, *args, **kw):
345 def __new__(self, *args, **kw):
290 # don't raise on positional args
346 # don't raise on positional args
291 return HasTraits.__new__(self, **kw)
347 return HasTraits.__new__(self, **kw)
292
348
293 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
349 def __init__(self, url_or_file=None, profile=None, profile_dir=None, ipython_dir=None,
294 context=None, debug=False, exec_key=None,
350 context=None, debug=False, exec_key=None,
295 sshserver=None, sshkey=None, password=None, paramiko=None,
351 sshserver=None, sshkey=None, password=None, paramiko=None,
296 timeout=10, **extra_args
352 timeout=10, **extra_args
297 ):
353 ):
298 if profile:
354 if profile:
299 super(Client, self).__init__(debug=debug, profile=profile)
355 super(Client, self).__init__(debug=debug, profile=profile)
300 else:
356 else:
301 super(Client, self).__init__(debug=debug)
357 super(Client, self).__init__(debug=debug)
302 if context is None:
358 if context is None:
303 context = zmq.Context.instance()
359 context = zmq.Context.instance()
304 self._context = context
360 self._context = context
305 self._stop_spinning = Event()
361 self._stop_spinning = Event()
306
362
307 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
363 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
308 if self._cd is not None:
364 if self._cd is not None:
309 if url_or_file is None:
365 if url_or_file is None:
310 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
366 url_or_file = pjoin(self._cd.security_dir, 'ipcontroller-client.json')
311 assert url_or_file is not None, "I can't find enough information to connect to a hub!"\
367 if url_or_file is None:
312 " Please specify at least one of url_or_file or profile."
368 raise ValueError(
369 "I can't find enough information to connect to a hub!"
370 " Please specify at least one of url_or_file or profile."
371 )
313
372
314 if not util.is_url(url_or_file):
373 if not util.is_url(url_or_file):
315 # it's not a url, try for a file
374 # it's not a url, try for a file
316 if not os.path.exists(url_or_file):
375 if not os.path.exists(url_or_file):
317 if self._cd:
376 if self._cd:
318 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
377 url_or_file = os.path.join(self._cd.security_dir, url_or_file)
319 assert os.path.exists(url_or_file), "Not a valid connection file or url: %r"%url_or_file
378 if not os.path.exists(url_or_file):
379 raise IOError("Connection file not found: %r" % url_or_file)
320 with open(url_or_file) as f:
380 with open(url_or_file) as f:
321 cfg = json.loads(f.read())
381 cfg = json.loads(f.read())
322 else:
382 else:
323 cfg = {'url':url_or_file}
383 cfg = {'url':url_or_file}
324
384
325 # sync defaults from args, json:
385 # sync defaults from args, json:
326 if sshserver:
386 if sshserver:
327 cfg['ssh'] = sshserver
387 cfg['ssh'] = sshserver
328 if exec_key:
388 if exec_key:
329 cfg['exec_key'] = exec_key
389 cfg['exec_key'] = exec_key
330 exec_key = cfg['exec_key']
390 exec_key = cfg['exec_key']
331 location = cfg.setdefault('location', None)
391 location = cfg.setdefault('location', None)
332 cfg['url'] = util.disambiguate_url(cfg['url'], location)
392 cfg['url'] = util.disambiguate_url(cfg['url'], location)
333 url = cfg['url']
393 url = cfg['url']
334 proto,addr,port = util.split_url(url)
394 proto,addr,port = util.split_url(url)
335 if location is not None and addr == '127.0.0.1':
395 if location is not None and addr == '127.0.0.1':
336 # location specified, and connection is expected to be local
396 # location specified, and connection is expected to be local
337 if location not in LOCAL_IPS and not sshserver:
397 if location not in LOCAL_IPS and not sshserver:
338 # load ssh from JSON *only* if the controller is not on
398 # load ssh from JSON *only* if the controller is not on
339 # this machine
399 # this machine
340 sshserver=cfg['ssh']
400 sshserver=cfg['ssh']
341 if location not in LOCAL_IPS and not sshserver:
401 if location not in LOCAL_IPS and not sshserver:
342 # warn if no ssh specified, but SSH is probably needed
402 # warn if no ssh specified, but SSH is probably needed
343 # This is only a warning, because the most likely cause
403 # This is only a warning, because the most likely cause
344 # is a local Controller on a laptop whose IP is dynamic
404 # is a local Controller on a laptop whose IP is dynamic
345 warnings.warn("""
405 warnings.warn("""
346 Controller appears to be listening on localhost, but not on this machine.
406 Controller appears to be listening on localhost, but not on this machine.
347 If this is true, you should specify Client(...,sshserver='you@%s')
407 If this is true, you should specify Client(...,sshserver='you@%s')
348 or instruct your controller to listen on an external IP."""%location,
408 or instruct your controller to listen on an external IP."""%location,
349 RuntimeWarning)
409 RuntimeWarning)
350 elif not sshserver:
410 elif not sshserver:
351 # otherwise sync with cfg
411 # otherwise sync with cfg
352 sshserver = cfg['ssh']
412 sshserver = cfg['ssh']
353
413
354 self._config = cfg
414 self._config = cfg
355
415
356 self._ssh = bool(sshserver or sshkey or password)
416 self._ssh = bool(sshserver or sshkey or password)
357 if self._ssh and sshserver is None:
417 if self._ssh and sshserver is None:
358 # default to ssh via localhost
418 # default to ssh via localhost
359 sshserver = url.split('://')[1].split(':')[0]
419 sshserver = url.split('://')[1].split(':')[0]
360 if self._ssh and password is None:
420 if self._ssh and password is None:
361 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
421 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
362 password=False
422 password=False
363 else:
423 else:
364 password = getpass("SSH Password for %s: "%sshserver)
424 password = getpass("SSH Password for %s: "%sshserver)
365 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
425 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
366
426
367 # configure and construct the session
427 # configure and construct the session
368 if exec_key is not None:
428 if exec_key is not None:
369 if os.path.isfile(exec_key):
429 if os.path.isfile(exec_key):
370 extra_args['keyfile'] = exec_key
430 extra_args['keyfile'] = exec_key
371 else:
431 else:
372 exec_key = util.asbytes(exec_key)
432 exec_key = cast_bytes(exec_key)
373 extra_args['key'] = exec_key
433 extra_args['key'] = exec_key
374 self.session = Session(**extra_args)
434 self.session = Session(**extra_args)
375
435
376 self._query_socket = self._context.socket(zmq.DEALER)
436 self._query_socket = self._context.socket(zmq.DEALER)
377 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
437 self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
378 if self._ssh:
438 if self._ssh:
379 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
439 tunnel.tunnel_connection(self._query_socket, url, sshserver, **ssh_kwargs)
380 else:
440 else:
381 self._query_socket.connect(url)
441 self._query_socket.connect(url)
382
442
383 self.session.debug = self.debug
443 self.session.debug = self.debug
384
444
385 self._notification_handlers = {'registration_notification' : self._register_engine,
445 self._notification_handlers = {'registration_notification' : self._register_engine,
386 'unregistration_notification' : self._unregister_engine,
446 'unregistration_notification' : self._unregister_engine,
387 'shutdown_notification' : lambda msg: self.close(),
447 'shutdown_notification' : lambda msg: self.close(),
388 }
448 }
389 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
449 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
390 'apply_reply' : self._handle_apply_reply}
450 'apply_reply' : self._handle_apply_reply}
391 self._connect(sshserver, ssh_kwargs, timeout)
451 self._connect(sshserver, ssh_kwargs, timeout)
392
452
393 def __del__(self):
453 def __del__(self):
394 """cleanup sockets, but _not_ context."""
454 """cleanup sockets, but _not_ context."""
395 self.close()
455 self.close()
396
456
397 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
457 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
398 if ipython_dir is None:
458 if ipython_dir is None:
399 ipython_dir = get_ipython_dir()
459 ipython_dir = get_ipython_dir()
400 if profile_dir is not None:
460 if profile_dir is not None:
401 try:
461 try:
402 self._cd = ProfileDir.find_profile_dir(profile_dir)
462 self._cd = ProfileDir.find_profile_dir(profile_dir)
403 return
463 return
404 except ProfileDirError:
464 except ProfileDirError:
405 pass
465 pass
406 elif profile is not None:
466 elif profile is not None:
407 try:
467 try:
408 self._cd = ProfileDir.find_profile_dir_by_name(
468 self._cd = ProfileDir.find_profile_dir_by_name(
409 ipython_dir, profile)
469 ipython_dir, profile)
410 return
470 return
411 except ProfileDirError:
471 except ProfileDirError:
412 pass
472 pass
413 self._cd = None
473 self._cd = None
414
474
415 def _update_engines(self, engines):
475 def _update_engines(self, engines):
416 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
476 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
417 for k,v in engines.iteritems():
477 for k,v in engines.iteritems():
418 eid = int(k)
478 eid = int(k)
419 self._engines[eid] = v
479 self._engines[eid] = v
420 self._ids.append(eid)
480 self._ids.append(eid)
421 self._ids = sorted(self._ids)
481 self._ids = sorted(self._ids)
422 if sorted(self._engines.keys()) != range(len(self._engines)) and \
482 if sorted(self._engines.keys()) != range(len(self._engines)) and \
423 self._task_scheme == 'pure' and self._task_socket:
483 self._task_scheme == 'pure' and self._task_socket:
424 self._stop_scheduling_tasks()
484 self._stop_scheduling_tasks()
425
485
426 def _stop_scheduling_tasks(self):
486 def _stop_scheduling_tasks(self):
427 """Stop scheduling tasks because an engine has been unregistered
487 """Stop scheduling tasks because an engine has been unregistered
428 from a pure ZMQ scheduler.
488 from a pure ZMQ scheduler.
429 """
489 """
430 self._task_socket.close()
490 self._task_socket.close()
431 self._task_socket = None
491 self._task_socket = None
432 msg = "An engine has been unregistered, and we are using pure " +\
492 msg = "An engine has been unregistered, and we are using pure " +\
433 "ZMQ task scheduling. Task farming will be disabled."
493 "ZMQ task scheduling. Task farming will be disabled."
434 if self.outstanding:
494 if self.outstanding:
435 msg += " If you were running tasks when this happened, " +\
495 msg += " If you were running tasks when this happened, " +\
436 "some `outstanding` msg_ids may never resolve."
496 "some `outstanding` msg_ids may never resolve."
437 warnings.warn(msg, RuntimeWarning)
497 warnings.warn(msg, RuntimeWarning)
438
498
439 def _build_targets(self, targets):
499 def _build_targets(self, targets):
440 """Turn valid target IDs or 'all' into two lists:
500 """Turn valid target IDs or 'all' into two lists:
441 (int_ids, uuids).
501 (int_ids, uuids).
442 """
502 """
443 if not self._ids:
503 if not self._ids:
444 # flush notification socket if no engines yet, just in case
504 # flush notification socket if no engines yet, just in case
445 if not self.ids:
505 if not self.ids:
446 raise error.NoEnginesRegistered("Can't build targets without any engines")
506 raise error.NoEnginesRegistered("Can't build targets without any engines")
447
507
448 if targets is None:
508 if targets is None:
449 targets = self._ids
509 targets = self._ids
450 elif isinstance(targets, basestring):
510 elif isinstance(targets, basestring):
451 if targets.lower() == 'all':
511 if targets.lower() == 'all':
452 targets = self._ids
512 targets = self._ids
453 else:
513 else:
454 raise TypeError("%r not valid str target, must be 'all'"%(targets))
514 raise TypeError("%r not valid str target, must be 'all'"%(targets))
455 elif isinstance(targets, int):
515 elif isinstance(targets, int):
456 if targets < 0:
516 if targets < 0:
457 targets = self.ids[targets]
517 targets = self.ids[targets]
458 if targets not in self._ids:
518 if targets not in self._ids:
459 raise IndexError("No such engine: %i"%targets)
519 raise IndexError("No such engine: %i"%targets)
460 targets = [targets]
520 targets = [targets]
461
521
462 if isinstance(targets, slice):
522 if isinstance(targets, slice):
463 indices = range(len(self._ids))[targets]
523 indices = range(len(self._ids))[targets]
464 ids = self.ids
524 ids = self.ids
465 targets = [ ids[i] for i in indices ]
525 targets = [ ids[i] for i in indices ]
466
526
467 if not isinstance(targets, (tuple, list, xrange)):
527 if not isinstance(targets, (tuple, list, xrange)):
468 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
528 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
469
529
470 return [util.asbytes(self._engines[t]) for t in targets], list(targets)
530 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
471
531
472 def _connect(self, sshserver, ssh_kwargs, timeout):
532 def _connect(self, sshserver, ssh_kwargs, timeout):
473 """setup all our socket connections to the cluster. This is called from
533 """setup all our socket connections to the cluster. This is called from
474 __init__."""
534 __init__."""
475
535
476 # Maybe allow reconnecting?
536 # Maybe allow reconnecting?
477 if self._connected:
537 if self._connected:
478 return
538 return
479 self._connected=True
539 self._connected=True
480
540
481 def connect_socket(s, url):
541 def connect_socket(s, url):
482 url = util.disambiguate_url(url, self._config['location'])
542 url = util.disambiguate_url(url, self._config['location'])
483 if self._ssh:
543 if self._ssh:
484 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
544 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
485 else:
545 else:
486 return s.connect(url)
546 return s.connect(url)
487
547
488 self.session.send(self._query_socket, 'connection_request')
548 self.session.send(self._query_socket, 'connection_request')
489 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
549 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
490 poller = zmq.Poller()
550 poller = zmq.Poller()
491 poller.register(self._query_socket, zmq.POLLIN)
551 poller.register(self._query_socket, zmq.POLLIN)
492 # poll expects milliseconds, timeout is seconds
552 # poll expects milliseconds, timeout is seconds
493 evts = poller.poll(timeout*1000)
553 evts = poller.poll(timeout*1000)
494 if not evts:
554 if not evts:
495 raise error.TimeoutError("Hub connection request timed out")
555 raise error.TimeoutError("Hub connection request timed out")
496 idents,msg = self.session.recv(self._query_socket,mode=0)
556 idents,msg = self.session.recv(self._query_socket,mode=0)
497 if self.debug:
557 if self.debug:
498 pprint(msg)
558 pprint(msg)
499 msg = Message(msg)
559 msg = Message(msg)
500 content = msg.content
560 content = msg.content
501 self._config['registration'] = dict(content)
561 self._config['registration'] = dict(content)
502 if content.status == 'ok':
562 if content.status == 'ok':
503 ident = self.session.bsession
563 ident = self.session.bsession
504 if content.mux:
564 if content.mux:
505 self._mux_socket = self._context.socket(zmq.DEALER)
565 self._mux_socket = self._context.socket(zmq.DEALER)
506 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
566 self._mux_socket.setsockopt(zmq.IDENTITY, ident)
507 connect_socket(self._mux_socket, content.mux)
567 connect_socket(self._mux_socket, content.mux)
508 if content.task:
568 if content.task:
509 self._task_scheme, task_addr = content.task
569 self._task_scheme, task_addr = content.task
510 self._task_socket = self._context.socket(zmq.DEALER)
570 self._task_socket = self._context.socket(zmq.DEALER)
511 self._task_socket.setsockopt(zmq.IDENTITY, ident)
571 self._task_socket.setsockopt(zmq.IDENTITY, ident)
512 connect_socket(self._task_socket, task_addr)
572 connect_socket(self._task_socket, task_addr)
513 if content.notification:
573 if content.notification:
514 self._notification_socket = self._context.socket(zmq.SUB)
574 self._notification_socket = self._context.socket(zmq.SUB)
515 connect_socket(self._notification_socket, content.notification)
575 connect_socket(self._notification_socket, content.notification)
516 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
576 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
517 # if content.query:
577 # if content.query:
518 # self._query_socket = self._context.socket(zmq.DEALER)
578 # self._query_socket = self._context.socket(zmq.DEALER)
519 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
579 # self._query_socket.setsockopt(zmq.IDENTITY, self.session.bsession)
520 # connect_socket(self._query_socket, content.query)
580 # connect_socket(self._query_socket, content.query)
521 if content.control:
581 if content.control:
522 self._control_socket = self._context.socket(zmq.DEALER)
582 self._control_socket = self._context.socket(zmq.DEALER)
523 self._control_socket.setsockopt(zmq.IDENTITY, ident)
583 self._control_socket.setsockopt(zmq.IDENTITY, ident)
524 connect_socket(self._control_socket, content.control)
584 connect_socket(self._control_socket, content.control)
525 if content.iopub:
585 if content.iopub:
526 self._iopub_socket = self._context.socket(zmq.SUB)
586 self._iopub_socket = self._context.socket(zmq.SUB)
527 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
587 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
528 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
588 self._iopub_socket.setsockopt(zmq.IDENTITY, ident)
529 connect_socket(self._iopub_socket, content.iopub)
589 connect_socket(self._iopub_socket, content.iopub)
530 self._update_engines(dict(content.engines))
590 self._update_engines(dict(content.engines))
531 else:
591 else:
532 self._connected = False
592 self._connected = False
533 raise Exception("Failed to connect!")
593 raise Exception("Failed to connect!")
534
594
535 #--------------------------------------------------------------------------
595 #--------------------------------------------------------------------------
536 # handlers and callbacks for incoming messages
596 # handlers and callbacks for incoming messages
537 #--------------------------------------------------------------------------
597 #--------------------------------------------------------------------------
538
598
539 def _unwrap_exception(self, content):
599 def _unwrap_exception(self, content):
540 """unwrap exception, and remap engine_id to int."""
600 """unwrap exception, and remap engine_id to int."""
541 e = error.unwrap_exception(content)
601 e = error.unwrap_exception(content)
542 # print e.traceback
602 # print e.traceback
543 if e.engine_info:
603 if e.engine_info:
544 e_uuid = e.engine_info['engine_uuid']
604 e_uuid = e.engine_info['engine_uuid']
545 eid = self._engines[e_uuid]
605 eid = self._engines[e_uuid]
546 e.engine_info['engine_id'] = eid
606 e.engine_info['engine_id'] = eid
547 return e
607 return e
548
608
549 def _extract_metadata(self, header, parent, content):
609 def _extract_metadata(self, header, parent, content):
550 md = {'msg_id' : parent['msg_id'],
610 md = {'msg_id' : parent['msg_id'],
551 'received' : datetime.now(),
611 'received' : datetime.now(),
552 'engine_uuid' : header.get('engine', None),
612 'engine_uuid' : header.get('engine', None),
553 'follow' : parent.get('follow', []),
613 'follow' : parent.get('follow', []),
554 'after' : parent.get('after', []),
614 'after' : parent.get('after', []),
555 'status' : content['status'],
615 'status' : content['status'],
556 }
616 }
557
617
558 if md['engine_uuid'] is not None:
618 if md['engine_uuid'] is not None:
559 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
619 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
560
620
561 if 'date' in parent:
621 if 'date' in parent:
562 md['submitted'] = parent['date']
622 md['submitted'] = parent['date']
563 if 'started' in header:
623 if 'started' in header:
564 md['started'] = header['started']
624 md['started'] = header['started']
565 if 'date' in header:
625 if 'date' in header:
566 md['completed'] = header['date']
626 md['completed'] = header['date']
567 return md
627 return md
568
628
569 def _register_engine(self, msg):
629 def _register_engine(self, msg):
570 """Register a new engine, and update our connection info."""
630 """Register a new engine, and update our connection info."""
571 content = msg['content']
631 content = msg['content']
572 eid = content['id']
632 eid = content['id']
573 d = {eid : content['queue']}
633 d = {eid : content['queue']}
574 self._update_engines(d)
634 self._update_engines(d)
575
635
576 def _unregister_engine(self, msg):
636 def _unregister_engine(self, msg):
577 """Unregister an engine that has died."""
637 """Unregister an engine that has died."""
578 content = msg['content']
638 content = msg['content']
579 eid = int(content['id'])
639 eid = int(content['id'])
580 if eid in self._ids:
640 if eid in self._ids:
581 self._ids.remove(eid)
641 self._ids.remove(eid)
582 uuid = self._engines.pop(eid)
642 uuid = self._engines.pop(eid)
583
643
584 self._handle_stranded_msgs(eid, uuid)
644 self._handle_stranded_msgs(eid, uuid)
585
645
586 if self._task_socket and self._task_scheme == 'pure':
646 if self._task_socket and self._task_scheme == 'pure':
587 self._stop_scheduling_tasks()
647 self._stop_scheduling_tasks()
588
648
589 def _handle_stranded_msgs(self, eid, uuid):
649 def _handle_stranded_msgs(self, eid, uuid):
590 """Handle messages known to be on an engine when the engine unregisters.
650 """Handle messages known to be on an engine when the engine unregisters.
591
651
592 It is possible that this will fire prematurely - that is, an engine will
652 It is possible that this will fire prematurely - that is, an engine will
593 go down after completing a result, and the client will be notified
653 go down after completing a result, and the client will be notified
594 of the unregistration and later receive the successful result.
654 of the unregistration and later receive the successful result.
595 """
655 """
596
656
597 outstanding = self._outstanding_dict[uuid]
657 outstanding = self._outstanding_dict[uuid]
598
658
599 for msg_id in list(outstanding):
659 for msg_id in list(outstanding):
600 if msg_id in self.results:
660 if msg_id in self.results:
601 # we already
661 # we already
602 continue
662 continue
603 try:
663 try:
604 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
664 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
605 except:
665 except:
606 content = error.wrap_exception()
666 content = error.wrap_exception()
607 # build a fake message:
667 # build a fake message:
608 parent = {}
668 parent = {}
609 header = {}
669 header = {}
610 parent['msg_id'] = msg_id
670 parent['msg_id'] = msg_id
611 header['engine'] = uuid
671 header['engine'] = uuid
612 header['date'] = datetime.now()
672 header['date'] = datetime.now()
613 msg = dict(parent_header=parent, header=header, content=content)
673 msg = dict(parent_header=parent, header=header, content=content)
614 self._handle_apply_reply(msg)
674 self._handle_apply_reply(msg)
615
675
616 def _handle_execute_reply(self, msg):
676 def _handle_execute_reply(self, msg):
617 """Save the reply to an execute_request into our results.
677 """Save the reply to an execute_request into our results.
618
678
619 execute messages are never actually used. apply is used instead.
679 execute messages are never actually used. apply is used instead.
620 """
680 """
621
681
622 parent = msg['parent_header']
682 parent = msg['parent_header']
623 msg_id = parent['msg_id']
683 msg_id = parent['msg_id']
624 if msg_id not in self.outstanding:
684 if msg_id not in self.outstanding:
625 if msg_id in self.history:
685 if msg_id in self.history:
626 print ("got stale result: %s"%msg_id)
686 print ("got stale result: %s"%msg_id)
627 else:
687 else:
628 print ("got unknown result: %s"%msg_id)
688 print ("got unknown result: %s"%msg_id)
629 else:
689 else:
630 self.outstanding.remove(msg_id)
690 self.outstanding.remove(msg_id)
631 self.results[msg_id] = self._unwrap_exception(msg['content'])
691
692 content = msg['content']
693 header = msg['header']
694
695 # construct metadata:
696 md = self.metadata[msg_id]
697 md.update(self._extract_metadata(header, parent, content))
698 # is this redundant?
699 self.metadata[msg_id] = md
700
701 e_outstanding = self._outstanding_dict[md['engine_uuid']]
702 if msg_id in e_outstanding:
703 e_outstanding.remove(msg_id)
704
705 # construct result:
706 if content['status'] == 'ok':
707 self.results[msg_id] = ExecuteReply(msg_id, content, md)
708 elif content['status'] == 'aborted':
709 self.results[msg_id] = error.TaskAborted(msg_id)
710 elif content['status'] == 'resubmitted':
711 # TODO: handle resubmission
712 pass
713 else:
714 self.results[msg_id] = self._unwrap_exception(content)
632
715
633 def _handle_apply_reply(self, msg):
716 def _handle_apply_reply(self, msg):
634 """Save the reply to an apply_request into our results."""
717 """Save the reply to an apply_request into our results."""
635 parent = msg['parent_header']
718 parent = msg['parent_header']
636 msg_id = parent['msg_id']
719 msg_id = parent['msg_id']
637 if msg_id not in self.outstanding:
720 if msg_id not in self.outstanding:
638 if msg_id in self.history:
721 if msg_id in self.history:
639 print ("got stale result: %s"%msg_id)
722 print ("got stale result: %s"%msg_id)
640 print self.results[msg_id]
723 print self.results[msg_id]
641 print msg
724 print msg
642 else:
725 else:
643 print ("got unknown result: %s"%msg_id)
726 print ("got unknown result: %s"%msg_id)
644 else:
727 else:
645 self.outstanding.remove(msg_id)
728 self.outstanding.remove(msg_id)
646 content = msg['content']
729 content = msg['content']
647 header = msg['header']
730 header = msg['header']
648
731
649 # construct metadata:
732 # construct metadata:
650 md = self.metadata[msg_id]
733 md = self.metadata[msg_id]
651 md.update(self._extract_metadata(header, parent, content))
734 md.update(self._extract_metadata(header, parent, content))
652 # is this redundant?
735 # is this redundant?
653 self.metadata[msg_id] = md
736 self.metadata[msg_id] = md
654
737
655 e_outstanding = self._outstanding_dict[md['engine_uuid']]
738 e_outstanding = self._outstanding_dict[md['engine_uuid']]
656 if msg_id in e_outstanding:
739 if msg_id in e_outstanding:
657 e_outstanding.remove(msg_id)
740 e_outstanding.remove(msg_id)
658
741
659 # construct result:
742 # construct result:
660 if content['status'] == 'ok':
743 if content['status'] == 'ok':
661 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
744 self.results[msg_id] = util.unserialize_object(msg['buffers'])[0]
662 elif content['status'] == 'aborted':
745 elif content['status'] == 'aborted':
663 self.results[msg_id] = error.TaskAborted(msg_id)
746 self.results[msg_id] = error.TaskAborted(msg_id)
664 elif content['status'] == 'resubmitted':
747 elif content['status'] == 'resubmitted':
665 # TODO: handle resubmission
748 # TODO: handle resubmission
666 pass
749 pass
667 else:
750 else:
668 self.results[msg_id] = self._unwrap_exception(content)
751 self.results[msg_id] = self._unwrap_exception(content)
669
752
670 def _flush_notifications(self):
753 def _flush_notifications(self):
671 """Flush notifications of engine registrations waiting
754 """Flush notifications of engine registrations waiting
672 in ZMQ queue."""
755 in ZMQ queue."""
673 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
756 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
674 while msg is not None:
757 while msg is not None:
675 if self.debug:
758 if self.debug:
676 pprint(msg)
759 pprint(msg)
677 msg_type = msg['header']['msg_type']
760 msg_type = msg['header']['msg_type']
678 handler = self._notification_handlers.get(msg_type, None)
761 handler = self._notification_handlers.get(msg_type, None)
679 if handler is None:
762 if handler is None:
680 raise Exception("Unhandled message type: %s"%msg.msg_type)
763 raise Exception("Unhandled message type: %s"%msg.msg_type)
681 else:
764 else:
682 handler(msg)
765 handler(msg)
683 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
766 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
684
767
685 def _flush_results(self, sock):
768 def _flush_results(self, sock):
686 """Flush task or queue results waiting in ZMQ queue."""
769 """Flush task or queue results waiting in ZMQ queue."""
687 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
770 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
688 while msg is not None:
771 while msg is not None:
689 if self.debug:
772 if self.debug:
690 pprint(msg)
773 pprint(msg)
691 msg_type = msg['header']['msg_type']
774 msg_type = msg['header']['msg_type']
692 handler = self._queue_handlers.get(msg_type, None)
775 handler = self._queue_handlers.get(msg_type, None)
693 if handler is None:
776 if handler is None:
694 raise Exception("Unhandled message type: %s"%msg.msg_type)
777 raise Exception("Unhandled message type: %s"%msg.msg_type)
695 else:
778 else:
696 handler(msg)
779 handler(msg)
697 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
780 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
698
781
699 def _flush_control(self, sock):
782 def _flush_control(self, sock):
700 """Flush replies from the control channel waiting
783 """Flush replies from the control channel waiting
701 in the ZMQ queue.
784 in the ZMQ queue.
702
785
703 Currently: ignore them."""
786 Currently: ignore them."""
704 if self._ignored_control_replies <= 0:
787 if self._ignored_control_replies <= 0:
705 return
788 return
706 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
789 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
707 while msg is not None:
790 while msg is not None:
708 self._ignored_control_replies -= 1
791 self._ignored_control_replies -= 1
709 if self.debug:
792 if self.debug:
710 pprint(msg)
793 pprint(msg)
711 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
794 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
712
795
713 def _flush_ignored_control(self):
796 def _flush_ignored_control(self):
714 """flush ignored control replies"""
797 """flush ignored control replies"""
715 while self._ignored_control_replies > 0:
798 while self._ignored_control_replies > 0:
716 self.session.recv(self._control_socket)
799 self.session.recv(self._control_socket)
717 self._ignored_control_replies -= 1
800 self._ignored_control_replies -= 1
718
801
719 def _flush_ignored_hub_replies(self):
802 def _flush_ignored_hub_replies(self):
720 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
803 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
721 while msg is not None:
804 while msg is not None:
722 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
805 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
723
806
724 def _flush_iopub(self, sock):
807 def _flush_iopub(self, sock):
725 """Flush replies from the iopub channel waiting
808 """Flush replies from the iopub channel waiting
726 in the ZMQ queue.
809 in the ZMQ queue.
727 """
810 """
728 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
811 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
729 while msg is not None:
812 while msg is not None:
730 if self.debug:
813 if self.debug:
731 pprint(msg)
814 pprint(msg)
732 parent = msg['parent_header']
815 parent = msg['parent_header']
733 # ignore IOPub messages with no parent.
816 # ignore IOPub messages with no parent.
734 # Caused by print statements or warnings from before the first execution.
817 # Caused by print statements or warnings from before the first execution.
735 if not parent:
818 if not parent:
736 continue
819 continue
737 msg_id = parent['msg_id']
820 msg_id = parent['msg_id']
738 content = msg['content']
821 content = msg['content']
739 header = msg['header']
822 header = msg['header']
740 msg_type = msg['header']['msg_type']
823 msg_type = msg['header']['msg_type']
741
824
742 # init metadata:
825 # init metadata:
743 md = self.metadata[msg_id]
826 md = self.metadata[msg_id]
744
827
745 if msg_type == 'stream':
828 if msg_type == 'stream':
746 name = content['name']
829 name = content['name']
747 s = md[name] or ''
830 s = md[name] or ''
748 md[name] = s + content['data']
831 md[name] = s + content['data']
749 elif msg_type == 'pyerr':
832 elif msg_type == 'pyerr':
750 md.update({'pyerr' : self._unwrap_exception(content)})
833 md.update({'pyerr' : self._unwrap_exception(content)})
751 elif msg_type == 'pyin':
834 elif msg_type == 'pyin':
752 md.update({'pyin' : content['code']})
835 md.update({'pyin' : content['code']})
836 elif msg_type == 'display_data':
837 md['outputs'].append(content)
838 elif msg_type == 'pyout':
839 md['pyout'] = content
753 else:
840 else:
754 md.update({msg_type : content.get('data', '')})
841 # unhandled msg_type (status, etc.)
842 pass
755
843
756 # reduntant?
844 # reduntant?
757 self.metadata[msg_id] = md
845 self.metadata[msg_id] = md
758
846
759 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
847 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
760
848
761 #--------------------------------------------------------------------------
849 #--------------------------------------------------------------------------
762 # len, getitem
850 # len, getitem
763 #--------------------------------------------------------------------------
851 #--------------------------------------------------------------------------
764
852
765 def __len__(self):
853 def __len__(self):
766 """len(client) returns # of engines."""
854 """len(client) returns # of engines."""
767 return len(self.ids)
855 return len(self.ids)
768
856
769 def __getitem__(self, key):
857 def __getitem__(self, key):
770 """index access returns DirectView multiplexer objects
858 """index access returns DirectView multiplexer objects
771
859
772 Must be int, slice, or list/tuple/xrange of ints"""
860 Must be int, slice, or list/tuple/xrange of ints"""
773 if not isinstance(key, (int, slice, tuple, list, xrange)):
861 if not isinstance(key, (int, slice, tuple, list, xrange)):
774 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
862 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
775 else:
863 else:
776 return self.direct_view(key)
864 return self.direct_view(key)
777
865
778 #--------------------------------------------------------------------------
866 #--------------------------------------------------------------------------
779 # Begin public methods
867 # Begin public methods
780 #--------------------------------------------------------------------------
868 #--------------------------------------------------------------------------
781
869
782 @property
870 @property
783 def ids(self):
871 def ids(self):
784 """Always up-to-date ids property."""
872 """Always up-to-date ids property."""
785 self._flush_notifications()
873 self._flush_notifications()
786 # always copy:
874 # always copy:
787 return list(self._ids)
875 return list(self._ids)
788
876
789 def close(self):
877 def close(self):
790 if self._closed:
878 if self._closed:
791 return
879 return
792 self.stop_spin_thread()
880 self.stop_spin_thread()
793 snames = filter(lambda n: n.endswith('socket'), dir(self))
881 snames = filter(lambda n: n.endswith('socket'), dir(self))
794 for socket in map(lambda name: getattr(self, name), snames):
882 for socket in map(lambda name: getattr(self, name), snames):
795 if isinstance(socket, zmq.Socket) and not socket.closed:
883 if isinstance(socket, zmq.Socket) and not socket.closed:
796 socket.close()
884 socket.close()
797 self._closed = True
885 self._closed = True
798
886
799 def _spin_every(self, interval=1):
887 def _spin_every(self, interval=1):
800 """target func for use in spin_thread"""
888 """target func for use in spin_thread"""
801 while True:
889 while True:
802 if self._stop_spinning.is_set():
890 if self._stop_spinning.is_set():
803 return
891 return
804 time.sleep(interval)
892 time.sleep(interval)
805 self.spin()
893 self.spin()
806
894
807 def spin_thread(self, interval=1):
895 def spin_thread(self, interval=1):
808 """call Client.spin() in a background thread on some regular interval
896 """call Client.spin() in a background thread on some regular interval
809
897
810 This helps ensure that messages don't pile up too much in the zmq queue
898 This helps ensure that messages don't pile up too much in the zmq queue
811 while you are working on other things, or just leaving an idle terminal.
899 while you are working on other things, or just leaving an idle terminal.
812
900
813 It also helps limit potential padding of the `received` timestamp
901 It also helps limit potential padding of the `received` timestamp
814 on AsyncResult objects, used for timings.
902 on AsyncResult objects, used for timings.
815
903
816 Parameters
904 Parameters
817 ----------
905 ----------
818
906
819 interval : float, optional
907 interval : float, optional
820 The interval on which to spin the client in the background thread
908 The interval on which to spin the client in the background thread
821 (simply passed to time.sleep).
909 (simply passed to time.sleep).
822
910
823 Notes
911 Notes
824 -----
912 -----
825
913
826 For precision timing, you may want to use this method to put a bound
914 For precision timing, you may want to use this method to put a bound
827 on the jitter (in seconds) in `received` timestamps used
915 on the jitter (in seconds) in `received` timestamps used
828 in AsyncResult.wall_time.
916 in AsyncResult.wall_time.
829
917
830 """
918 """
831 if self._spin_thread is not None:
919 if self._spin_thread is not None:
832 self.stop_spin_thread()
920 self.stop_spin_thread()
833 self._stop_spinning.clear()
921 self._stop_spinning.clear()
834 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
922 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
835 self._spin_thread.daemon = True
923 self._spin_thread.daemon = True
836 self._spin_thread.start()
924 self._spin_thread.start()
837
925
838 def stop_spin_thread(self):
926 def stop_spin_thread(self):
839 """stop background spin_thread, if any"""
927 """stop background spin_thread, if any"""
840 if self._spin_thread is not None:
928 if self._spin_thread is not None:
841 self._stop_spinning.set()
929 self._stop_spinning.set()
842 self._spin_thread.join()
930 self._spin_thread.join()
843 self._spin_thread = None
931 self._spin_thread = None
844
932
845 def spin(self):
933 def spin(self):
846 """Flush any registration notifications and execution results
934 """Flush any registration notifications and execution results
847 waiting in the ZMQ queue.
935 waiting in the ZMQ queue.
848 """
936 """
849 if self._notification_socket:
937 if self._notification_socket:
850 self._flush_notifications()
938 self._flush_notifications()
939 if self._iopub_socket:
940 self._flush_iopub(self._iopub_socket)
851 if self._mux_socket:
941 if self._mux_socket:
852 self._flush_results(self._mux_socket)
942 self._flush_results(self._mux_socket)
853 if self._task_socket:
943 if self._task_socket:
854 self._flush_results(self._task_socket)
944 self._flush_results(self._task_socket)
855 if self._control_socket:
945 if self._control_socket:
856 self._flush_control(self._control_socket)
946 self._flush_control(self._control_socket)
857 if self._iopub_socket:
858 self._flush_iopub(self._iopub_socket)
859 if self._query_socket:
947 if self._query_socket:
860 self._flush_ignored_hub_replies()
948 self._flush_ignored_hub_replies()
861
949
862 def wait(self, jobs=None, timeout=-1):
950 def wait(self, jobs=None, timeout=-1):
863 """waits on one or more `jobs`, for up to `timeout` seconds.
951 """waits on one or more `jobs`, for up to `timeout` seconds.
864
952
865 Parameters
953 Parameters
866 ----------
954 ----------
867
955
868 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
956 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
869 ints are indices to self.history
957 ints are indices to self.history
870 strs are msg_ids
958 strs are msg_ids
871 default: wait on all outstanding messages
959 default: wait on all outstanding messages
872 timeout : float
960 timeout : float
873 a time in seconds, after which to give up.
961 a time in seconds, after which to give up.
874 default is -1, which means no timeout
962 default is -1, which means no timeout
875
963
876 Returns
964 Returns
877 -------
965 -------
878
966
879 True : when all msg_ids are done
967 True : when all msg_ids are done
880 False : timeout reached, some msg_ids still outstanding
968 False : timeout reached, some msg_ids still outstanding
881 """
969 """
882 tic = time.time()
970 tic = time.time()
883 if jobs is None:
971 if jobs is None:
884 theids = self.outstanding
972 theids = self.outstanding
885 else:
973 else:
886 if isinstance(jobs, (int, basestring, AsyncResult)):
974 if isinstance(jobs, (int, basestring, AsyncResult)):
887 jobs = [jobs]
975 jobs = [jobs]
888 theids = set()
976 theids = set()
889 for job in jobs:
977 for job in jobs:
890 if isinstance(job, int):
978 if isinstance(job, int):
891 # index access
979 # index access
892 job = self.history[job]
980 job = self.history[job]
893 elif isinstance(job, AsyncResult):
981 elif isinstance(job, AsyncResult):
894 map(theids.add, job.msg_ids)
982 map(theids.add, job.msg_ids)
895 continue
983 continue
896 theids.add(job)
984 theids.add(job)
897 if not theids.intersection(self.outstanding):
985 if not theids.intersection(self.outstanding):
898 return True
986 return True
899 self.spin()
987 self.spin()
900 while theids.intersection(self.outstanding):
988 while theids.intersection(self.outstanding):
901 if timeout >= 0 and ( time.time()-tic ) > timeout:
989 if timeout >= 0 and ( time.time()-tic ) > timeout:
902 break
990 break
903 time.sleep(1e-3)
991 time.sleep(1e-3)
904 self.spin()
992 self.spin()
905 return len(theids.intersection(self.outstanding)) == 0
993 return len(theids.intersection(self.outstanding)) == 0
906
994
907 #--------------------------------------------------------------------------
995 #--------------------------------------------------------------------------
908 # Control methods
996 # Control methods
909 #--------------------------------------------------------------------------
997 #--------------------------------------------------------------------------
910
998
911 @spin_first
999 @spin_first
912 def clear(self, targets=None, block=None):
1000 def clear(self, targets=None, block=None):
913 """Clear the namespace in target(s)."""
1001 """Clear the namespace in target(s)."""
914 block = self.block if block is None else block
1002 block = self.block if block is None else block
915 targets = self._build_targets(targets)[0]
1003 targets = self._build_targets(targets)[0]
916 for t in targets:
1004 for t in targets:
917 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1005 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
918 error = False
1006 error = False
919 if block:
1007 if block:
920 self._flush_ignored_control()
1008 self._flush_ignored_control()
921 for i in range(len(targets)):
1009 for i in range(len(targets)):
922 idents,msg = self.session.recv(self._control_socket,0)
1010 idents,msg = self.session.recv(self._control_socket,0)
923 if self.debug:
1011 if self.debug:
924 pprint(msg)
1012 pprint(msg)
925 if msg['content']['status'] != 'ok':
1013 if msg['content']['status'] != 'ok':
926 error = self._unwrap_exception(msg['content'])
1014 error = self._unwrap_exception(msg['content'])
927 else:
1015 else:
928 self._ignored_control_replies += len(targets)
1016 self._ignored_control_replies += len(targets)
929 if error:
1017 if error:
930 raise error
1018 raise error
931
1019
932
1020
933 @spin_first
1021 @spin_first
934 def abort(self, jobs=None, targets=None, block=None):
1022 def abort(self, jobs=None, targets=None, block=None):
935 """Abort specific jobs from the execution queues of target(s).
1023 """Abort specific jobs from the execution queues of target(s).
936
1024
937 This is a mechanism to prevent jobs that have already been submitted
1025 This is a mechanism to prevent jobs that have already been submitted
938 from executing.
1026 from executing.
939
1027
940 Parameters
1028 Parameters
941 ----------
1029 ----------
942
1030
943 jobs : msg_id, list of msg_ids, or AsyncResult
1031 jobs : msg_id, list of msg_ids, or AsyncResult
944 The jobs to be aborted
1032 The jobs to be aborted
945
1033
946 If unspecified/None: abort all outstanding jobs.
1034 If unspecified/None: abort all outstanding jobs.
947
1035
948 """
1036 """
949 block = self.block if block is None else block
1037 block = self.block if block is None else block
950 jobs = jobs if jobs is not None else list(self.outstanding)
1038 jobs = jobs if jobs is not None else list(self.outstanding)
951 targets = self._build_targets(targets)[0]
1039 targets = self._build_targets(targets)[0]
952
1040
953 msg_ids = []
1041 msg_ids = []
954 if isinstance(jobs, (basestring,AsyncResult)):
1042 if isinstance(jobs, (basestring,AsyncResult)):
955 jobs = [jobs]
1043 jobs = [jobs]
956 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1044 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
957 if bad_ids:
1045 if bad_ids:
958 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1046 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
959 for j in jobs:
1047 for j in jobs:
960 if isinstance(j, AsyncResult):
1048 if isinstance(j, AsyncResult):
961 msg_ids.extend(j.msg_ids)
1049 msg_ids.extend(j.msg_ids)
962 else:
1050 else:
963 msg_ids.append(j)
1051 msg_ids.append(j)
964 content = dict(msg_ids=msg_ids)
1052 content = dict(msg_ids=msg_ids)
965 for t in targets:
1053 for t in targets:
966 self.session.send(self._control_socket, 'abort_request',
1054 self.session.send(self._control_socket, 'abort_request',
967 content=content, ident=t)
1055 content=content, ident=t)
968 error = False
1056 error = False
969 if block:
1057 if block:
970 self._flush_ignored_control()
1058 self._flush_ignored_control()
971 for i in range(len(targets)):
1059 for i in range(len(targets)):
972 idents,msg = self.session.recv(self._control_socket,0)
1060 idents,msg = self.session.recv(self._control_socket,0)
973 if self.debug:
1061 if self.debug:
974 pprint(msg)
1062 pprint(msg)
975 if msg['content']['status'] != 'ok':
1063 if msg['content']['status'] != 'ok':
976 error = self._unwrap_exception(msg['content'])
1064 error = self._unwrap_exception(msg['content'])
977 else:
1065 else:
978 self._ignored_control_replies += len(targets)
1066 self._ignored_control_replies += len(targets)
979 if error:
1067 if error:
980 raise error
1068 raise error
981
1069
982 @spin_first
1070 @spin_first
983 def shutdown(self, targets=None, restart=False, hub=False, block=None):
1071 def shutdown(self, targets=None, restart=False, hub=False, block=None):
984 """Terminates one or more engine processes, optionally including the hub."""
1072 """Terminates one or more engine processes, optionally including the hub."""
985 block = self.block if block is None else block
1073 block = self.block if block is None else block
986 if hub:
1074 if hub:
987 targets = 'all'
1075 targets = 'all'
988 targets = self._build_targets(targets)[0]
1076 targets = self._build_targets(targets)[0]
989 for t in targets:
1077 for t in targets:
990 self.session.send(self._control_socket, 'shutdown_request',
1078 self.session.send(self._control_socket, 'shutdown_request',
991 content={'restart':restart},ident=t)
1079 content={'restart':restart},ident=t)
992 error = False
1080 error = False
993 if block or hub:
1081 if block or hub:
994 self._flush_ignored_control()
1082 self._flush_ignored_control()
995 for i in range(len(targets)):
1083 for i in range(len(targets)):
996 idents,msg = self.session.recv(self._control_socket, 0)
1084 idents,msg = self.session.recv(self._control_socket, 0)
997 if self.debug:
1085 if self.debug:
998 pprint(msg)
1086 pprint(msg)
999 if msg['content']['status'] != 'ok':
1087 if msg['content']['status'] != 'ok':
1000 error = self._unwrap_exception(msg['content'])
1088 error = self._unwrap_exception(msg['content'])
1001 else:
1089 else:
1002 self._ignored_control_replies += len(targets)
1090 self._ignored_control_replies += len(targets)
1003
1091
1004 if hub:
1092 if hub:
1005 time.sleep(0.25)
1093 time.sleep(0.25)
1006 self.session.send(self._query_socket, 'shutdown_request')
1094 self.session.send(self._query_socket, 'shutdown_request')
1007 idents,msg = self.session.recv(self._query_socket, 0)
1095 idents,msg = self.session.recv(self._query_socket, 0)
1008 if self.debug:
1096 if self.debug:
1009 pprint(msg)
1097 pprint(msg)
1010 if msg['content']['status'] != 'ok':
1098 if msg['content']['status'] != 'ok':
1011 error = self._unwrap_exception(msg['content'])
1099 error = self._unwrap_exception(msg['content'])
1012
1100
1013 if error:
1101 if error:
1014 raise error
1102 raise error
1015
1103
1016 #--------------------------------------------------------------------------
1104 #--------------------------------------------------------------------------
1017 # Execution related methods
1105 # Execution related methods
1018 #--------------------------------------------------------------------------
1106 #--------------------------------------------------------------------------
1019
1107
1020 def _maybe_raise(self, result):
1108 def _maybe_raise(self, result):
1021 """wrapper for maybe raising an exception if apply failed."""
1109 """wrapper for maybe raising an exception if apply failed."""
1022 if isinstance(result, error.RemoteError):
1110 if isinstance(result, error.RemoteError):
1023 raise result
1111 raise result
1024
1112
1025 return result
1113 return result
1026
1114
1027 def send_apply_message(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
1115 def send_apply_request(self, socket, f, args=None, kwargs=None, subheader=None, track=False,
1028 ident=None):
1116 ident=None):
1029 """construct and send an apply message via a socket.
1117 """construct and send an apply message via a socket.
1030
1118
1031 This is the principal method with which all engine execution is performed by views.
1119 This is the principal method with which all engine execution is performed by views.
1032 """
1120 """
1033
1121
1034 assert not self._closed, "cannot use me anymore, I'm closed!"
1122 if self._closed:
1123 raise RuntimeError("Client cannot be used after its sockets have been closed")
1124
1035 # defaults:
1125 # defaults:
1036 args = args if args is not None else []
1126 args = args if args is not None else []
1037 kwargs = kwargs if kwargs is not None else {}
1127 kwargs = kwargs if kwargs is not None else {}
1038 subheader = subheader if subheader is not None else {}
1128 subheader = subheader if subheader is not None else {}
1039
1129
1040 # validate arguments
1130 # validate arguments
1041 if not callable(f) and not isinstance(f, Reference):
1131 if not callable(f) and not isinstance(f, Reference):
1042 raise TypeError("f must be callable, not %s"%type(f))
1132 raise TypeError("f must be callable, not %s"%type(f))
1043 if not isinstance(args, (tuple, list)):
1133 if not isinstance(args, (tuple, list)):
1044 raise TypeError("args must be tuple or list, not %s"%type(args))
1134 raise TypeError("args must be tuple or list, not %s"%type(args))
1045 if not isinstance(kwargs, dict):
1135 if not isinstance(kwargs, dict):
1046 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1136 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1047 if not isinstance(subheader, dict):
1137 if not isinstance(subheader, dict):
1048 raise TypeError("subheader must be dict, not %s"%type(subheader))
1138 raise TypeError("subheader must be dict, not %s"%type(subheader))
1049
1139
1050 bufs = util.pack_apply_message(f,args,kwargs)
1140 bufs = util.pack_apply_message(f,args,kwargs)
1051
1141
1052 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1142 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1053 subheader=subheader, track=track)
1143 subheader=subheader, track=track)
1054
1144
1055 msg_id = msg['header']['msg_id']
1145 msg_id = msg['header']['msg_id']
1056 self.outstanding.add(msg_id)
1146 self.outstanding.add(msg_id)
1057 if ident:
1147 if ident:
1058 # possibly routed to a specific engine
1148 # possibly routed to a specific engine
1059 if isinstance(ident, list):
1149 if isinstance(ident, list):
1060 ident = ident[-1]
1150 ident = ident[-1]
1061 if ident in self._engines.values():
1151 if ident in self._engines.values():
1062 # save for later, in case of engine death
1152 # save for later, in case of engine death
1063 self._outstanding_dict[ident].add(msg_id)
1153 self._outstanding_dict[ident].add(msg_id)
1064 self.history.append(msg_id)
1154 self.history.append(msg_id)
1065 self.metadata[msg_id]['submitted'] = datetime.now()
1155 self.metadata[msg_id]['submitted'] = datetime.now()
1066
1156
1067 return msg
1157 return msg
1068
1158
1159 def send_execute_request(self, socket, code, silent=True, subheader=None, ident=None):
1160 """construct and send an execute request via a socket.
1161
1162 """
1163
1164 if self._closed:
1165 raise RuntimeError("Client cannot be used after its sockets have been closed")
1166
1167 # defaults:
1168 subheader = subheader if subheader is not None else {}
1169
1170 # validate arguments
1171 if not isinstance(code, basestring):
1172 raise TypeError("code must be text, not %s" % type(code))
1173 if not isinstance(subheader, dict):
1174 raise TypeError("subheader must be dict, not %s" % type(subheader))
1175
1176 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1177
1178
1179 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1180 subheader=subheader)
1181
1182 msg_id = msg['header']['msg_id']
1183 self.outstanding.add(msg_id)
1184 if ident:
1185 # possibly routed to a specific engine
1186 if isinstance(ident, list):
1187 ident = ident[-1]
1188 if ident in self._engines.values():
1189 # save for later, in case of engine death
1190 self._outstanding_dict[ident].add(msg_id)
1191 self.history.append(msg_id)
1192 self.metadata[msg_id]['submitted'] = datetime.now()
1193
1194 return msg
1195
1069 #--------------------------------------------------------------------------
1196 #--------------------------------------------------------------------------
1070 # construct a View object
1197 # construct a View object
1071 #--------------------------------------------------------------------------
1198 #--------------------------------------------------------------------------
1072
1199
1073 def load_balanced_view(self, targets=None):
1200 def load_balanced_view(self, targets=None):
1074 """construct a DirectView object.
1201 """construct a DirectView object.
1075
1202
1076 If no arguments are specified, create a LoadBalancedView
1203 If no arguments are specified, create a LoadBalancedView
1077 using all engines.
1204 using all engines.
1078
1205
1079 Parameters
1206 Parameters
1080 ----------
1207 ----------
1081
1208
1082 targets: list,slice,int,etc. [default: use all engines]
1209 targets: list,slice,int,etc. [default: use all engines]
1083 The subset of engines across which to load-balance
1210 The subset of engines across which to load-balance
1084 """
1211 """
1085 if targets == 'all':
1212 if targets == 'all':
1086 targets = None
1213 targets = None
1087 if targets is not None:
1214 if targets is not None:
1088 targets = self._build_targets(targets)[1]
1215 targets = self._build_targets(targets)[1]
1089 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1216 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1090
1217
1091 def direct_view(self, targets='all'):
1218 def direct_view(self, targets='all'):
1092 """construct a DirectView object.
1219 """construct a DirectView object.
1093
1220
1094 If no targets are specified, create a DirectView using all engines.
1221 If no targets are specified, create a DirectView using all engines.
1095
1222
1096 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1223 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1097 evaluate the target engines at each execution, whereas rc[:] will connect to
1224 evaluate the target engines at each execution, whereas rc[:] will connect to
1098 all *current* engines, and that list will not change.
1225 all *current* engines, and that list will not change.
1099
1226
1100 That is, 'all' will always use all engines, whereas rc[:] will not use
1227 That is, 'all' will always use all engines, whereas rc[:] will not use
1101 engines added after the DirectView is constructed.
1228 engines added after the DirectView is constructed.
1102
1229
1103 Parameters
1230 Parameters
1104 ----------
1231 ----------
1105
1232
1106 targets: list,slice,int,etc. [default: use all engines]
1233 targets: list,slice,int,etc. [default: use all engines]
1107 The engines to use for the View
1234 The engines to use for the View
1108 """
1235 """
1109 single = isinstance(targets, int)
1236 single = isinstance(targets, int)
1110 # allow 'all' to be lazily evaluated at each execution
1237 # allow 'all' to be lazily evaluated at each execution
1111 if targets != 'all':
1238 if targets != 'all':
1112 targets = self._build_targets(targets)[1]
1239 targets = self._build_targets(targets)[1]
1113 if single:
1240 if single:
1114 targets = targets[0]
1241 targets = targets[0]
1115 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1242 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1116
1243
1117 #--------------------------------------------------------------------------
1244 #--------------------------------------------------------------------------
1118 # Query methods
1245 # Query methods
1119 #--------------------------------------------------------------------------
1246 #--------------------------------------------------------------------------
1120
1247
1121 @spin_first
1248 @spin_first
1122 def get_result(self, indices_or_msg_ids=None, block=None):
1249 def get_result(self, indices_or_msg_ids=None, block=None):
1123 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1250 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1124
1251
1125 If the client already has the results, no request to the Hub will be made.
1252 If the client already has the results, no request to the Hub will be made.
1126
1253
1127 This is a convenient way to construct AsyncResult objects, which are wrappers
1254 This is a convenient way to construct AsyncResult objects, which are wrappers
1128 that include metadata about execution, and allow for awaiting results that
1255 that include metadata about execution, and allow for awaiting results that
1129 were not submitted by this Client.
1256 were not submitted by this Client.
1130
1257
1131 It can also be a convenient way to retrieve the metadata associated with
1258 It can also be a convenient way to retrieve the metadata associated with
1132 blocking execution, since it always retrieves
1259 blocking execution, since it always retrieves
1133
1260
1134 Examples
1261 Examples
1135 --------
1262 --------
1136 ::
1263 ::
1137
1264
1138 In [10]: r = client.apply()
1265 In [10]: r = client.apply()
1139
1266
1140 Parameters
1267 Parameters
1141 ----------
1268 ----------
1142
1269
1143 indices_or_msg_ids : integer history index, str msg_id, or list of either
1270 indices_or_msg_ids : integer history index, str msg_id, or list of either
1144 The indices or msg_ids of indices to be retrieved
1271 The indices or msg_ids of indices to be retrieved
1145
1272
1146 block : bool
1273 block : bool
1147 Whether to wait for the result to be done
1274 Whether to wait for the result to be done
1148
1275
1149 Returns
1276 Returns
1150 -------
1277 -------
1151
1278
1152 AsyncResult
1279 AsyncResult
1153 A single AsyncResult object will always be returned.
1280 A single AsyncResult object will always be returned.
1154
1281
1155 AsyncHubResult
1282 AsyncHubResult
1156 A subclass of AsyncResult that retrieves results from the Hub
1283 A subclass of AsyncResult that retrieves results from the Hub
1157
1284
1158 """
1285 """
1159 block = self.block if block is None else block
1286 block = self.block if block is None else block
1160 if indices_or_msg_ids is None:
1287 if indices_or_msg_ids is None:
1161 indices_or_msg_ids = -1
1288 indices_or_msg_ids = -1
1162
1289
1163 if not isinstance(indices_or_msg_ids, (list,tuple)):
1290 if not isinstance(indices_or_msg_ids, (list,tuple)):
1164 indices_or_msg_ids = [indices_or_msg_ids]
1291 indices_or_msg_ids = [indices_or_msg_ids]
1165
1292
1166 theids = []
1293 theids = []
1167 for id in indices_or_msg_ids:
1294 for id in indices_or_msg_ids:
1168 if isinstance(id, int):
1295 if isinstance(id, int):
1169 id = self.history[id]
1296 id = self.history[id]
1170 if not isinstance(id, basestring):
1297 if not isinstance(id, basestring):
1171 raise TypeError("indices must be str or int, not %r"%id)
1298 raise TypeError("indices must be str or int, not %r"%id)
1172 theids.append(id)
1299 theids.append(id)
1173
1300
1174 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1301 local_ids = filter(lambda msg_id: msg_id in self.history or msg_id in self.results, theids)
1175 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1302 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1176
1303
1177 if remote_ids:
1304 if remote_ids:
1178 ar = AsyncHubResult(self, msg_ids=theids)
1305 ar = AsyncHubResult(self, msg_ids=theids)
1179 else:
1306 else:
1180 ar = AsyncResult(self, msg_ids=theids)
1307 ar = AsyncResult(self, msg_ids=theids)
1181
1308
1182 if block:
1309 if block:
1183 ar.wait()
1310 ar.wait()
1184
1311
1185 return ar
1312 return ar
1186
1313
1187 @spin_first
1314 @spin_first
1188 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1315 def resubmit(self, indices_or_msg_ids=None, subheader=None, block=None):
1189 """Resubmit one or more tasks.
1316 """Resubmit one or more tasks.
1190
1317
1191 in-flight tasks may not be resubmitted.
1318 in-flight tasks may not be resubmitted.
1192
1319
1193 Parameters
1320 Parameters
1194 ----------
1321 ----------
1195
1322
1196 indices_or_msg_ids : integer history index, str msg_id, or list of either
1323 indices_or_msg_ids : integer history index, str msg_id, or list of either
1197 The indices or msg_ids of indices to be retrieved
1324 The indices or msg_ids of indices to be retrieved
1198
1325
1199 block : bool
1326 block : bool
1200 Whether to wait for the result to be done
1327 Whether to wait for the result to be done
1201
1328
1202 Returns
1329 Returns
1203 -------
1330 -------
1204
1331
1205 AsyncHubResult
1332 AsyncHubResult
1206 A subclass of AsyncResult that retrieves results from the Hub
1333 A subclass of AsyncResult that retrieves results from the Hub
1207
1334
1208 """
1335 """
1209 block = self.block if block is None else block
1336 block = self.block if block is None else block
1210 if indices_or_msg_ids is None:
1337 if indices_or_msg_ids is None:
1211 indices_or_msg_ids = -1
1338 indices_or_msg_ids = -1
1212
1339
1213 if not isinstance(indices_or_msg_ids, (list,tuple)):
1340 if not isinstance(indices_or_msg_ids, (list,tuple)):
1214 indices_or_msg_ids = [indices_or_msg_ids]
1341 indices_or_msg_ids = [indices_or_msg_ids]
1215
1342
1216 theids = []
1343 theids = []
1217 for id in indices_or_msg_ids:
1344 for id in indices_or_msg_ids:
1218 if isinstance(id, int):
1345 if isinstance(id, int):
1219 id = self.history[id]
1346 id = self.history[id]
1220 if not isinstance(id, basestring):
1347 if not isinstance(id, basestring):
1221 raise TypeError("indices must be str or int, not %r"%id)
1348 raise TypeError("indices must be str or int, not %r"%id)
1222 theids.append(id)
1349 theids.append(id)
1223
1350
1224 for msg_id in theids:
1225 self.outstanding.discard(msg_id)
1226 if msg_id in self.history:
1227 self.history.remove(msg_id)
1228 self.results.pop(msg_id, None)
1229 self.metadata.pop(msg_id, None)
1230 content = dict(msg_ids = theids)
1351 content = dict(msg_ids = theids)
1231
1352
1232 self.session.send(self._query_socket, 'resubmit_request', content)
1353 self.session.send(self._query_socket, 'resubmit_request', content)
1233
1354
1234 zmq.select([self._query_socket], [], [])
1355 zmq.select([self._query_socket], [], [])
1235 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1356 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1236 if self.debug:
1357 if self.debug:
1237 pprint(msg)
1358 pprint(msg)
1238 content = msg['content']
1359 content = msg['content']
1239 if content['status'] != 'ok':
1360 if content['status'] != 'ok':
1240 raise self._unwrap_exception(content)
1361 raise self._unwrap_exception(content)
1362 mapping = content['resubmitted']
1363 new_ids = [ mapping[msg_id] for msg_id in theids ]
1241
1364
1242 ar = AsyncHubResult(self, msg_ids=theids)
1365 ar = AsyncHubResult(self, msg_ids=new_ids)
1243
1366
1244 if block:
1367 if block:
1245 ar.wait()
1368 ar.wait()
1246
1369
1247 return ar
1370 return ar
1248
1371
1249 @spin_first
1372 @spin_first
1250 def result_status(self, msg_ids, status_only=True):
1373 def result_status(self, msg_ids, status_only=True):
1251 """Check on the status of the result(s) of the apply request with `msg_ids`.
1374 """Check on the status of the result(s) of the apply request with `msg_ids`.
1252
1375
1253 If status_only is False, then the actual results will be retrieved, else
1376 If status_only is False, then the actual results will be retrieved, else
1254 only the status of the results will be checked.
1377 only the status of the results will be checked.
1255
1378
1256 Parameters
1379 Parameters
1257 ----------
1380 ----------
1258
1381
1259 msg_ids : list of msg_ids
1382 msg_ids : list of msg_ids
1260 if int:
1383 if int:
1261 Passed as index to self.history for convenience.
1384 Passed as index to self.history for convenience.
1262 status_only : bool (default: True)
1385 status_only : bool (default: True)
1263 if False:
1386 if False:
1264 Retrieve the actual results of completed tasks.
1387 Retrieve the actual results of completed tasks.
1265
1388
1266 Returns
1389 Returns
1267 -------
1390 -------
1268
1391
1269 results : dict
1392 results : dict
1270 There will always be the keys 'pending' and 'completed', which will
1393 There will always be the keys 'pending' and 'completed', which will
1271 be lists of msg_ids that are incomplete or complete. If `status_only`
1394 be lists of msg_ids that are incomplete or complete. If `status_only`
1272 is False, then completed results will be keyed by their `msg_id`.
1395 is False, then completed results will be keyed by their `msg_id`.
1273 """
1396 """
1274 if not isinstance(msg_ids, (list,tuple)):
1397 if not isinstance(msg_ids, (list,tuple)):
1275 msg_ids = [msg_ids]
1398 msg_ids = [msg_ids]
1276
1399
1277 theids = []
1400 theids = []
1278 for msg_id in msg_ids:
1401 for msg_id in msg_ids:
1279 if isinstance(msg_id, int):
1402 if isinstance(msg_id, int):
1280 msg_id = self.history[msg_id]
1403 msg_id = self.history[msg_id]
1281 if not isinstance(msg_id, basestring):
1404 if not isinstance(msg_id, basestring):
1282 raise TypeError("msg_ids must be str, not %r"%msg_id)
1405 raise TypeError("msg_ids must be str, not %r"%msg_id)
1283 theids.append(msg_id)
1406 theids.append(msg_id)
1284
1407
1285 completed = []
1408 completed = []
1286 local_results = {}
1409 local_results = {}
1287
1410
1288 # comment this block out to temporarily disable local shortcut:
1411 # comment this block out to temporarily disable local shortcut:
1289 for msg_id in theids:
1412 for msg_id in theids:
1290 if msg_id in self.results:
1413 if msg_id in self.results:
1291 completed.append(msg_id)
1414 completed.append(msg_id)
1292 local_results[msg_id] = self.results[msg_id]
1415 local_results[msg_id] = self.results[msg_id]
1293 theids.remove(msg_id)
1416 theids.remove(msg_id)
1294
1417
1295 if theids: # some not locally cached
1418 if theids: # some not locally cached
1296 content = dict(msg_ids=theids, status_only=status_only)
1419 content = dict(msg_ids=theids, status_only=status_only)
1297 msg = self.session.send(self._query_socket, "result_request", content=content)
1420 msg = self.session.send(self._query_socket, "result_request", content=content)
1298 zmq.select([self._query_socket], [], [])
1421 zmq.select([self._query_socket], [], [])
1299 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1422 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1300 if self.debug:
1423 if self.debug:
1301 pprint(msg)
1424 pprint(msg)
1302 content = msg['content']
1425 content = msg['content']
1303 if content['status'] != 'ok':
1426 if content['status'] != 'ok':
1304 raise self._unwrap_exception(content)
1427 raise self._unwrap_exception(content)
1305 buffers = msg['buffers']
1428 buffers = msg['buffers']
1306 else:
1429 else:
1307 content = dict(completed=[],pending=[])
1430 content = dict(completed=[],pending=[])
1308
1431
1309 content['completed'].extend(completed)
1432 content['completed'].extend(completed)
1310
1433
1311 if status_only:
1434 if status_only:
1312 return content
1435 return content
1313
1436
1314 failures = []
1437 failures = []
1315 # load cached results into result:
1438 # load cached results into result:
1316 content.update(local_results)
1439 content.update(local_results)
1317
1440
1318 # update cache with results:
1441 # update cache with results:
1319 for msg_id in sorted(theids):
1442 for msg_id in sorted(theids):
1320 if msg_id in content['completed']:
1443 if msg_id in content['completed']:
1321 rec = content[msg_id]
1444 rec = content[msg_id]
1322 parent = rec['header']
1445 parent = rec['header']
1323 header = rec['result_header']
1446 header = rec['result_header']
1324 rcontent = rec['result_content']
1447 rcontent = rec['result_content']
1325 iodict = rec['io']
1448 iodict = rec['io']
1326 if isinstance(rcontent, str):
1449 if isinstance(rcontent, str):
1327 rcontent = self.session.unpack(rcontent)
1450 rcontent = self.session.unpack(rcontent)
1328
1451
1329 md = self.metadata[msg_id]
1452 md = self.metadata[msg_id]
1330 md.update(self._extract_metadata(header, parent, rcontent))
1453 md.update(self._extract_metadata(header, parent, rcontent))
1331 if rec.get('received'):
1454 if rec.get('received'):
1332 md['received'] = rec['received']
1455 md['received'] = rec['received']
1333 md.update(iodict)
1456 md.update(iodict)
1334
1457
1335 if rcontent['status'] == 'ok':
1458 if rcontent['status'] == 'ok':
1336 res,buffers = util.unserialize_object(buffers)
1459 res,buffers = util.unserialize_object(buffers)
1337 else:
1460 else:
1338 print rcontent
1461 print rcontent
1339 res = self._unwrap_exception(rcontent)
1462 res = self._unwrap_exception(rcontent)
1340 failures.append(res)
1463 failures.append(res)
1341
1464
1342 self.results[msg_id] = res
1465 self.results[msg_id] = res
1343 content[msg_id] = res
1466 content[msg_id] = res
1344
1467
1345 if len(theids) == 1 and failures:
1468 if len(theids) == 1 and failures:
1346 raise failures[0]
1469 raise failures[0]
1347
1470
1348 error.collect_exceptions(failures, "result_status")
1471 error.collect_exceptions(failures, "result_status")
1349 return content
1472 return content
1350
1473
1351 @spin_first
1474 @spin_first
1352 def queue_status(self, targets='all', verbose=False):
1475 def queue_status(self, targets='all', verbose=False):
1353 """Fetch the status of engine queues.
1476 """Fetch the status of engine queues.
1354
1477
1355 Parameters
1478 Parameters
1356 ----------
1479 ----------
1357
1480
1358 targets : int/str/list of ints/strs
1481 targets : int/str/list of ints/strs
1359 the engines whose states are to be queried.
1482 the engines whose states are to be queried.
1360 default : all
1483 default : all
1361 verbose : bool
1484 verbose : bool
1362 Whether to return lengths only, or lists of ids for each element
1485 Whether to return lengths only, or lists of ids for each element
1363 """
1486 """
1364 if targets == 'all':
1487 if targets == 'all':
1365 # allow 'all' to be evaluated on the engine
1488 # allow 'all' to be evaluated on the engine
1366 engine_ids = None
1489 engine_ids = None
1367 else:
1490 else:
1368 engine_ids = self._build_targets(targets)[1]
1491 engine_ids = self._build_targets(targets)[1]
1369 content = dict(targets=engine_ids, verbose=verbose)
1492 content = dict(targets=engine_ids, verbose=verbose)
1370 self.session.send(self._query_socket, "queue_request", content=content)
1493 self.session.send(self._query_socket, "queue_request", content=content)
1371 idents,msg = self.session.recv(self._query_socket, 0)
1494 idents,msg = self.session.recv(self._query_socket, 0)
1372 if self.debug:
1495 if self.debug:
1373 pprint(msg)
1496 pprint(msg)
1374 content = msg['content']
1497 content = msg['content']
1375 status = content.pop('status')
1498 status = content.pop('status')
1376 if status != 'ok':
1499 if status != 'ok':
1377 raise self._unwrap_exception(content)
1500 raise self._unwrap_exception(content)
1378 content = rekey(content)
1501 content = rekey(content)
1379 if isinstance(targets, int):
1502 if isinstance(targets, int):
1380 return content[targets]
1503 return content[targets]
1381 else:
1504 else:
1382 return content
1505 return content
1383
1506
1384 @spin_first
1507 @spin_first
1385 def purge_results(self, jobs=[], targets=[]):
1508 def purge_results(self, jobs=[], targets=[]):
1386 """Tell the Hub to forget results.
1509 """Tell the Hub to forget results.
1387
1510
1388 Individual results can be purged by msg_id, or the entire
1511 Individual results can be purged by msg_id, or the entire
1389 history of specific targets can be purged.
1512 history of specific targets can be purged.
1390
1513
1391 Use `purge_results('all')` to scrub everything from the Hub's db.
1514 Use `purge_results('all')` to scrub everything from the Hub's db.
1392
1515
1393 Parameters
1516 Parameters
1394 ----------
1517 ----------
1395
1518
1396 jobs : str or list of str or AsyncResult objects
1519 jobs : str or list of str or AsyncResult objects
1397 the msg_ids whose results should be forgotten.
1520 the msg_ids whose results should be forgotten.
1398 targets : int/str/list of ints/strs
1521 targets : int/str/list of ints/strs
1399 The targets, by int_id, whose entire history is to be purged.
1522 The targets, by int_id, whose entire history is to be purged.
1400
1523
1401 default : None
1524 default : None
1402 """
1525 """
1403 if not targets and not jobs:
1526 if not targets and not jobs:
1404 raise ValueError("Must specify at least one of `targets` and `jobs`")
1527 raise ValueError("Must specify at least one of `targets` and `jobs`")
1405 if targets:
1528 if targets:
1406 targets = self._build_targets(targets)[1]
1529 targets = self._build_targets(targets)[1]
1407
1530
1408 # construct msg_ids from jobs
1531 # construct msg_ids from jobs
1409 if jobs == 'all':
1532 if jobs == 'all':
1410 msg_ids = jobs
1533 msg_ids = jobs
1411 else:
1534 else:
1412 msg_ids = []
1535 msg_ids = []
1413 if isinstance(jobs, (basestring,AsyncResult)):
1536 if isinstance(jobs, (basestring,AsyncResult)):
1414 jobs = [jobs]
1537 jobs = [jobs]
1415 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1538 bad_ids = filter(lambda obj: not isinstance(obj, (basestring, AsyncResult)), jobs)
1416 if bad_ids:
1539 if bad_ids:
1417 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1540 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1418 for j in jobs:
1541 for j in jobs:
1419 if isinstance(j, AsyncResult):
1542 if isinstance(j, AsyncResult):
1420 msg_ids.extend(j.msg_ids)
1543 msg_ids.extend(j.msg_ids)
1421 else:
1544 else:
1422 msg_ids.append(j)
1545 msg_ids.append(j)
1423
1546
1424 content = dict(engine_ids=targets, msg_ids=msg_ids)
1547 content = dict(engine_ids=targets, msg_ids=msg_ids)
1425 self.session.send(self._query_socket, "purge_request", content=content)
1548 self.session.send(self._query_socket, "purge_request", content=content)
1426 idents, msg = self.session.recv(self._query_socket, 0)
1549 idents, msg = self.session.recv(self._query_socket, 0)
1427 if self.debug:
1550 if self.debug:
1428 pprint(msg)
1551 pprint(msg)
1429 content = msg['content']
1552 content = msg['content']
1430 if content['status'] != 'ok':
1553 if content['status'] != 'ok':
1431 raise self._unwrap_exception(content)
1554 raise self._unwrap_exception(content)
1432
1555
1433 @spin_first
1556 @spin_first
1434 def hub_history(self):
1557 def hub_history(self):
1435 """Get the Hub's history
1558 """Get the Hub's history
1436
1559
1437 Just like the Client, the Hub has a history, which is a list of msg_ids.
1560 Just like the Client, the Hub has a history, which is a list of msg_ids.
1438 This will contain the history of all clients, and, depending on configuration,
1561 This will contain the history of all clients, and, depending on configuration,
1439 may contain history across multiple cluster sessions.
1562 may contain history across multiple cluster sessions.
1440
1563
1441 Any msg_id returned here is a valid argument to `get_result`.
1564 Any msg_id returned here is a valid argument to `get_result`.
1442
1565
1443 Returns
1566 Returns
1444 -------
1567 -------
1445
1568
1446 msg_ids : list of strs
1569 msg_ids : list of strs
1447 list of all msg_ids, ordered by task submission time.
1570 list of all msg_ids, ordered by task submission time.
1448 """
1571 """
1449
1572
1450 self.session.send(self._query_socket, "history_request", content={})
1573 self.session.send(self._query_socket, "history_request", content={})
1451 idents, msg = self.session.recv(self._query_socket, 0)
1574 idents, msg = self.session.recv(self._query_socket, 0)
1452
1575
1453 if self.debug:
1576 if self.debug:
1454 pprint(msg)
1577 pprint(msg)
1455 content = msg['content']
1578 content = msg['content']
1456 if content['status'] != 'ok':
1579 if content['status'] != 'ok':
1457 raise self._unwrap_exception(content)
1580 raise self._unwrap_exception(content)
1458 else:
1581 else:
1459 return content['history']
1582 return content['history']
1460
1583
1461 @spin_first
1584 @spin_first
1462 def db_query(self, query, keys=None):
1585 def db_query(self, query, keys=None):
1463 """Query the Hub's TaskRecord database
1586 """Query the Hub's TaskRecord database
1464
1587
1465 This will return a list of task record dicts that match `query`
1588 This will return a list of task record dicts that match `query`
1466
1589
1467 Parameters
1590 Parameters
1468 ----------
1591 ----------
1469
1592
1470 query : mongodb query dict
1593 query : mongodb query dict
1471 The search dict. See mongodb query docs for details.
1594 The search dict. See mongodb query docs for details.
1472 keys : list of strs [optional]
1595 keys : list of strs [optional]
1473 The subset of keys to be returned. The default is to fetch everything but buffers.
1596 The subset of keys to be returned. The default is to fetch everything but buffers.
1474 'msg_id' will *always* be included.
1597 'msg_id' will *always* be included.
1475 """
1598 """
1476 if isinstance(keys, basestring):
1599 if isinstance(keys, basestring):
1477 keys = [keys]
1600 keys = [keys]
1478 content = dict(query=query, keys=keys)
1601 content = dict(query=query, keys=keys)
1479 self.session.send(self._query_socket, "db_request", content=content)
1602 self.session.send(self._query_socket, "db_request", content=content)
1480 idents, msg = self.session.recv(self._query_socket, 0)
1603 idents, msg = self.session.recv(self._query_socket, 0)
1481 if self.debug:
1604 if self.debug:
1482 pprint(msg)
1605 pprint(msg)
1483 content = msg['content']
1606 content = msg['content']
1484 if content['status'] != 'ok':
1607 if content['status'] != 'ok':
1485 raise self._unwrap_exception(content)
1608 raise self._unwrap_exception(content)
1486
1609
1487 records = content['records']
1610 records = content['records']
1488
1611
1489 buffer_lens = content['buffer_lens']
1612 buffer_lens = content['buffer_lens']
1490 result_buffer_lens = content['result_buffer_lens']
1613 result_buffer_lens = content['result_buffer_lens']
1491 buffers = msg['buffers']
1614 buffers = msg['buffers']
1492 has_bufs = buffer_lens is not None
1615 has_bufs = buffer_lens is not None
1493 has_rbufs = result_buffer_lens is not None
1616 has_rbufs = result_buffer_lens is not None
1494 for i,rec in enumerate(records):
1617 for i,rec in enumerate(records):
1495 # relink buffers
1618 # relink buffers
1496 if has_bufs:
1619 if has_bufs:
1497 blen = buffer_lens[i]
1620 blen = buffer_lens[i]
1498 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1621 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1499 if has_rbufs:
1622 if has_rbufs:
1500 blen = result_buffer_lens[i]
1623 blen = result_buffer_lens[i]
1501 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1624 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1502
1625
1503 return records
1626 return records
1504
1627
1505 __all__ = [ 'Client' ]
1628 __all__ = [ 'Client' ]
@@ -1,1075 +1,1100 b''
1 """Views of remote engines.
1 """Views of remote engines.
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 import imp
18 import imp
19 import sys
19 import sys
20 import warnings
20 import warnings
21 from contextlib import contextmanager
21 from contextlib import contextmanager
22 from types import ModuleType
22 from types import ModuleType
23
23
24 import zmq
24 import zmq
25
25
26 from IPython.testing.skipdoctest import skip_doctest
26 from IPython.testing.skipdoctest import skip_doctest
27 from IPython.utils.traitlets import (
27 from IPython.utils.traitlets import (
28 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
28 HasTraits, Any, Bool, List, Dict, Set, Instance, CFloat, Integer
29 )
29 )
30 from IPython.external.decorator import decorator
30 from IPython.external.decorator import decorator
31
31
32 from IPython.parallel import util
32 from IPython.parallel import util
33 from IPython.parallel.controller.dependency import Dependency, dependent
33 from IPython.parallel.controller.dependency import Dependency, dependent
34
34
35 from . import map as Map
35 from . import map as Map
36 from .asyncresult import AsyncResult, AsyncMapResult
36 from .asyncresult import AsyncResult, AsyncMapResult
37 from .remotefunction import ParallelFunction, parallel, remote, getname
37 from .remotefunction import ParallelFunction, parallel, remote, getname
38
38
39 #-----------------------------------------------------------------------------
39 #-----------------------------------------------------------------------------
40 # Decorators
40 # Decorators
41 #-----------------------------------------------------------------------------
41 #-----------------------------------------------------------------------------
42
42
43 @decorator
43 @decorator
44 def save_ids(f, self, *args, **kwargs):
44 def save_ids(f, self, *args, **kwargs):
45 """Keep our history and outstanding attributes up to date after a method call."""
45 """Keep our history and outstanding attributes up to date after a method call."""
46 n_previous = len(self.client.history)
46 n_previous = len(self.client.history)
47 try:
47 try:
48 ret = f(self, *args, **kwargs)
48 ret = f(self, *args, **kwargs)
49 finally:
49 finally:
50 nmsgs = len(self.client.history) - n_previous
50 nmsgs = len(self.client.history) - n_previous
51 msg_ids = self.client.history[-nmsgs:]
51 msg_ids = self.client.history[-nmsgs:]
52 self.history.extend(msg_ids)
52 self.history.extend(msg_ids)
53 map(self.outstanding.add, msg_ids)
53 map(self.outstanding.add, msg_ids)
54 return ret
54 return ret
55
55
56 @decorator
56 @decorator
57 def sync_results(f, self, *args, **kwargs):
57 def sync_results(f, self, *args, **kwargs):
58 """sync relevant results from self.client to our results attribute."""
58 """sync relevant results from self.client to our results attribute."""
59 ret = f(self, *args, **kwargs)
59 ret = f(self, *args, **kwargs)
60 delta = self.outstanding.difference(self.client.outstanding)
60 delta = self.outstanding.difference(self.client.outstanding)
61 completed = self.outstanding.intersection(delta)
61 completed = self.outstanding.intersection(delta)
62 self.outstanding = self.outstanding.difference(completed)
62 self.outstanding = self.outstanding.difference(completed)
63 for msg_id in completed:
63 for msg_id in completed:
64 self.results[msg_id] = self.client.results[msg_id]
64 self.results[msg_id] = self.client.results[msg_id]
65 return ret
65 return ret
66
66
67 @decorator
67 @decorator
68 def spin_after(f, self, *args, **kwargs):
68 def spin_after(f, self, *args, **kwargs):
69 """call spin after the method."""
69 """call spin after the method."""
70 ret = f(self, *args, **kwargs)
70 ret = f(self, *args, **kwargs)
71 self.spin()
71 self.spin()
72 return ret
72 return ret
73
73
74 #-----------------------------------------------------------------------------
74 #-----------------------------------------------------------------------------
75 # Classes
75 # Classes
76 #-----------------------------------------------------------------------------
76 #-----------------------------------------------------------------------------
77
77
78 @skip_doctest
78 @skip_doctest
79 class View(HasTraits):
79 class View(HasTraits):
80 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
80 """Base View class for more convenint apply(f,*args,**kwargs) syntax via attributes.
81
81
82 Don't use this class, use subclasses.
82 Don't use this class, use subclasses.
83
83
84 Methods
84 Methods
85 -------
85 -------
86
86
87 spin
87 spin
88 flushes incoming results and registration state changes
88 flushes incoming results and registration state changes
89 control methods spin, and requesting `ids` also ensures up to date
89 control methods spin, and requesting `ids` also ensures up to date
90
90
91 wait
91 wait
92 wait on one or more msg_ids
92 wait on one or more msg_ids
93
93
94 execution methods
94 execution methods
95 apply
95 apply
96 legacy: execute, run
96 legacy: execute, run
97
97
98 data movement
98 data movement
99 push, pull, scatter, gather
99 push, pull, scatter, gather
100
100
101 query methods
101 query methods
102 get_result, queue_status, purge_results, result_status
102 get_result, queue_status, purge_results, result_status
103
103
104 control methods
104 control methods
105 abort, shutdown
105 abort, shutdown
106
106
107 """
107 """
108 # flags
108 # flags
109 block=Bool(False)
109 block=Bool(False)
110 track=Bool(True)
110 track=Bool(True)
111 targets = Any()
111 targets = Any()
112
112
113 history=List()
113 history=List()
114 outstanding = Set()
114 outstanding = Set()
115 results = Dict()
115 results = Dict()
116 client = Instance('IPython.parallel.Client')
116 client = Instance('IPython.parallel.Client')
117
117
118 _socket = Instance('zmq.Socket')
118 _socket = Instance('zmq.Socket')
119 _flag_names = List(['targets', 'block', 'track'])
119 _flag_names = List(['targets', 'block', 'track'])
120 _targets = Any()
120 _targets = Any()
121 _idents = Any()
121 _idents = Any()
122
122
123 def __init__(self, client=None, socket=None, **flags):
123 def __init__(self, client=None, socket=None, **flags):
124 super(View, self).__init__(client=client, _socket=socket)
124 super(View, self).__init__(client=client, _socket=socket)
125 self.block = client.block
125 self.block = client.block
126
126
127 self.set_flags(**flags)
127 self.set_flags(**flags)
128
128
129 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
129 assert not self.__class__ is View, "Don't use base View objects, use subclasses"
130
130
131
132 def __repr__(self):
131 def __repr__(self):
133 strtargets = str(self.targets)
132 strtargets = str(self.targets)
134 if len(strtargets) > 16:
133 if len(strtargets) > 16:
135 strtargets = strtargets[:12]+'...]'
134 strtargets = strtargets[:12]+'...]'
136 return "<%s %s>"%(self.__class__.__name__, strtargets)
135 return "<%s %s>"%(self.__class__.__name__, strtargets)
137
136
137 def __len__(self):
138 if isinstance(self.targets, list):
139 return len(self.targets)
140 elif isinstance(self.targets, int):
141 return 1
142 else:
143 return len(self.client)
144
138 def set_flags(self, **kwargs):
145 def set_flags(self, **kwargs):
139 """set my attribute flags by keyword.
146 """set my attribute flags by keyword.
140
147
141 Views determine behavior with a few attributes (`block`, `track`, etc.).
148 Views determine behavior with a few attributes (`block`, `track`, etc.).
142 These attributes can be set all at once by name with this method.
149 These attributes can be set all at once by name with this method.
143
150
144 Parameters
151 Parameters
145 ----------
152 ----------
146
153
147 block : bool
154 block : bool
148 whether to wait for results
155 whether to wait for results
149 track : bool
156 track : bool
150 whether to create a MessageTracker to allow the user to
157 whether to create a MessageTracker to allow the user to
151 safely edit after arrays and buffers during non-copying
158 safely edit after arrays and buffers during non-copying
152 sends.
159 sends.
153 """
160 """
154 for name, value in kwargs.iteritems():
161 for name, value in kwargs.iteritems():
155 if name not in self._flag_names:
162 if name not in self._flag_names:
156 raise KeyError("Invalid name: %r"%name)
163 raise KeyError("Invalid name: %r"%name)
157 else:
164 else:
158 setattr(self, name, value)
165 setattr(self, name, value)
159
166
160 @contextmanager
167 @contextmanager
161 def temp_flags(self, **kwargs):
168 def temp_flags(self, **kwargs):
162 """temporarily set flags, for use in `with` statements.
169 """temporarily set flags, for use in `with` statements.
163
170
164 See set_flags for permanent setting of flags
171 See set_flags for permanent setting of flags
165
172
166 Examples
173 Examples
167 --------
174 --------
168
175
169 >>> view.track=False
176 >>> view.track=False
170 ...
177 ...
171 >>> with view.temp_flags(track=True):
178 >>> with view.temp_flags(track=True):
172 ... ar = view.apply(dostuff, my_big_array)
179 ... ar = view.apply(dostuff, my_big_array)
173 ... ar.tracker.wait() # wait for send to finish
180 ... ar.tracker.wait() # wait for send to finish
174 >>> view.track
181 >>> view.track
175 False
182 False
176
183
177 """
184 """
178 # preflight: save flags, and set temporaries
185 # preflight: save flags, and set temporaries
179 saved_flags = {}
186 saved_flags = {}
180 for f in self._flag_names:
187 for f in self._flag_names:
181 saved_flags[f] = getattr(self, f)
188 saved_flags[f] = getattr(self, f)
182 self.set_flags(**kwargs)
189 self.set_flags(**kwargs)
183 # yield to the with-statement block
190 # yield to the with-statement block
184 try:
191 try:
185 yield
192 yield
186 finally:
193 finally:
187 # postflight: restore saved flags
194 # postflight: restore saved flags
188 self.set_flags(**saved_flags)
195 self.set_flags(**saved_flags)
189
196
190
197
191 #----------------------------------------------------------------
198 #----------------------------------------------------------------
192 # apply
199 # apply
193 #----------------------------------------------------------------
200 #----------------------------------------------------------------
194
201
195 @sync_results
202 @sync_results
196 @save_ids
203 @save_ids
197 def _really_apply(self, f, args, kwargs, block=None, **options):
204 def _really_apply(self, f, args, kwargs, block=None, **options):
198 """wrapper for client.send_apply_message"""
205 """wrapper for client.send_apply_request"""
199 raise NotImplementedError("Implement in subclasses")
206 raise NotImplementedError("Implement in subclasses")
200
207
201 def apply(self, f, *args, **kwargs):
208 def apply(self, f, *args, **kwargs):
202 """calls f(*args, **kwargs) on remote engines, returning the result.
209 """calls f(*args, **kwargs) on remote engines, returning the result.
203
210
204 This method sets all apply flags via this View's attributes.
211 This method sets all apply flags via this View's attributes.
205
212
206 if self.block is False:
213 if self.block is False:
207 returns AsyncResult
214 returns AsyncResult
208 else:
215 else:
209 returns actual result of f(*args, **kwargs)
216 returns actual result of f(*args, **kwargs)
210 """
217 """
211 return self._really_apply(f, args, kwargs)
218 return self._really_apply(f, args, kwargs)
212
219
213 def apply_async(self, f, *args, **kwargs):
220 def apply_async(self, f, *args, **kwargs):
214 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
221 """calls f(*args, **kwargs) on remote engines in a nonblocking manner.
215
222
216 returns AsyncResult
223 returns AsyncResult
217 """
224 """
218 return self._really_apply(f, args, kwargs, block=False)
225 return self._really_apply(f, args, kwargs, block=False)
219
226
220 @spin_after
227 @spin_after
221 def apply_sync(self, f, *args, **kwargs):
228 def apply_sync(self, f, *args, **kwargs):
222 """calls f(*args, **kwargs) on remote engines in a blocking manner,
229 """calls f(*args, **kwargs) on remote engines in a blocking manner,
223 returning the result.
230 returning the result.
224
231
225 returns: actual result of f(*args, **kwargs)
232 returns: actual result of f(*args, **kwargs)
226 """
233 """
227 return self._really_apply(f, args, kwargs, block=True)
234 return self._really_apply(f, args, kwargs, block=True)
228
235
229 #----------------------------------------------------------------
236 #----------------------------------------------------------------
230 # wrappers for client and control methods
237 # wrappers for client and control methods
231 #----------------------------------------------------------------
238 #----------------------------------------------------------------
232 @sync_results
239 @sync_results
233 def spin(self):
240 def spin(self):
234 """spin the client, and sync"""
241 """spin the client, and sync"""
235 self.client.spin()
242 self.client.spin()
236
243
237 @sync_results
244 @sync_results
238 def wait(self, jobs=None, timeout=-1):
245 def wait(self, jobs=None, timeout=-1):
239 """waits on one or more `jobs`, for up to `timeout` seconds.
246 """waits on one or more `jobs`, for up to `timeout` seconds.
240
247
241 Parameters
248 Parameters
242 ----------
249 ----------
243
250
244 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
251 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
245 ints are indices to self.history
252 ints are indices to self.history
246 strs are msg_ids
253 strs are msg_ids
247 default: wait on all outstanding messages
254 default: wait on all outstanding messages
248 timeout : float
255 timeout : float
249 a time in seconds, after which to give up.
256 a time in seconds, after which to give up.
250 default is -1, which means no timeout
257 default is -1, which means no timeout
251
258
252 Returns
259 Returns
253 -------
260 -------
254
261
255 True : when all msg_ids are done
262 True : when all msg_ids are done
256 False : timeout reached, some msg_ids still outstanding
263 False : timeout reached, some msg_ids still outstanding
257 """
264 """
258 if jobs is None:
265 if jobs is None:
259 jobs = self.history
266 jobs = self.history
260 return self.client.wait(jobs, timeout)
267 return self.client.wait(jobs, timeout)
261
268
262 def abort(self, jobs=None, targets=None, block=None):
269 def abort(self, jobs=None, targets=None, block=None):
263 """Abort jobs on my engines.
270 """Abort jobs on my engines.
264
271
265 Parameters
272 Parameters
266 ----------
273 ----------
267
274
268 jobs : None, str, list of strs, optional
275 jobs : None, str, list of strs, optional
269 if None: abort all jobs.
276 if None: abort all jobs.
270 else: abort specific msg_id(s).
277 else: abort specific msg_id(s).
271 """
278 """
272 block = block if block is not None else self.block
279 block = block if block is not None else self.block
273 targets = targets if targets is not None else self.targets
280 targets = targets if targets is not None else self.targets
274 jobs = jobs if jobs is not None else list(self.outstanding)
281 jobs = jobs if jobs is not None else list(self.outstanding)
275
282
276 return self.client.abort(jobs=jobs, targets=targets, block=block)
283 return self.client.abort(jobs=jobs, targets=targets, block=block)
277
284
278 def queue_status(self, targets=None, verbose=False):
285 def queue_status(self, targets=None, verbose=False):
279 """Fetch the Queue status of my engines"""
286 """Fetch the Queue status of my engines"""
280 targets = targets if targets is not None else self.targets
287 targets = targets if targets is not None else self.targets
281 return self.client.queue_status(targets=targets, verbose=verbose)
288 return self.client.queue_status(targets=targets, verbose=verbose)
282
289
283 def purge_results(self, jobs=[], targets=[]):
290 def purge_results(self, jobs=[], targets=[]):
284 """Instruct the controller to forget specific results."""
291 """Instruct the controller to forget specific results."""
285 if targets is None or targets == 'all':
292 if targets is None or targets == 'all':
286 targets = self.targets
293 targets = self.targets
287 return self.client.purge_results(jobs=jobs, targets=targets)
294 return self.client.purge_results(jobs=jobs, targets=targets)
288
295
289 def shutdown(self, targets=None, restart=False, hub=False, block=None):
296 def shutdown(self, targets=None, restart=False, hub=False, block=None):
290 """Terminates one or more engine processes, optionally including the hub.
297 """Terminates one or more engine processes, optionally including the hub.
291 """
298 """
292 block = self.block if block is None else block
299 block = self.block if block is None else block
293 if targets is None or targets == 'all':
300 if targets is None or targets == 'all':
294 targets = self.targets
301 targets = self.targets
295 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
302 return self.client.shutdown(targets=targets, restart=restart, hub=hub, block=block)
296
303
297 @spin_after
304 @spin_after
298 def get_result(self, indices_or_msg_ids=None):
305 def get_result(self, indices_or_msg_ids=None):
299 """return one or more results, specified by history index or msg_id.
306 """return one or more results, specified by history index or msg_id.
300
307
301 See client.get_result for details.
308 See client.get_result for details.
302
309
303 """
310 """
304
311
305 if indices_or_msg_ids is None:
312 if indices_or_msg_ids is None:
306 indices_or_msg_ids = -1
313 indices_or_msg_ids = -1
307 if isinstance(indices_or_msg_ids, int):
314 if isinstance(indices_or_msg_ids, int):
308 indices_or_msg_ids = self.history[indices_or_msg_ids]
315 indices_or_msg_ids = self.history[indices_or_msg_ids]
309 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
316 elif isinstance(indices_or_msg_ids, (list,tuple,set)):
310 indices_or_msg_ids = list(indices_or_msg_ids)
317 indices_or_msg_ids = list(indices_or_msg_ids)
311 for i,index in enumerate(indices_or_msg_ids):
318 for i,index in enumerate(indices_or_msg_ids):
312 if isinstance(index, int):
319 if isinstance(index, int):
313 indices_or_msg_ids[i] = self.history[index]
320 indices_or_msg_ids[i] = self.history[index]
314 return self.client.get_result(indices_or_msg_ids)
321 return self.client.get_result(indices_or_msg_ids)
315
322
316 #-------------------------------------------------------------------
323 #-------------------------------------------------------------------
317 # Map
324 # Map
318 #-------------------------------------------------------------------
325 #-------------------------------------------------------------------
319
326
320 def map(self, f, *sequences, **kwargs):
327 def map(self, f, *sequences, **kwargs):
321 """override in subclasses"""
328 """override in subclasses"""
322 raise NotImplementedError
329 raise NotImplementedError
323
330
324 def map_async(self, f, *sequences, **kwargs):
331 def map_async(self, f, *sequences, **kwargs):
325 """Parallel version of builtin `map`, using this view's engines.
332 """Parallel version of builtin `map`, using this view's engines.
326
333
327 This is equivalent to map(...block=False)
334 This is equivalent to map(...block=False)
328
335
329 See `self.map` for details.
336 See `self.map` for details.
330 """
337 """
331 if 'block' in kwargs:
338 if 'block' in kwargs:
332 raise TypeError("map_async doesn't take a `block` keyword argument.")
339 raise TypeError("map_async doesn't take a `block` keyword argument.")
333 kwargs['block'] = False
340 kwargs['block'] = False
334 return self.map(f,*sequences,**kwargs)
341 return self.map(f,*sequences,**kwargs)
335
342
336 def map_sync(self, f, *sequences, **kwargs):
343 def map_sync(self, f, *sequences, **kwargs):
337 """Parallel version of builtin `map`, using this view's engines.
344 """Parallel version of builtin `map`, using this view's engines.
338
345
339 This is equivalent to map(...block=True)
346 This is equivalent to map(...block=True)
340
347
341 See `self.map` for details.
348 See `self.map` for details.
342 """
349 """
343 if 'block' in kwargs:
350 if 'block' in kwargs:
344 raise TypeError("map_sync doesn't take a `block` keyword argument.")
351 raise TypeError("map_sync doesn't take a `block` keyword argument.")
345 kwargs['block'] = True
352 kwargs['block'] = True
346 return self.map(f,*sequences,**kwargs)
353 return self.map(f,*sequences,**kwargs)
347
354
348 def imap(self, f, *sequences, **kwargs):
355 def imap(self, f, *sequences, **kwargs):
349 """Parallel version of `itertools.imap`.
356 """Parallel version of `itertools.imap`.
350
357
351 See `self.map` for details.
358 See `self.map` for details.
352
359
353 """
360 """
354
361
355 return iter(self.map_async(f,*sequences, **kwargs))
362 return iter(self.map_async(f,*sequences, **kwargs))
356
363
357 #-------------------------------------------------------------------
364 #-------------------------------------------------------------------
358 # Decorators
365 # Decorators
359 #-------------------------------------------------------------------
366 #-------------------------------------------------------------------
360
367
361 def remote(self, block=True, **flags):
368 def remote(self, block=True, **flags):
362 """Decorator for making a RemoteFunction"""
369 """Decorator for making a RemoteFunction"""
363 block = self.block if block is None else block
370 block = self.block if block is None else block
364 return remote(self, block=block, **flags)
371 return remote(self, block=block, **flags)
365
372
366 def parallel(self, dist='b', block=None, **flags):
373 def parallel(self, dist='b', block=None, **flags):
367 """Decorator for making a ParallelFunction"""
374 """Decorator for making a ParallelFunction"""
368 block = self.block if block is None else block
375 block = self.block if block is None else block
369 return parallel(self, dist=dist, block=block, **flags)
376 return parallel(self, dist=dist, block=block, **flags)
370
377
371 @skip_doctest
378 @skip_doctest
372 class DirectView(View):
379 class DirectView(View):
373 """Direct Multiplexer View of one or more engines.
380 """Direct Multiplexer View of one or more engines.
374
381
375 These are created via indexed access to a client:
382 These are created via indexed access to a client:
376
383
377 >>> dv_1 = client[1]
384 >>> dv_1 = client[1]
378 >>> dv_all = client[:]
385 >>> dv_all = client[:]
379 >>> dv_even = client[::2]
386 >>> dv_even = client[::2]
380 >>> dv_some = client[1:3]
387 >>> dv_some = client[1:3]
381
388
382 This object provides dictionary access to engine namespaces:
389 This object provides dictionary access to engine namespaces:
383
390
384 # push a=5:
391 # push a=5:
385 >>> dv['a'] = 5
392 >>> dv['a'] = 5
386 # pull 'foo':
393 # pull 'foo':
387 >>> db['foo']
394 >>> db['foo']
388
395
389 """
396 """
390
397
391 def __init__(self, client=None, socket=None, targets=None):
398 def __init__(self, client=None, socket=None, targets=None):
392 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
399 super(DirectView, self).__init__(client=client, socket=socket, targets=targets)
393
400
394 @property
401 @property
395 def importer(self):
402 def importer(self):
396 """sync_imports(local=True) as a property.
403 """sync_imports(local=True) as a property.
397
404
398 See sync_imports for details.
405 See sync_imports for details.
399
406
400 """
407 """
401 return self.sync_imports(True)
408 return self.sync_imports(True)
402
409
403 @contextmanager
410 @contextmanager
404 def sync_imports(self, local=True, quiet=False):
411 def sync_imports(self, local=True, quiet=False):
405 """Context Manager for performing simultaneous local and remote imports.
412 """Context Manager for performing simultaneous local and remote imports.
406
413
407 'import x as y' will *not* work. The 'as y' part will simply be ignored.
414 'import x as y' will *not* work. The 'as y' part will simply be ignored.
408
415
409 If `local=True`, then the package will also be imported locally.
416 If `local=True`, then the package will also be imported locally.
410
417
411 If `quiet=True`, no output will be produced when attempting remote
418 If `quiet=True`, no output will be produced when attempting remote
412 imports.
419 imports.
413
420
414 Note that remote-only (`local=False`) imports have not been implemented.
421 Note that remote-only (`local=False`) imports have not been implemented.
415
422
416 >>> with view.sync_imports():
423 >>> with view.sync_imports():
417 ... from numpy import recarray
424 ... from numpy import recarray
418 importing recarray from numpy on engine(s)
425 importing recarray from numpy on engine(s)
419
426
420 """
427 """
421 import __builtin__
428 import __builtin__
422 local_import = __builtin__.__import__
429 local_import = __builtin__.__import__
423 modules = set()
430 modules = set()
424 results = []
431 results = []
425 @util.interactive
432 @util.interactive
426 def remote_import(name, fromlist, level):
433 def remote_import(name, fromlist, level):
427 """the function to be passed to apply, that actually performs the import
434 """the function to be passed to apply, that actually performs the import
428 on the engine, and loads up the user namespace.
435 on the engine, and loads up the user namespace.
429 """
436 """
430 import sys
437 import sys
431 user_ns = globals()
438 user_ns = globals()
432 mod = __import__(name, fromlist=fromlist, level=level)
439 mod = __import__(name, fromlist=fromlist, level=level)
433 if fromlist:
440 if fromlist:
434 for key in fromlist:
441 for key in fromlist:
435 user_ns[key] = getattr(mod, key)
442 user_ns[key] = getattr(mod, key)
436 else:
443 else:
437 user_ns[name] = sys.modules[name]
444 user_ns[name] = sys.modules[name]
438
445
439 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
446 def view_import(name, globals={}, locals={}, fromlist=[], level=-1):
440 """the drop-in replacement for __import__, that optionally imports
447 """the drop-in replacement for __import__, that optionally imports
441 locally as well.
448 locally as well.
442 """
449 """
443 # don't override nested imports
450 # don't override nested imports
444 save_import = __builtin__.__import__
451 save_import = __builtin__.__import__
445 __builtin__.__import__ = local_import
452 __builtin__.__import__ = local_import
446
453
447 if imp.lock_held():
454 if imp.lock_held():
448 # this is a side-effect import, don't do it remotely, or even
455 # this is a side-effect import, don't do it remotely, or even
449 # ignore the local effects
456 # ignore the local effects
450 return local_import(name, globals, locals, fromlist, level)
457 return local_import(name, globals, locals, fromlist, level)
451
458
452 imp.acquire_lock()
459 imp.acquire_lock()
453 if local:
460 if local:
454 mod = local_import(name, globals, locals, fromlist, level)
461 mod = local_import(name, globals, locals, fromlist, level)
455 else:
462 else:
456 raise NotImplementedError("remote-only imports not yet implemented")
463 raise NotImplementedError("remote-only imports not yet implemented")
457 imp.release_lock()
464 imp.release_lock()
458
465
459 key = name+':'+','.join(fromlist or [])
466 key = name+':'+','.join(fromlist or [])
460 if level == -1 and key not in modules:
467 if level == -1 and key not in modules:
461 modules.add(key)
468 modules.add(key)
462 if not quiet:
469 if not quiet:
463 if fromlist:
470 if fromlist:
464 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
471 print "importing %s from %s on engine(s)"%(','.join(fromlist), name)
465 else:
472 else:
466 print "importing %s on engine(s)"%name
473 print "importing %s on engine(s)"%name
467 results.append(self.apply_async(remote_import, name, fromlist, level))
474 results.append(self.apply_async(remote_import, name, fromlist, level))
468 # restore override
475 # restore override
469 __builtin__.__import__ = save_import
476 __builtin__.__import__ = save_import
470
477
471 return mod
478 return mod
472
479
473 # override __import__
480 # override __import__
474 __builtin__.__import__ = view_import
481 __builtin__.__import__ = view_import
475 try:
482 try:
476 # enter the block
483 # enter the block
477 yield
484 yield
478 except ImportError:
485 except ImportError:
479 if local:
486 if local:
480 raise
487 raise
481 else:
488 else:
482 # ignore import errors if not doing local imports
489 # ignore import errors if not doing local imports
483 pass
490 pass
484 finally:
491 finally:
485 # always restore __import__
492 # always restore __import__
486 __builtin__.__import__ = local_import
493 __builtin__.__import__ = local_import
487
494
488 for r in results:
495 for r in results:
489 # raise possible remote ImportErrors here
496 # raise possible remote ImportErrors here
490 r.get()
497 r.get()
491
498
492
499
493 @sync_results
500 @sync_results
494 @save_ids
501 @save_ids
495 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
502 def _really_apply(self, f, args=None, kwargs=None, targets=None, block=None, track=None):
496 """calls f(*args, **kwargs) on remote engines, returning the result.
503 """calls f(*args, **kwargs) on remote engines, returning the result.
497
504
498 This method sets all of `apply`'s flags via this View's attributes.
505 This method sets all of `apply`'s flags via this View's attributes.
499
506
500 Parameters
507 Parameters
501 ----------
508 ----------
502
509
503 f : callable
510 f : callable
504
511
505 args : list [default: empty]
512 args : list [default: empty]
506
513
507 kwargs : dict [default: empty]
514 kwargs : dict [default: empty]
508
515
509 targets : target list [default: self.targets]
516 targets : target list [default: self.targets]
510 where to run
517 where to run
511 block : bool [default: self.block]
518 block : bool [default: self.block]
512 whether to block
519 whether to block
513 track : bool [default: self.track]
520 track : bool [default: self.track]
514 whether to ask zmq to track the message, for safe non-copying sends
521 whether to ask zmq to track the message, for safe non-copying sends
515
522
516 Returns
523 Returns
517 -------
524 -------
518
525
519 if self.block is False:
526 if self.block is False:
520 returns AsyncResult
527 returns AsyncResult
521 else:
528 else:
522 returns actual result of f(*args, **kwargs) on the engine(s)
529 returns actual result of f(*args, **kwargs) on the engine(s)
523 This will be a list of self.targets is also a list (even length 1), or
530 This will be a list of self.targets is also a list (even length 1), or
524 the single result if self.targets is an integer engine id
531 the single result if self.targets is an integer engine id
525 """
532 """
526 args = [] if args is None else args
533 args = [] if args is None else args
527 kwargs = {} if kwargs is None else kwargs
534 kwargs = {} if kwargs is None else kwargs
528 block = self.block if block is None else block
535 block = self.block if block is None else block
529 track = self.track if track is None else track
536 track = self.track if track is None else track
530 targets = self.targets if targets is None else targets
537 targets = self.targets if targets is None else targets
531
538
532 _idents = self.client._build_targets(targets)[0]
539 _idents = self.client._build_targets(targets)[0]
533 msg_ids = []
540 msg_ids = []
534 trackers = []
541 trackers = []
535 for ident in _idents:
542 for ident in _idents:
536 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
543 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
537 ident=ident)
544 ident=ident)
538 if track:
545 if track:
539 trackers.append(msg['tracker'])
546 trackers.append(msg['tracker'])
540 msg_ids.append(msg['header']['msg_id'])
547 msg_ids.append(msg['header']['msg_id'])
541 tracker = None if track is False else zmq.MessageTracker(*trackers)
548 tracker = None if track is False else zmq.MessageTracker(*trackers)
542 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=targets, tracker=tracker)
549 ar = AsyncResult(self.client, msg_ids, fname=getname(f), targets=targets, tracker=tracker)
543 if block:
550 if block:
544 try:
551 try:
545 return ar.get()
552 return ar.get()
546 except KeyboardInterrupt:
553 except KeyboardInterrupt:
547 pass
554 pass
548 return ar
555 return ar
549
556
557
550 @spin_after
558 @spin_after
551 def map(self, f, *sequences, **kwargs):
559 def map(self, f, *sequences, **kwargs):
552 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
560 """view.map(f, *sequences, block=self.block) => list|AsyncMapResult
553
561
554 Parallel version of builtin `map`, using this View's `targets`.
562 Parallel version of builtin `map`, using this View's `targets`.
555
563
556 There will be one task per target, so work will be chunked
564 There will be one task per target, so work will be chunked
557 if the sequences are longer than `targets`.
565 if the sequences are longer than `targets`.
558
566
559 Results can be iterated as they are ready, but will become available in chunks.
567 Results can be iterated as they are ready, but will become available in chunks.
560
568
561 Parameters
569 Parameters
562 ----------
570 ----------
563
571
564 f : callable
572 f : callable
565 function to be mapped
573 function to be mapped
566 *sequences: one or more sequences of matching length
574 *sequences: one or more sequences of matching length
567 the sequences to be distributed and passed to `f`
575 the sequences to be distributed and passed to `f`
568 block : bool
576 block : bool
569 whether to wait for the result or not [default self.block]
577 whether to wait for the result or not [default self.block]
570
578
571 Returns
579 Returns
572 -------
580 -------
573
581
574 if block=False:
582 if block=False:
575 AsyncMapResult
583 AsyncMapResult
576 An object like AsyncResult, but which reassembles the sequence of results
584 An object like AsyncResult, but which reassembles the sequence of results
577 into a single list. AsyncMapResults can be iterated through before all
585 into a single list. AsyncMapResults can be iterated through before all
578 results are complete.
586 results are complete.
579 else:
587 else:
580 list
588 list
581 the result of map(f,*sequences)
589 the result of map(f,*sequences)
582 """
590 """
583
591
584 block = kwargs.pop('block', self.block)
592 block = kwargs.pop('block', self.block)
585 for k in kwargs.keys():
593 for k in kwargs.keys():
586 if k not in ['block', 'track']:
594 if k not in ['block', 'track']:
587 raise TypeError("invalid keyword arg, %r"%k)
595 raise TypeError("invalid keyword arg, %r"%k)
588
596
589 assert len(sequences) > 0, "must have some sequences to map onto!"
597 assert len(sequences) > 0, "must have some sequences to map onto!"
590 pf = ParallelFunction(self, f, block=block, **kwargs)
598 pf = ParallelFunction(self, f, block=block, **kwargs)
591 return pf.map(*sequences)
599 return pf.map(*sequences)
592
600
593 def execute(self, code, targets=None, block=None):
601 @sync_results
602 @save_ids
603 def execute(self, code, silent=True, targets=None, block=None):
594 """Executes `code` on `targets` in blocking or nonblocking manner.
604 """Executes `code` on `targets` in blocking or nonblocking manner.
595
605
596 ``execute`` is always `bound` (affects engine namespace)
606 ``execute`` is always `bound` (affects engine namespace)
597
607
598 Parameters
608 Parameters
599 ----------
609 ----------
600
610
601 code : str
611 code : str
602 the code string to be executed
612 the code string to be executed
603 block : bool
613 block : bool
604 whether or not to wait until done to return
614 whether or not to wait until done to return
605 default: self.block
615 default: self.block
606 """
616 """
607 return self._really_apply(util._execute, args=(code,), block=block, targets=targets)
617 block = self.block if block is None else block
618 targets = self.targets if targets is None else targets
619
620 _idents = self.client._build_targets(targets)[0]
621 msg_ids = []
622 trackers = []
623 for ident in _idents:
624 msg = self.client.send_execute_request(self._socket, code, silent=silent, ident=ident)
625 msg_ids.append(msg['header']['msg_id'])
626 ar = AsyncResult(self.client, msg_ids, fname='execute', targets=targets)
627 if block:
628 try:
629 ar.get()
630 except KeyboardInterrupt:
631 pass
632 return ar
608
633
609 def run(self, filename, targets=None, block=None):
634 def run(self, filename, targets=None, block=None):
610 """Execute contents of `filename` on my engine(s).
635 """Execute contents of `filename` on my engine(s).
611
636
612 This simply reads the contents of the file and calls `execute`.
637 This simply reads the contents of the file and calls `execute`.
613
638
614 Parameters
639 Parameters
615 ----------
640 ----------
616
641
617 filename : str
642 filename : str
618 The path to the file
643 The path to the file
619 targets : int/str/list of ints/strs
644 targets : int/str/list of ints/strs
620 the engines on which to execute
645 the engines on which to execute
621 default : all
646 default : all
622 block : bool
647 block : bool
623 whether or not to wait until done
648 whether or not to wait until done
624 default: self.block
649 default: self.block
625
650
626 """
651 """
627 with open(filename, 'r') as f:
652 with open(filename, 'r') as f:
628 # add newline in case of trailing indented whitespace
653 # add newline in case of trailing indented whitespace
629 # which will cause SyntaxError
654 # which will cause SyntaxError
630 code = f.read()+'\n'
655 code = f.read()+'\n'
631 return self.execute(code, block=block, targets=targets)
656 return self.execute(code, block=block, targets=targets)
632
657
633 def update(self, ns):
658 def update(self, ns):
634 """update remote namespace with dict `ns`
659 """update remote namespace with dict `ns`
635
660
636 See `push` for details.
661 See `push` for details.
637 """
662 """
638 return self.push(ns, block=self.block, track=self.track)
663 return self.push(ns, block=self.block, track=self.track)
639
664
640 def push(self, ns, targets=None, block=None, track=None):
665 def push(self, ns, targets=None, block=None, track=None):
641 """update remote namespace with dict `ns`
666 """update remote namespace with dict `ns`
642
667
643 Parameters
668 Parameters
644 ----------
669 ----------
645
670
646 ns : dict
671 ns : dict
647 dict of keys with which to update engine namespace(s)
672 dict of keys with which to update engine namespace(s)
648 block : bool [default : self.block]
673 block : bool [default : self.block]
649 whether to wait to be notified of engine receipt
674 whether to wait to be notified of engine receipt
650
675
651 """
676 """
652
677
653 block = block if block is not None else self.block
678 block = block if block is not None else self.block
654 track = track if track is not None else self.track
679 track = track if track is not None else self.track
655 targets = targets if targets is not None else self.targets
680 targets = targets if targets is not None else self.targets
656 # applier = self.apply_sync if block else self.apply_async
681 # applier = self.apply_sync if block else self.apply_async
657 if not isinstance(ns, dict):
682 if not isinstance(ns, dict):
658 raise TypeError("Must be a dict, not %s"%type(ns))
683 raise TypeError("Must be a dict, not %s"%type(ns))
659 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
684 return self._really_apply(util._push, kwargs=ns, block=block, track=track, targets=targets)
660
685
661 def get(self, key_s):
686 def get(self, key_s):
662 """get object(s) by `key_s` from remote namespace
687 """get object(s) by `key_s` from remote namespace
663
688
664 see `pull` for details.
689 see `pull` for details.
665 """
690 """
666 # block = block if block is not None else self.block
691 # block = block if block is not None else self.block
667 return self.pull(key_s, block=True)
692 return self.pull(key_s, block=True)
668
693
669 def pull(self, names, targets=None, block=None):
694 def pull(self, names, targets=None, block=None):
670 """get object(s) by `name` from remote namespace
695 """get object(s) by `name` from remote namespace
671
696
672 will return one object if it is a key.
697 will return one object if it is a key.
673 can also take a list of keys, in which case it will return a list of objects.
698 can also take a list of keys, in which case it will return a list of objects.
674 """
699 """
675 block = block if block is not None else self.block
700 block = block if block is not None else self.block
676 targets = targets if targets is not None else self.targets
701 targets = targets if targets is not None else self.targets
677 applier = self.apply_sync if block else self.apply_async
702 applier = self.apply_sync if block else self.apply_async
678 if isinstance(names, basestring):
703 if isinstance(names, basestring):
679 pass
704 pass
680 elif isinstance(names, (list,tuple,set)):
705 elif isinstance(names, (list,tuple,set)):
681 for key in names:
706 for key in names:
682 if not isinstance(key, basestring):
707 if not isinstance(key, basestring):
683 raise TypeError("keys must be str, not type %r"%type(key))
708 raise TypeError("keys must be str, not type %r"%type(key))
684 else:
709 else:
685 raise TypeError("names must be strs, not %r"%names)
710 raise TypeError("names must be strs, not %r"%names)
686 return self._really_apply(util._pull, (names,), block=block, targets=targets)
711 return self._really_apply(util._pull, (names,), block=block, targets=targets)
687
712
688 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
713 def scatter(self, key, seq, dist='b', flatten=False, targets=None, block=None, track=None):
689 """
714 """
690 Partition a Python sequence and send the partitions to a set of engines.
715 Partition a Python sequence and send the partitions to a set of engines.
691 """
716 """
692 block = block if block is not None else self.block
717 block = block if block is not None else self.block
693 track = track if track is not None else self.track
718 track = track if track is not None else self.track
694 targets = targets if targets is not None else self.targets
719 targets = targets if targets is not None else self.targets
695
720
696 # construct integer ID list:
721 # construct integer ID list:
697 targets = self.client._build_targets(targets)[1]
722 targets = self.client._build_targets(targets)[1]
698
723
699 mapObject = Map.dists[dist]()
724 mapObject = Map.dists[dist]()
700 nparts = len(targets)
725 nparts = len(targets)
701 msg_ids = []
726 msg_ids = []
702 trackers = []
727 trackers = []
703 for index, engineid in enumerate(targets):
728 for index, engineid in enumerate(targets):
704 partition = mapObject.getPartition(seq, index, nparts)
729 partition = mapObject.getPartition(seq, index, nparts)
705 if flatten and len(partition) == 1:
730 if flatten and len(partition) == 1:
706 ns = {key: partition[0]}
731 ns = {key: partition[0]}
707 else:
732 else:
708 ns = {key: partition}
733 ns = {key: partition}
709 r = self.push(ns, block=False, track=track, targets=engineid)
734 r = self.push(ns, block=False, track=track, targets=engineid)
710 msg_ids.extend(r.msg_ids)
735 msg_ids.extend(r.msg_ids)
711 if track:
736 if track:
712 trackers.append(r._tracker)
737 trackers.append(r._tracker)
713
738
714 if track:
739 if track:
715 tracker = zmq.MessageTracker(*trackers)
740 tracker = zmq.MessageTracker(*trackers)
716 else:
741 else:
717 tracker = None
742 tracker = None
718
743
719 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
744 r = AsyncResult(self.client, msg_ids, fname='scatter', targets=targets, tracker=tracker)
720 if block:
745 if block:
721 r.wait()
746 r.wait()
722 else:
747 else:
723 return r
748 return r
724
749
725 @sync_results
750 @sync_results
726 @save_ids
751 @save_ids
727 def gather(self, key, dist='b', targets=None, block=None):
752 def gather(self, key, dist='b', targets=None, block=None):
728 """
753 """
729 Gather a partitioned sequence on a set of engines as a single local seq.
754 Gather a partitioned sequence on a set of engines as a single local seq.
730 """
755 """
731 block = block if block is not None else self.block
756 block = block if block is not None else self.block
732 targets = targets if targets is not None else self.targets
757 targets = targets if targets is not None else self.targets
733 mapObject = Map.dists[dist]()
758 mapObject = Map.dists[dist]()
734 msg_ids = []
759 msg_ids = []
735
760
736 # construct integer ID list:
761 # construct integer ID list:
737 targets = self.client._build_targets(targets)[1]
762 targets = self.client._build_targets(targets)[1]
738
763
739 for index, engineid in enumerate(targets):
764 for index, engineid in enumerate(targets):
740 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
765 msg_ids.extend(self.pull(key, block=False, targets=engineid).msg_ids)
741
766
742 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
767 r = AsyncMapResult(self.client, msg_ids, mapObject, fname='gather')
743
768
744 if block:
769 if block:
745 try:
770 try:
746 return r.get()
771 return r.get()
747 except KeyboardInterrupt:
772 except KeyboardInterrupt:
748 pass
773 pass
749 return r
774 return r
750
775
751 def __getitem__(self, key):
776 def __getitem__(self, key):
752 return self.get(key)
777 return self.get(key)
753
778
754 def __setitem__(self,key, value):
779 def __setitem__(self,key, value):
755 self.update({key:value})
780 self.update({key:value})
756
781
757 def clear(self, targets=None, block=False):
782 def clear(self, targets=None, block=False):
758 """Clear the remote namespaces on my engines."""
783 """Clear the remote namespaces on my engines."""
759 block = block if block is not None else self.block
784 block = block if block is not None else self.block
760 targets = targets if targets is not None else self.targets
785 targets = targets if targets is not None else self.targets
761 return self.client.clear(targets=targets, block=block)
786 return self.client.clear(targets=targets, block=block)
762
787
763 def kill(self, targets=None, block=True):
788 def kill(self, targets=None, block=True):
764 """Kill my engines."""
789 """Kill my engines."""
765 block = block if block is not None else self.block
790 block = block if block is not None else self.block
766 targets = targets if targets is not None else self.targets
791 targets = targets if targets is not None else self.targets
767 return self.client.kill(targets=targets, block=block)
792 return self.client.kill(targets=targets, block=block)
768
793
769 #----------------------------------------
794 #----------------------------------------
770 # activate for %px,%autopx magics
795 # activate for %px,%autopx magics
771 #----------------------------------------
796 #----------------------------------------
772 def activate(self):
797 def activate(self):
773 """Make this `View` active for parallel magic commands.
798 """Make this `View` active for parallel magic commands.
774
799
775 IPython has a magic command syntax to work with `MultiEngineClient` objects.
800 IPython has a magic command syntax to work with `MultiEngineClient` objects.
776 In a given IPython session there is a single active one. While
801 In a given IPython session there is a single active one. While
777 there can be many `Views` created and used by the user,
802 there can be many `Views` created and used by the user,
778 there is only one active one. The active `View` is used whenever
803 there is only one active one. The active `View` is used whenever
779 the magic commands %px and %autopx are used.
804 the magic commands %px and %autopx are used.
780
805
781 The activate() method is called on a given `View` to make it
806 The activate() method is called on a given `View` to make it
782 active. Once this has been done, the magic commands can be used.
807 active. Once this has been done, the magic commands can be used.
783 """
808 """
784
809
785 try:
810 try:
786 # This is injected into __builtins__.
811 # This is injected into __builtins__.
787 ip = get_ipython()
812 ip = get_ipython()
788 except NameError:
813 except NameError:
789 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
814 print "The IPython parallel magics (%result, %px, %autopx) only work within IPython."
790 else:
815 else:
791 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
816 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
792 if pmagic is None:
817 if pmagic is None:
793 ip.magic_load_ext('parallelmagic')
818 ip.magic_load_ext('parallelmagic')
794 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
819 pmagic = ip.plugin_manager.get_plugin('parallelmagic')
795
820
796 pmagic.active_view = self
821 pmagic.active_view = self
797
822
798
823
799 @skip_doctest
824 @skip_doctest
800 class LoadBalancedView(View):
825 class LoadBalancedView(View):
801 """An load-balancing View that only executes via the Task scheduler.
826 """An load-balancing View that only executes via the Task scheduler.
802
827
803 Load-balanced views can be created with the client's `view` method:
828 Load-balanced views can be created with the client's `view` method:
804
829
805 >>> v = client.load_balanced_view()
830 >>> v = client.load_balanced_view()
806
831
807 or targets can be specified, to restrict the potential destinations:
832 or targets can be specified, to restrict the potential destinations:
808
833
809 >>> v = client.client.load_balanced_view([1,3])
834 >>> v = client.client.load_balanced_view([1,3])
810
835
811 which would restrict loadbalancing to between engines 1 and 3.
836 which would restrict loadbalancing to between engines 1 and 3.
812
837
813 """
838 """
814
839
815 follow=Any()
840 follow=Any()
816 after=Any()
841 after=Any()
817 timeout=CFloat()
842 timeout=CFloat()
818 retries = Integer(0)
843 retries = Integer(0)
819
844
820 _task_scheme = Any()
845 _task_scheme = Any()
821 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
846 _flag_names = List(['targets', 'block', 'track', 'follow', 'after', 'timeout', 'retries'])
822
847
823 def __init__(self, client=None, socket=None, **flags):
848 def __init__(self, client=None, socket=None, **flags):
824 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
849 super(LoadBalancedView, self).__init__(client=client, socket=socket, **flags)
825 self._task_scheme=client._task_scheme
850 self._task_scheme=client._task_scheme
826
851
827 def _validate_dependency(self, dep):
852 def _validate_dependency(self, dep):
828 """validate a dependency.
853 """validate a dependency.
829
854
830 For use in `set_flags`.
855 For use in `set_flags`.
831 """
856 """
832 if dep is None or isinstance(dep, (basestring, AsyncResult, Dependency)):
857 if dep is None or isinstance(dep, (basestring, AsyncResult, Dependency)):
833 return True
858 return True
834 elif isinstance(dep, (list,set, tuple)):
859 elif isinstance(dep, (list,set, tuple)):
835 for d in dep:
860 for d in dep:
836 if not isinstance(d, (basestring, AsyncResult)):
861 if not isinstance(d, (basestring, AsyncResult)):
837 return False
862 return False
838 elif isinstance(dep, dict):
863 elif isinstance(dep, dict):
839 if set(dep.keys()) != set(Dependency().as_dict().keys()):
864 if set(dep.keys()) != set(Dependency().as_dict().keys()):
840 return False
865 return False
841 if not isinstance(dep['msg_ids'], list):
866 if not isinstance(dep['msg_ids'], list):
842 return False
867 return False
843 for d in dep['msg_ids']:
868 for d in dep['msg_ids']:
844 if not isinstance(d, basestring):
869 if not isinstance(d, basestring):
845 return False
870 return False
846 else:
871 else:
847 return False
872 return False
848
873
849 return True
874 return True
850
875
851 def _render_dependency(self, dep):
876 def _render_dependency(self, dep):
852 """helper for building jsonable dependencies from various input forms."""
877 """helper for building jsonable dependencies from various input forms."""
853 if isinstance(dep, Dependency):
878 if isinstance(dep, Dependency):
854 return dep.as_dict()
879 return dep.as_dict()
855 elif isinstance(dep, AsyncResult):
880 elif isinstance(dep, AsyncResult):
856 return dep.msg_ids
881 return dep.msg_ids
857 elif dep is None:
882 elif dep is None:
858 return []
883 return []
859 else:
884 else:
860 # pass to Dependency constructor
885 # pass to Dependency constructor
861 return list(Dependency(dep))
886 return list(Dependency(dep))
862
887
863 def set_flags(self, **kwargs):
888 def set_flags(self, **kwargs):
864 """set my attribute flags by keyword.
889 """set my attribute flags by keyword.
865
890
866 A View is a wrapper for the Client's apply method, but with attributes
891 A View is a wrapper for the Client's apply method, but with attributes
867 that specify keyword arguments, those attributes can be set by keyword
892 that specify keyword arguments, those attributes can be set by keyword
868 argument with this method.
893 argument with this method.
869
894
870 Parameters
895 Parameters
871 ----------
896 ----------
872
897
873 block : bool
898 block : bool
874 whether to wait for results
899 whether to wait for results
875 track : bool
900 track : bool
876 whether to create a MessageTracker to allow the user to
901 whether to create a MessageTracker to allow the user to
877 safely edit after arrays and buffers during non-copying
902 safely edit after arrays and buffers during non-copying
878 sends.
903 sends.
879
904
880 after : Dependency or collection of msg_ids
905 after : Dependency or collection of msg_ids
881 Only for load-balanced execution (targets=None)
906 Only for load-balanced execution (targets=None)
882 Specify a list of msg_ids as a time-based dependency.
907 Specify a list of msg_ids as a time-based dependency.
883 This job will only be run *after* the dependencies
908 This job will only be run *after* the dependencies
884 have been met.
909 have been met.
885
910
886 follow : Dependency or collection of msg_ids
911 follow : Dependency or collection of msg_ids
887 Only for load-balanced execution (targets=None)
912 Only for load-balanced execution (targets=None)
888 Specify a list of msg_ids as a location-based dependency.
913 Specify a list of msg_ids as a location-based dependency.
889 This job will only be run on an engine where this dependency
914 This job will only be run on an engine where this dependency
890 is met.
915 is met.
891
916
892 timeout : float/int or None
917 timeout : float/int or None
893 Only for load-balanced execution (targets=None)
918 Only for load-balanced execution (targets=None)
894 Specify an amount of time (in seconds) for the scheduler to
919 Specify an amount of time (in seconds) for the scheduler to
895 wait for dependencies to be met before failing with a
920 wait for dependencies to be met before failing with a
896 DependencyTimeout.
921 DependencyTimeout.
897
922
898 retries : int
923 retries : int
899 Number of times a task will be retried on failure.
924 Number of times a task will be retried on failure.
900 """
925 """
901
926
902 super(LoadBalancedView, self).set_flags(**kwargs)
927 super(LoadBalancedView, self).set_flags(**kwargs)
903 for name in ('follow', 'after'):
928 for name in ('follow', 'after'):
904 if name in kwargs:
929 if name in kwargs:
905 value = kwargs[name]
930 value = kwargs[name]
906 if self._validate_dependency(value):
931 if self._validate_dependency(value):
907 setattr(self, name, value)
932 setattr(self, name, value)
908 else:
933 else:
909 raise ValueError("Invalid dependency: %r"%value)
934 raise ValueError("Invalid dependency: %r"%value)
910 if 'timeout' in kwargs:
935 if 'timeout' in kwargs:
911 t = kwargs['timeout']
936 t = kwargs['timeout']
912 if not isinstance(t, (int, long, float, type(None))):
937 if not isinstance(t, (int, long, float, type(None))):
913 raise TypeError("Invalid type for timeout: %r"%type(t))
938 raise TypeError("Invalid type for timeout: %r"%type(t))
914 if t is not None:
939 if t is not None:
915 if t < 0:
940 if t < 0:
916 raise ValueError("Invalid timeout: %s"%t)
941 raise ValueError("Invalid timeout: %s"%t)
917 self.timeout = t
942 self.timeout = t
918
943
919 @sync_results
944 @sync_results
920 @save_ids
945 @save_ids
921 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
946 def _really_apply(self, f, args=None, kwargs=None, block=None, track=None,
922 after=None, follow=None, timeout=None,
947 after=None, follow=None, timeout=None,
923 targets=None, retries=None):
948 targets=None, retries=None):
924 """calls f(*args, **kwargs) on a remote engine, returning the result.
949 """calls f(*args, **kwargs) on a remote engine, returning the result.
925
950
926 This method temporarily sets all of `apply`'s flags for a single call.
951 This method temporarily sets all of `apply`'s flags for a single call.
927
952
928 Parameters
953 Parameters
929 ----------
954 ----------
930
955
931 f : callable
956 f : callable
932
957
933 args : list [default: empty]
958 args : list [default: empty]
934
959
935 kwargs : dict [default: empty]
960 kwargs : dict [default: empty]
936
961
937 block : bool [default: self.block]
962 block : bool [default: self.block]
938 whether to block
963 whether to block
939 track : bool [default: self.track]
964 track : bool [default: self.track]
940 whether to ask zmq to track the message, for safe non-copying sends
965 whether to ask zmq to track the message, for safe non-copying sends
941
966
942 !!!!!! TODO: THE REST HERE !!!!
967 !!!!!! TODO: THE REST HERE !!!!
943
968
944 Returns
969 Returns
945 -------
970 -------
946
971
947 if self.block is False:
972 if self.block is False:
948 returns AsyncResult
973 returns AsyncResult
949 else:
974 else:
950 returns actual result of f(*args, **kwargs) on the engine(s)
975 returns actual result of f(*args, **kwargs) on the engine(s)
951 This will be a list of self.targets is also a list (even length 1), or
976 This will be a list of self.targets is also a list (even length 1), or
952 the single result if self.targets is an integer engine id
977 the single result if self.targets is an integer engine id
953 """
978 """
954
979
955 # validate whether we can run
980 # validate whether we can run
956 if self._socket.closed:
981 if self._socket.closed:
957 msg = "Task farming is disabled"
982 msg = "Task farming is disabled"
958 if self._task_scheme == 'pure':
983 if self._task_scheme == 'pure':
959 msg += " because the pure ZMQ scheduler cannot handle"
984 msg += " because the pure ZMQ scheduler cannot handle"
960 msg += " disappearing engines."
985 msg += " disappearing engines."
961 raise RuntimeError(msg)
986 raise RuntimeError(msg)
962
987
963 if self._task_scheme == 'pure':
988 if self._task_scheme == 'pure':
964 # pure zmq scheme doesn't support extra features
989 # pure zmq scheme doesn't support extra features
965 msg = "Pure ZMQ scheduler doesn't support the following flags:"
990 msg = "Pure ZMQ scheduler doesn't support the following flags:"
966 "follow, after, retries, targets, timeout"
991 "follow, after, retries, targets, timeout"
967 if (follow or after or retries or targets or timeout):
992 if (follow or after or retries or targets or timeout):
968 # hard fail on Scheduler flags
993 # hard fail on Scheduler flags
969 raise RuntimeError(msg)
994 raise RuntimeError(msg)
970 if isinstance(f, dependent):
995 if isinstance(f, dependent):
971 # soft warn on functional dependencies
996 # soft warn on functional dependencies
972 warnings.warn(msg, RuntimeWarning)
997 warnings.warn(msg, RuntimeWarning)
973
998
974 # build args
999 # build args
975 args = [] if args is None else args
1000 args = [] if args is None else args
976 kwargs = {} if kwargs is None else kwargs
1001 kwargs = {} if kwargs is None else kwargs
977 block = self.block if block is None else block
1002 block = self.block if block is None else block
978 track = self.track if track is None else track
1003 track = self.track if track is None else track
979 after = self.after if after is None else after
1004 after = self.after if after is None else after
980 retries = self.retries if retries is None else retries
1005 retries = self.retries if retries is None else retries
981 follow = self.follow if follow is None else follow
1006 follow = self.follow if follow is None else follow
982 timeout = self.timeout if timeout is None else timeout
1007 timeout = self.timeout if timeout is None else timeout
983 targets = self.targets if targets is None else targets
1008 targets = self.targets if targets is None else targets
984
1009
985 if not isinstance(retries, int):
1010 if not isinstance(retries, int):
986 raise TypeError('retries must be int, not %r'%type(retries))
1011 raise TypeError('retries must be int, not %r'%type(retries))
987
1012
988 if targets is None:
1013 if targets is None:
989 idents = []
1014 idents = []
990 else:
1015 else:
991 idents = self.client._build_targets(targets)[0]
1016 idents = self.client._build_targets(targets)[0]
992 # ensure *not* bytes
1017 # ensure *not* bytes
993 idents = [ ident.decode() for ident in idents ]
1018 idents = [ ident.decode() for ident in idents ]
994
1019
995 after = self._render_dependency(after)
1020 after = self._render_dependency(after)
996 follow = self._render_dependency(follow)
1021 follow = self._render_dependency(follow)
997 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
1022 subheader = dict(after=after, follow=follow, timeout=timeout, targets=idents, retries=retries)
998
1023
999 msg = self.client.send_apply_message(self._socket, f, args, kwargs, track=track,
1024 msg = self.client.send_apply_request(self._socket, f, args, kwargs, track=track,
1000 subheader=subheader)
1025 subheader=subheader)
1001 tracker = None if track is False else msg['tracker']
1026 tracker = None if track is False else msg['tracker']
1002
1027
1003 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker)
1028 ar = AsyncResult(self.client, msg['header']['msg_id'], fname=getname(f), targets=None, tracker=tracker)
1004
1029
1005 if block:
1030 if block:
1006 try:
1031 try:
1007 return ar.get()
1032 return ar.get()
1008 except KeyboardInterrupt:
1033 except KeyboardInterrupt:
1009 pass
1034 pass
1010 return ar
1035 return ar
1011
1036
1012 @spin_after
1037 @spin_after
1013 @save_ids
1038 @save_ids
1014 def map(self, f, *sequences, **kwargs):
1039 def map(self, f, *sequences, **kwargs):
1015 """view.map(f, *sequences, block=self.block, chunksize=1, ordered=True) => list|AsyncMapResult
1040 """view.map(f, *sequences, block=self.block, chunksize=1, ordered=True) => list|AsyncMapResult
1016
1041
1017 Parallel version of builtin `map`, load-balanced by this View.
1042 Parallel version of builtin `map`, load-balanced by this View.
1018
1043
1019 `block`, and `chunksize` can be specified by keyword only.
1044 `block`, and `chunksize` can be specified by keyword only.
1020
1045
1021 Each `chunksize` elements will be a separate task, and will be
1046 Each `chunksize` elements will be a separate task, and will be
1022 load-balanced. This lets individual elements be available for iteration
1047 load-balanced. This lets individual elements be available for iteration
1023 as soon as they arrive.
1048 as soon as they arrive.
1024
1049
1025 Parameters
1050 Parameters
1026 ----------
1051 ----------
1027
1052
1028 f : callable
1053 f : callable
1029 function to be mapped
1054 function to be mapped
1030 *sequences: one or more sequences of matching length
1055 *sequences: one or more sequences of matching length
1031 the sequences to be distributed and passed to `f`
1056 the sequences to be distributed and passed to `f`
1032 block : bool [default self.block]
1057 block : bool [default self.block]
1033 whether to wait for the result or not
1058 whether to wait for the result or not
1034 track : bool
1059 track : bool
1035 whether to create a MessageTracker to allow the user to
1060 whether to create a MessageTracker to allow the user to
1036 safely edit after arrays and buffers during non-copying
1061 safely edit after arrays and buffers during non-copying
1037 sends.
1062 sends.
1038 chunksize : int [default 1]
1063 chunksize : int [default 1]
1039 how many elements should be in each task.
1064 how many elements should be in each task.
1040 ordered : bool [default True]
1065 ordered : bool [default True]
1041 Whether the results should be gathered as they arrive, or enforce
1066 Whether the results should be gathered as they arrive, or enforce
1042 the order of submission.
1067 the order of submission.
1043
1068
1044 Only applies when iterating through AsyncMapResult as results arrive.
1069 Only applies when iterating through AsyncMapResult as results arrive.
1045 Has no effect when block=True.
1070 Has no effect when block=True.
1046
1071
1047 Returns
1072 Returns
1048 -------
1073 -------
1049
1074
1050 if block=False:
1075 if block=False:
1051 AsyncMapResult
1076 AsyncMapResult
1052 An object like AsyncResult, but which reassembles the sequence of results
1077 An object like AsyncResult, but which reassembles the sequence of results
1053 into a single list. AsyncMapResults can be iterated through before all
1078 into a single list. AsyncMapResults can be iterated through before all
1054 results are complete.
1079 results are complete.
1055 else:
1080 else:
1056 the result of map(f,*sequences)
1081 the result of map(f,*sequences)
1057
1082
1058 """
1083 """
1059
1084
1060 # default
1085 # default
1061 block = kwargs.get('block', self.block)
1086 block = kwargs.get('block', self.block)
1062 chunksize = kwargs.get('chunksize', 1)
1087 chunksize = kwargs.get('chunksize', 1)
1063 ordered = kwargs.get('ordered', True)
1088 ordered = kwargs.get('ordered', True)
1064
1089
1065 keyset = set(kwargs.keys())
1090 keyset = set(kwargs.keys())
1066 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1091 extra_keys = keyset.difference_update(set(['block', 'chunksize']))
1067 if extra_keys:
1092 if extra_keys:
1068 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1093 raise TypeError("Invalid kwargs: %s"%list(extra_keys))
1069
1094
1070 assert len(sequences) > 0, "must have some sequences to map onto!"
1095 assert len(sequences) > 0, "must have some sequences to map onto!"
1071
1096
1072 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1097 pf = ParallelFunction(self, f, block=block, chunksize=chunksize, ordered=ordered)
1073 return pf.map(*sequences)
1098 return pf.map(*sequences)
1074
1099
1075 __all__ = ['LoadBalancedView', 'DirectView']
1100 __all__ = ['LoadBalancedView', 'DirectView']
@@ -1,181 +1,182 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """
2 """
3 A multi-heart Heartbeat system using PUB and XREP sockets. pings are sent out on the PUB,
3 A multi-heart Heartbeat system using PUB and XREP sockets. pings are sent out on the PUB,
4 and hearts are tracked based on their XREQ identities.
4 and hearts are tracked based on their XREQ identities.
5
5
6 Authors:
6 Authors:
7
7
8 * Min RK
8 * Min RK
9 """
9 """
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Copyright (C) 2010-2011 The IPython Development Team
11 # Copyright (C) 2010-2011 The IPython Development Team
12 #
12 #
13 # Distributed under the terms of the BSD License. The full license is in
13 # Distributed under the terms of the BSD License. The full license is in
14 # the file COPYING, distributed as part of this software.
14 # the file COPYING, distributed as part of this software.
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16
16
17 from __future__ import print_function
17 from __future__ import print_function
18 import time
18 import time
19 import uuid
19 import uuid
20
20
21 import zmq
21 import zmq
22 from zmq.devices import ThreadDevice
22 from zmq.devices import ThreadDevice
23 from zmq.eventloop import ioloop, zmqstream
23 from zmq.eventloop import ioloop, zmqstream
24
24
25 from IPython.config.configurable import LoggingConfigurable
25 from IPython.config.configurable import LoggingConfigurable
26 from IPython.utils.py3compat import str_to_bytes
26 from IPython.utils.traitlets import Set, Instance, CFloat, Integer
27 from IPython.utils.traitlets import Set, Instance, CFloat, Integer
27
28
28 from IPython.parallel.util import asbytes, log_errors
29 from IPython.parallel.util import log_errors
29
30
30 class Heart(object):
31 class Heart(object):
31 """A basic heart object for responding to a HeartMonitor.
32 """A basic heart object for responding to a HeartMonitor.
32 This is a simple wrapper with defaults for the most common
33 This is a simple wrapper with defaults for the most common
33 Device model for responding to heartbeats.
34 Device model for responding to heartbeats.
34
35
35 It simply builds a threadsafe zmq.FORWARDER Device, defaulting to using
36 It simply builds a threadsafe zmq.FORWARDER Device, defaulting to using
36 SUB/XREQ for in/out.
37 SUB/XREQ for in/out.
37
38
38 You can specify the XREQ's IDENTITY via the optional heart_id argument."""
39 You can specify the XREQ's IDENTITY via the optional heart_id argument."""
39 device=None
40 device=None
40 id=None
41 id=None
41 def __init__(self, in_addr, out_addr, in_type=zmq.SUB, out_type=zmq.DEALER, heart_id=None):
42 def __init__(self, in_addr, out_addr, in_type=zmq.SUB, out_type=zmq.DEALER, heart_id=None):
42 self.device = ThreadDevice(zmq.FORWARDER, in_type, out_type)
43 self.device = ThreadDevice(zmq.FORWARDER, in_type, out_type)
43 # do not allow the device to share global Context.instance,
44 # do not allow the device to share global Context.instance,
44 # which is the default behavior in pyzmq > 2.1.10
45 # which is the default behavior in pyzmq > 2.1.10
45 self.device.context_factory = zmq.Context
46 self.device.context_factory = zmq.Context
46
47
47 self.device.daemon=True
48 self.device.daemon=True
48 self.device.connect_in(in_addr)
49 self.device.connect_in(in_addr)
49 self.device.connect_out(out_addr)
50 self.device.connect_out(out_addr)
50 if in_type == zmq.SUB:
51 if in_type == zmq.SUB:
51 self.device.setsockopt_in(zmq.SUBSCRIBE, b"")
52 self.device.setsockopt_in(zmq.SUBSCRIBE, b"")
52 if heart_id is None:
53 if heart_id is None:
53 heart_id = uuid.uuid4().bytes
54 heart_id = uuid.uuid4().bytes
54 self.device.setsockopt_out(zmq.IDENTITY, heart_id)
55 self.device.setsockopt_out(zmq.IDENTITY, heart_id)
55 self.id = heart_id
56 self.id = heart_id
56
57
57 def start(self):
58 def start(self):
58 return self.device.start()
59 return self.device.start()
59
60
60
61
61 class HeartMonitor(LoggingConfigurable):
62 class HeartMonitor(LoggingConfigurable):
62 """A basic HeartMonitor class
63 """A basic HeartMonitor class
63 pingstream: a PUB stream
64 pingstream: a PUB stream
64 pongstream: an XREP stream
65 pongstream: an XREP stream
65 period: the period of the heartbeat in milliseconds"""
66 period: the period of the heartbeat in milliseconds"""
66
67
67 period = Integer(3000, config=True,
68 period = Integer(3000, config=True,
68 help='The frequency at which the Hub pings the engines for heartbeats '
69 help='The frequency at which the Hub pings the engines for heartbeats '
69 '(in ms)',
70 '(in ms)',
70 )
71 )
71
72
72 pingstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
73 pingstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
73 pongstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
74 pongstream=Instance('zmq.eventloop.zmqstream.ZMQStream')
74 loop = Instance('zmq.eventloop.ioloop.IOLoop')
75 loop = Instance('zmq.eventloop.ioloop.IOLoop')
75 def _loop_default(self):
76 def _loop_default(self):
76 return ioloop.IOLoop.instance()
77 return ioloop.IOLoop.instance()
77
78
78 # not settable:
79 # not settable:
79 hearts=Set()
80 hearts=Set()
80 responses=Set()
81 responses=Set()
81 on_probation=Set()
82 on_probation=Set()
82 last_ping=CFloat(0)
83 last_ping=CFloat(0)
83 _new_handlers = Set()
84 _new_handlers = Set()
84 _failure_handlers = Set()
85 _failure_handlers = Set()
85 lifetime = CFloat(0)
86 lifetime = CFloat(0)
86 tic = CFloat(0)
87 tic = CFloat(0)
87
88
88 def __init__(self, **kwargs):
89 def __init__(self, **kwargs):
89 super(HeartMonitor, self).__init__(**kwargs)
90 super(HeartMonitor, self).__init__(**kwargs)
90
91
91 self.pongstream.on_recv(self.handle_pong)
92 self.pongstream.on_recv(self.handle_pong)
92
93
93 def start(self):
94 def start(self):
94 self.tic = time.time()
95 self.tic = time.time()
95 self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop)
96 self.caller = ioloop.PeriodicCallback(self.beat, self.period, self.loop)
96 self.caller.start()
97 self.caller.start()
97
98
98 def add_new_heart_handler(self, handler):
99 def add_new_heart_handler(self, handler):
99 """add a new handler for new hearts"""
100 """add a new handler for new hearts"""
100 self.log.debug("heartbeat::new_heart_handler: %s", handler)
101 self.log.debug("heartbeat::new_heart_handler: %s", handler)
101 self._new_handlers.add(handler)
102 self._new_handlers.add(handler)
102
103
103 def add_heart_failure_handler(self, handler):
104 def add_heart_failure_handler(self, handler):
104 """add a new handler for heart failure"""
105 """add a new handler for heart failure"""
105 self.log.debug("heartbeat::new heart failure handler: %s", handler)
106 self.log.debug("heartbeat::new heart failure handler: %s", handler)
106 self._failure_handlers.add(handler)
107 self._failure_handlers.add(handler)
107
108
108 def beat(self):
109 def beat(self):
109 self.pongstream.flush()
110 self.pongstream.flush()
110 self.last_ping = self.lifetime
111 self.last_ping = self.lifetime
111
112
112 toc = time.time()
113 toc = time.time()
113 self.lifetime += toc-self.tic
114 self.lifetime += toc-self.tic
114 self.tic = toc
115 self.tic = toc
115 self.log.debug("heartbeat::sending %s", self.lifetime)
116 self.log.debug("heartbeat::sending %s", self.lifetime)
116 goodhearts = self.hearts.intersection(self.responses)
117 goodhearts = self.hearts.intersection(self.responses)
117 missed_beats = self.hearts.difference(goodhearts)
118 missed_beats = self.hearts.difference(goodhearts)
118 heartfailures = self.on_probation.intersection(missed_beats)
119 heartfailures = self.on_probation.intersection(missed_beats)
119 newhearts = self.responses.difference(goodhearts)
120 newhearts = self.responses.difference(goodhearts)
120 map(self.handle_new_heart, newhearts)
121 map(self.handle_new_heart, newhearts)
121 map(self.handle_heart_failure, heartfailures)
122 map(self.handle_heart_failure, heartfailures)
122 self.on_probation = missed_beats.intersection(self.hearts)
123 self.on_probation = missed_beats.intersection(self.hearts)
123 self.responses = set()
124 self.responses = set()
124 # print self.on_probation, self.hearts
125 # print self.on_probation, self.hearts
125 # self.log.debug("heartbeat::beat %.3f, %i beating hearts", self.lifetime, len(self.hearts))
126 # self.log.debug("heartbeat::beat %.3f, %i beating hearts", self.lifetime, len(self.hearts))
126 self.pingstream.send(asbytes(str(self.lifetime)))
127 self.pingstream.send(str_to_bytes(str(self.lifetime)))
127 # flush stream to force immediate socket send
128 # flush stream to force immediate socket send
128 self.pingstream.flush()
129 self.pingstream.flush()
129
130
130 def handle_new_heart(self, heart):
131 def handle_new_heart(self, heart):
131 if self._new_handlers:
132 if self._new_handlers:
132 for handler in self._new_handlers:
133 for handler in self._new_handlers:
133 handler(heart)
134 handler(heart)
134 else:
135 else:
135 self.log.info("heartbeat::yay, got new heart %s!", heart)
136 self.log.info("heartbeat::yay, got new heart %s!", heart)
136 self.hearts.add(heart)
137 self.hearts.add(heart)
137
138
138 def handle_heart_failure(self, heart):
139 def handle_heart_failure(self, heart):
139 if self._failure_handlers:
140 if self._failure_handlers:
140 for handler in self._failure_handlers:
141 for handler in self._failure_handlers:
141 try:
142 try:
142 handler(heart)
143 handler(heart)
143 except Exception as e:
144 except Exception as e:
144 self.log.error("heartbeat::Bad Handler! %s", handler, exc_info=True)
145 self.log.error("heartbeat::Bad Handler! %s", handler, exc_info=True)
145 pass
146 pass
146 else:
147 else:
147 self.log.info("heartbeat::Heart %s failed :(", heart)
148 self.log.info("heartbeat::Heart %s failed :(", heart)
148 self.hearts.remove(heart)
149 self.hearts.remove(heart)
149
150
150
151
151 @log_errors
152 @log_errors
152 def handle_pong(self, msg):
153 def handle_pong(self, msg):
153 "a heart just beat"
154 "a heart just beat"
154 current = asbytes(str(self.lifetime))
155 current = str_to_bytes(str(self.lifetime))
155 last = asbytes(str(self.last_ping))
156 last = str_to_bytes(str(self.last_ping))
156 if msg[1] == current:
157 if msg[1] == current:
157 delta = time.time()-self.tic
158 delta = time.time()-self.tic
158 # self.log.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta))
159 # self.log.debug("heartbeat::heart %r took %.2f ms to respond"%(msg[0], 1000*delta))
159 self.responses.add(msg[0])
160 self.responses.add(msg[0])
160 elif msg[1] == last:
161 elif msg[1] == last:
161 delta = time.time()-self.tic + (self.lifetime-self.last_ping)
162 delta = time.time()-self.tic + (self.lifetime-self.last_ping)
162 self.log.warn("heartbeat::heart %r missed a beat, and took %.2f ms to respond", msg[0], 1000*delta)
163 self.log.warn("heartbeat::heart %r missed a beat, and took %.2f ms to respond", msg[0], 1000*delta)
163 self.responses.add(msg[0])
164 self.responses.add(msg[0])
164 else:
165 else:
165 self.log.warn("heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)", msg[1], self.lifetime)
166 self.log.warn("heartbeat::got bad heartbeat (possibly old?): %s (current=%.3f)", msg[1], self.lifetime)
166
167
167
168
168 if __name__ == '__main__':
169 if __name__ == '__main__':
169 loop = ioloop.IOLoop.instance()
170 loop = ioloop.IOLoop.instance()
170 context = zmq.Context()
171 context = zmq.Context()
171 pub = context.socket(zmq.PUB)
172 pub = context.socket(zmq.PUB)
172 pub.bind('tcp://127.0.0.1:5555')
173 pub.bind('tcp://127.0.0.1:5555')
173 xrep = context.socket(zmq.ROUTER)
174 xrep = context.socket(zmq.ROUTER)
174 xrep.bind('tcp://127.0.0.1:5556')
175 xrep.bind('tcp://127.0.0.1:5556')
175
176
176 outstream = zmqstream.ZMQStream(pub, loop)
177 outstream = zmqstream.ZMQStream(pub, loop)
177 instream = zmqstream.ZMQStream(xrep, loop)
178 instream = zmqstream.ZMQStream(xrep, loop)
178
179
179 hb = HeartMonitor(loop, outstream, instream)
180 hb = HeartMonitor(loop, outstream, instream)
180
181
181 loop.start()
182 loop.start()
@@ -1,1300 +1,1310 b''
1 """The IPython Controller Hub with 0MQ
1 """The IPython Controller Hub with 0MQ
2 This is the master object that handles connections from engines and clients,
2 This is the master object that handles connections from engines and clients,
3 and monitors traffic through the various queues.
3 and monitors traffic through the various queues.
4
4
5 Authors:
5 Authors:
6
6
7 * Min RK
7 * Min RK
8 """
8 """
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Copyright (C) 2010-2011 The IPython Development Team
10 # Copyright (C) 2010-2011 The IPython Development Team
11 #
11 #
12 # Distributed under the terms of the BSD License. The full license is in
12 # Distributed under the terms of the BSD License. The full license is in
13 # the file COPYING, distributed as part of this software.
13 # the file COPYING, distributed as part of this software.
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15
15
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17 # Imports
17 # Imports
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19 from __future__ import print_function
19 from __future__ import print_function
20
20
21 import sys
21 import sys
22 import time
22 import time
23 from datetime import datetime
23 from datetime import datetime
24
24
25 import zmq
25 import zmq
26 from zmq.eventloop import ioloop
26 from zmq.eventloop import ioloop
27 from zmq.eventloop.zmqstream import ZMQStream
27 from zmq.eventloop.zmqstream import ZMQStream
28
28
29 # internal:
29 # internal:
30 from IPython.utils.importstring import import_item
30 from IPython.utils.importstring import import_item
31 from IPython.utils.py3compat import cast_bytes
31 from IPython.utils.traitlets import (
32 from IPython.utils.traitlets import (
32 HasTraits, Instance, Integer, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
33 HasTraits, Instance, Integer, Unicode, Dict, Set, Tuple, CBytes, DottedObjectName
33 )
34 )
34
35
35 from IPython.parallel import error, util
36 from IPython.parallel import error, util
36 from IPython.parallel.factory import RegistrationFactory
37 from IPython.parallel.factory import RegistrationFactory
37
38
38 from IPython.zmq.session import SessionFactory
39 from IPython.zmq.session import SessionFactory
39
40
40 from .heartmonitor import HeartMonitor
41 from .heartmonitor import HeartMonitor
41
42
42 #-----------------------------------------------------------------------------
43 #-----------------------------------------------------------------------------
43 # Code
44 # Code
44 #-----------------------------------------------------------------------------
45 #-----------------------------------------------------------------------------
45
46
46 def _passer(*args, **kwargs):
47 def _passer(*args, **kwargs):
47 return
48 return
48
49
49 def _printer(*args, **kwargs):
50 def _printer(*args, **kwargs):
50 print (args)
51 print (args)
51 print (kwargs)
52 print (kwargs)
52
53
53 def empty_record():
54 def empty_record():
54 """Return an empty dict with all record keys."""
55 """Return an empty dict with all record keys."""
55 return {
56 return {
56 'msg_id' : None,
57 'msg_id' : None,
57 'header' : None,
58 'header' : None,
58 'content': None,
59 'content': None,
59 'buffers': None,
60 'buffers': None,
60 'submitted': None,
61 'submitted': None,
61 'client_uuid' : None,
62 'client_uuid' : None,
62 'engine_uuid' : None,
63 'engine_uuid' : None,
63 'started': None,
64 'started': None,
64 'completed': None,
65 'completed': None,
65 'resubmitted': None,
66 'resubmitted': None,
66 'received': None,
67 'received': None,
67 'result_header' : None,
68 'result_header' : None,
68 'result_content' : None,
69 'result_content' : None,
69 'result_buffers' : None,
70 'result_buffers' : None,
70 'queue' : None,
71 'queue' : None,
71 'pyin' : None,
72 'pyin' : None,
72 'pyout': None,
73 'pyout': None,
73 'pyerr': None,
74 'pyerr': None,
74 'stdout': '',
75 'stdout': '',
75 'stderr': '',
76 'stderr': '',
76 }
77 }
77
78
78 def init_record(msg):
79 def init_record(msg):
79 """Initialize a TaskRecord based on a request."""
80 """Initialize a TaskRecord based on a request."""
80 header = msg['header']
81 header = msg['header']
81 return {
82 return {
82 'msg_id' : header['msg_id'],
83 'msg_id' : header['msg_id'],
83 'header' : header,
84 'header' : header,
84 'content': msg['content'],
85 'content': msg['content'],
85 'buffers': msg['buffers'],
86 'buffers': msg['buffers'],
86 'submitted': header['date'],
87 'submitted': header['date'],
87 'client_uuid' : None,
88 'client_uuid' : None,
88 'engine_uuid' : None,
89 'engine_uuid' : None,
89 'started': None,
90 'started': None,
90 'completed': None,
91 'completed': None,
91 'resubmitted': None,
92 'resubmitted': None,
92 'received': None,
93 'received': None,
93 'result_header' : None,
94 'result_header' : None,
94 'result_content' : None,
95 'result_content' : None,
95 'result_buffers' : None,
96 'result_buffers' : None,
96 'queue' : None,
97 'queue' : None,
97 'pyin' : None,
98 'pyin' : None,
98 'pyout': None,
99 'pyout': None,
99 'pyerr': None,
100 'pyerr': None,
100 'stdout': '',
101 'stdout': '',
101 'stderr': '',
102 'stderr': '',
102 }
103 }
103
104
104
105
105 class EngineConnector(HasTraits):
106 class EngineConnector(HasTraits):
106 """A simple object for accessing the various zmq connections of an object.
107 """A simple object for accessing the various zmq connections of an object.
107 Attributes are:
108 Attributes are:
108 id (int): engine ID
109 id (int): engine ID
109 uuid (str): uuid (unused?)
110 uuid (str): uuid (unused?)
110 queue (str): identity of queue's XREQ socket
111 queue (str): identity of queue's XREQ socket
111 registration (str): identity of registration XREQ socket
112 registration (str): identity of registration XREQ socket
112 heartbeat (str): identity of heartbeat XREQ socket
113 heartbeat (str): identity of heartbeat XREQ socket
113 """
114 """
114 id=Integer(0)
115 id=Integer(0)
115 queue=CBytes()
116 queue=CBytes()
116 control=CBytes()
117 control=CBytes()
117 registration=CBytes()
118 registration=CBytes()
118 heartbeat=CBytes()
119 heartbeat=CBytes()
119 pending=Set()
120 pending=Set()
120
121
121 class HubFactory(RegistrationFactory):
122 class HubFactory(RegistrationFactory):
122 """The Configurable for setting up a Hub."""
123 """The Configurable for setting up a Hub."""
123
124
124 # port-pairs for monitoredqueues:
125 # port-pairs for monitoredqueues:
125 hb = Tuple(Integer,Integer,config=True,
126 hb = Tuple(Integer,Integer,config=True,
126 help="""XREQ/SUB Port pair for Engine heartbeats""")
127 help="""XREQ/SUB Port pair for Engine heartbeats""")
127 def _hb_default(self):
128 def _hb_default(self):
128 return tuple(util.select_random_ports(2))
129 return tuple(util.select_random_ports(2))
129
130
130 mux = Tuple(Integer,Integer,config=True,
131 mux = Tuple(Integer,Integer,config=True,
131 help="""Engine/Client Port pair for MUX queue""")
132 help="""Engine/Client Port pair for MUX queue""")
132
133
133 def _mux_default(self):
134 def _mux_default(self):
134 return tuple(util.select_random_ports(2))
135 return tuple(util.select_random_ports(2))
135
136
136 task = Tuple(Integer,Integer,config=True,
137 task = Tuple(Integer,Integer,config=True,
137 help="""Engine/Client Port pair for Task queue""")
138 help="""Engine/Client Port pair for Task queue""")
138 def _task_default(self):
139 def _task_default(self):
139 return tuple(util.select_random_ports(2))
140 return tuple(util.select_random_ports(2))
140
141
141 control = Tuple(Integer,Integer,config=True,
142 control = Tuple(Integer,Integer,config=True,
142 help="""Engine/Client Port pair for Control queue""")
143 help="""Engine/Client Port pair for Control queue""")
143
144
144 def _control_default(self):
145 def _control_default(self):
145 return tuple(util.select_random_ports(2))
146 return tuple(util.select_random_ports(2))
146
147
147 iopub = Tuple(Integer,Integer,config=True,
148 iopub = Tuple(Integer,Integer,config=True,
148 help="""Engine/Client Port pair for IOPub relay""")
149 help="""Engine/Client Port pair for IOPub relay""")
149
150
150 def _iopub_default(self):
151 def _iopub_default(self):
151 return tuple(util.select_random_ports(2))
152 return tuple(util.select_random_ports(2))
152
153
153 # single ports:
154 # single ports:
154 mon_port = Integer(config=True,
155 mon_port = Integer(config=True,
155 help="""Monitor (SUB) port for queue traffic""")
156 help="""Monitor (SUB) port for queue traffic""")
156
157
157 def _mon_port_default(self):
158 def _mon_port_default(self):
158 return util.select_random_ports(1)[0]
159 return util.select_random_ports(1)[0]
159
160
160 notifier_port = Integer(config=True,
161 notifier_port = Integer(config=True,
161 help="""PUB port for sending engine status notifications""")
162 help="""PUB port for sending engine status notifications""")
162
163
163 def _notifier_port_default(self):
164 def _notifier_port_default(self):
164 return util.select_random_ports(1)[0]
165 return util.select_random_ports(1)[0]
165
166
166 engine_ip = Unicode('127.0.0.1', config=True,
167 engine_ip = Unicode('127.0.0.1', config=True,
167 help="IP on which to listen for engine connections. [default: loopback]")
168 help="IP on which to listen for engine connections. [default: loopback]")
168 engine_transport = Unicode('tcp', config=True,
169 engine_transport = Unicode('tcp', config=True,
169 help="0MQ transport for engine connections. [default: tcp]")
170 help="0MQ transport for engine connections. [default: tcp]")
170
171
171 client_ip = Unicode('127.0.0.1', config=True,
172 client_ip = Unicode('127.0.0.1', config=True,
172 help="IP on which to listen for client connections. [default: loopback]")
173 help="IP on which to listen for client connections. [default: loopback]")
173 client_transport = Unicode('tcp', config=True,
174 client_transport = Unicode('tcp', config=True,
174 help="0MQ transport for client connections. [default : tcp]")
175 help="0MQ transport for client connections. [default : tcp]")
175
176
176 monitor_ip = Unicode('127.0.0.1', config=True,
177 monitor_ip = Unicode('127.0.0.1', config=True,
177 help="IP on which to listen for monitor messages. [default: loopback]")
178 help="IP on which to listen for monitor messages. [default: loopback]")
178 monitor_transport = Unicode('tcp', config=True,
179 monitor_transport = Unicode('tcp', config=True,
179 help="0MQ transport for monitor messages. [default : tcp]")
180 help="0MQ transport for monitor messages. [default : tcp]")
180
181
181 monitor_url = Unicode('')
182 monitor_url = Unicode('')
182
183
183 db_class = DottedObjectName('IPython.parallel.controller.dictdb.DictDB',
184 db_class = DottedObjectName('IPython.parallel.controller.dictdb.DictDB',
184 config=True, help="""The class to use for the DB backend""")
185 config=True, help="""The class to use for the DB backend""")
185
186
186 # not configurable
187 # not configurable
187 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
188 db = Instance('IPython.parallel.controller.dictdb.BaseDB')
188 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
189 heartmonitor = Instance('IPython.parallel.controller.heartmonitor.HeartMonitor')
189
190
190 def _ip_changed(self, name, old, new):
191 def _ip_changed(self, name, old, new):
191 self.engine_ip = new
192 self.engine_ip = new
192 self.client_ip = new
193 self.client_ip = new
193 self.monitor_ip = new
194 self.monitor_ip = new
194 self._update_monitor_url()
195 self._update_monitor_url()
195
196
196 def _update_monitor_url(self):
197 def _update_monitor_url(self):
197 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
198 self.monitor_url = "%s://%s:%i" % (self.monitor_transport, self.monitor_ip, self.mon_port)
198
199
199 def _transport_changed(self, name, old, new):
200 def _transport_changed(self, name, old, new):
200 self.engine_transport = new
201 self.engine_transport = new
201 self.client_transport = new
202 self.client_transport = new
202 self.monitor_transport = new
203 self.monitor_transport = new
203 self._update_monitor_url()
204 self._update_monitor_url()
204
205
205 def __init__(self, **kwargs):
206 def __init__(self, **kwargs):
206 super(HubFactory, self).__init__(**kwargs)
207 super(HubFactory, self).__init__(**kwargs)
207 self._update_monitor_url()
208 self._update_monitor_url()
208
209
209
210
210 def construct(self):
211 def construct(self):
211 self.init_hub()
212 self.init_hub()
212
213
213 def start(self):
214 def start(self):
214 self.heartmonitor.start()
215 self.heartmonitor.start()
215 self.log.info("Heartmonitor started")
216 self.log.info("Heartmonitor started")
216
217
217 def init_hub(self):
218 def init_hub(self):
218 """construct"""
219 """construct"""
219 client_iface = "%s://%s:" % (self.client_transport, self.client_ip) + "%i"
220 client_iface = "%s://%s:" % (self.client_transport, self.client_ip) + "%i"
220 engine_iface = "%s://%s:" % (self.engine_transport, self.engine_ip) + "%i"
221 engine_iface = "%s://%s:" % (self.engine_transport, self.engine_ip) + "%i"
221
222
222 ctx = self.context
223 ctx = self.context
223 loop = self.loop
224 loop = self.loop
224
225
225 # Registrar socket
226 # Registrar socket
226 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
227 q = ZMQStream(ctx.socket(zmq.ROUTER), loop)
227 q.bind(client_iface % self.regport)
228 q.bind(client_iface % self.regport)
228 self.log.info("Hub listening on %s for registration.", client_iface % self.regport)
229 self.log.info("Hub listening on %s for registration.", client_iface % self.regport)
229 if self.client_ip != self.engine_ip:
230 if self.client_ip != self.engine_ip:
230 q.bind(engine_iface % self.regport)
231 q.bind(engine_iface % self.regport)
231 self.log.info("Hub listening on %s for registration.", engine_iface % self.regport)
232 self.log.info("Hub listening on %s for registration.", engine_iface % self.regport)
232
233
233 ### Engine connections ###
234 ### Engine connections ###
234
235
235 # heartbeat
236 # heartbeat
236 hpub = ctx.socket(zmq.PUB)
237 hpub = ctx.socket(zmq.PUB)
237 hpub.bind(engine_iface % self.hb[0])
238 hpub.bind(engine_iface % self.hb[0])
238 hrep = ctx.socket(zmq.ROUTER)
239 hrep = ctx.socket(zmq.ROUTER)
239 hrep.bind(engine_iface % self.hb[1])
240 hrep.bind(engine_iface % self.hb[1])
240 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
241 self.heartmonitor = HeartMonitor(loop=loop, config=self.config, log=self.log,
241 pingstream=ZMQStream(hpub,loop),
242 pingstream=ZMQStream(hpub,loop),
242 pongstream=ZMQStream(hrep,loop)
243 pongstream=ZMQStream(hrep,loop)
243 )
244 )
244
245
245 ### Client connections ###
246 ### Client connections ###
246 # Notifier socket
247 # Notifier socket
247 n = ZMQStream(ctx.socket(zmq.PUB), loop)
248 n = ZMQStream(ctx.socket(zmq.PUB), loop)
248 n.bind(client_iface%self.notifier_port)
249 n.bind(client_iface%self.notifier_port)
249
250
250 ### build and launch the queues ###
251 ### build and launch the queues ###
251
252
252 # monitor socket
253 # monitor socket
253 sub = ctx.socket(zmq.SUB)
254 sub = ctx.socket(zmq.SUB)
254 sub.setsockopt(zmq.SUBSCRIBE, b"")
255 sub.setsockopt(zmq.SUBSCRIBE, b"")
255 sub.bind(self.monitor_url)
256 sub.bind(self.monitor_url)
256 sub.bind('inproc://monitor')
257 sub.bind('inproc://monitor')
257 sub = ZMQStream(sub, loop)
258 sub = ZMQStream(sub, loop)
258
259
259 # connect the db
260 # connect the db
260 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
261 self.log.info('Hub using DB backend: %r'%(self.db_class.split()[-1]))
261 # cdir = self.config.Global.cluster_dir
262 # cdir = self.config.Global.cluster_dir
262 self.db = import_item(str(self.db_class))(session=self.session.session,
263 self.db = import_item(str(self.db_class))(session=self.session.session,
263 config=self.config, log=self.log)
264 config=self.config, log=self.log)
264 time.sleep(.25)
265 time.sleep(.25)
265 try:
266 try:
266 scheme = self.config.TaskScheduler.scheme_name
267 scheme = self.config.TaskScheduler.scheme_name
267 except AttributeError:
268 except AttributeError:
268 from .scheduler import TaskScheduler
269 from .scheduler import TaskScheduler
269 scheme = TaskScheduler.scheme_name.get_default_value()
270 scheme = TaskScheduler.scheme_name.get_default_value()
270 # build connection dicts
271 # build connection dicts
271 self.engine_info = {
272 self.engine_info = {
272 'control' : engine_iface%self.control[1],
273 'control' : engine_iface%self.control[1],
273 'mux': engine_iface%self.mux[1],
274 'mux': engine_iface%self.mux[1],
274 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
275 'heartbeat': (engine_iface%self.hb[0], engine_iface%self.hb[1]),
275 'task' : engine_iface%self.task[1],
276 'task' : engine_iface%self.task[1],
276 'iopub' : engine_iface%self.iopub[1],
277 'iopub' : engine_iface%self.iopub[1],
277 # 'monitor' : engine_iface%self.mon_port,
278 # 'monitor' : engine_iface%self.mon_port,
278 }
279 }
279
280
280 self.client_info = {
281 self.client_info = {
281 'control' : client_iface%self.control[0],
282 'control' : client_iface%self.control[0],
282 'mux': client_iface%self.mux[0],
283 'mux': client_iface%self.mux[0],
283 'task' : (scheme, client_iface%self.task[0]),
284 'task' : (scheme, client_iface%self.task[0]),
284 'iopub' : client_iface%self.iopub[0],
285 'iopub' : client_iface%self.iopub[0],
285 'notification': client_iface%self.notifier_port
286 'notification': client_iface%self.notifier_port
286 }
287 }
287 self.log.debug("Hub engine addrs: %s", self.engine_info)
288 self.log.debug("Hub engine addrs: %s", self.engine_info)
288 self.log.debug("Hub client addrs: %s", self.client_info)
289 self.log.debug("Hub client addrs: %s", self.client_info)
289
290
290 # resubmit stream
291 # resubmit stream
291 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
292 r = ZMQStream(ctx.socket(zmq.DEALER), loop)
292 url = util.disambiguate_url(self.client_info['task'][-1])
293 url = util.disambiguate_url(self.client_info['task'][-1])
293 r.setsockopt(zmq.IDENTITY, self.session.bsession)
294 r.setsockopt(zmq.IDENTITY, self.session.bsession)
294 r.connect(url)
295 r.connect(url)
295
296
296 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
297 self.hub = Hub(loop=loop, session=self.session, monitor=sub, heartmonitor=self.heartmonitor,
297 query=q, notifier=n, resubmit=r, db=self.db,
298 query=q, notifier=n, resubmit=r, db=self.db,
298 engine_info=self.engine_info, client_info=self.client_info,
299 engine_info=self.engine_info, client_info=self.client_info,
299 log=self.log)
300 log=self.log)
300
301
301
302
302 class Hub(SessionFactory):
303 class Hub(SessionFactory):
303 """The IPython Controller Hub with 0MQ connections
304 """The IPython Controller Hub with 0MQ connections
304
305
305 Parameters
306 Parameters
306 ==========
307 ==========
307 loop: zmq IOLoop instance
308 loop: zmq IOLoop instance
308 session: Session object
309 session: Session object
309 <removed> context: zmq context for creating new connections (?)
310 <removed> context: zmq context for creating new connections (?)
310 queue: ZMQStream for monitoring the command queue (SUB)
311 queue: ZMQStream for monitoring the command queue (SUB)
311 query: ZMQStream for engine registration and client queries requests (XREP)
312 query: ZMQStream for engine registration and client queries requests (XREP)
312 heartbeat: HeartMonitor object checking the pulse of the engines
313 heartbeat: HeartMonitor object checking the pulse of the engines
313 notifier: ZMQStream for broadcasting engine registration changes (PUB)
314 notifier: ZMQStream for broadcasting engine registration changes (PUB)
314 db: connection to db for out of memory logging of commands
315 db: connection to db for out of memory logging of commands
315 NotImplemented
316 NotImplemented
316 engine_info: dict of zmq connection information for engines to connect
317 engine_info: dict of zmq connection information for engines to connect
317 to the queues.
318 to the queues.
318 client_info: dict of zmq connection information for engines to connect
319 client_info: dict of zmq connection information for engines to connect
319 to the queues.
320 to the queues.
320 """
321 """
321 # internal data structures:
322 # internal data structures:
322 ids=Set() # engine IDs
323 ids=Set() # engine IDs
323 keytable=Dict()
324 keytable=Dict()
324 by_ident=Dict()
325 by_ident=Dict()
325 engines=Dict()
326 engines=Dict()
326 clients=Dict()
327 clients=Dict()
327 hearts=Dict()
328 hearts=Dict()
328 pending=Set()
329 pending=Set()
329 queues=Dict() # pending msg_ids keyed by engine_id
330 queues=Dict() # pending msg_ids keyed by engine_id
330 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
331 tasks=Dict() # pending msg_ids submitted as tasks, keyed by client_id
331 completed=Dict() # completed msg_ids keyed by engine_id
332 completed=Dict() # completed msg_ids keyed by engine_id
332 all_completed=Set() # completed msg_ids keyed by engine_id
333 all_completed=Set() # completed msg_ids keyed by engine_id
333 dead_engines=Set() # completed msg_ids keyed by engine_id
334 dead_engines=Set() # completed msg_ids keyed by engine_id
334 unassigned=Set() # set of task msg_ds not yet assigned a destination
335 unassigned=Set() # set of task msg_ds not yet assigned a destination
335 incoming_registrations=Dict()
336 incoming_registrations=Dict()
336 registration_timeout=Integer()
337 registration_timeout=Integer()
337 _idcounter=Integer(0)
338 _idcounter=Integer(0)
338
339
339 # objects from constructor:
340 # objects from constructor:
340 query=Instance(ZMQStream)
341 query=Instance(ZMQStream)
341 monitor=Instance(ZMQStream)
342 monitor=Instance(ZMQStream)
342 notifier=Instance(ZMQStream)
343 notifier=Instance(ZMQStream)
343 resubmit=Instance(ZMQStream)
344 resubmit=Instance(ZMQStream)
344 heartmonitor=Instance(HeartMonitor)
345 heartmonitor=Instance(HeartMonitor)
345 db=Instance(object)
346 db=Instance(object)
346 client_info=Dict()
347 client_info=Dict()
347 engine_info=Dict()
348 engine_info=Dict()
348
349
349
350
350 def __init__(self, **kwargs):
351 def __init__(self, **kwargs):
351 """
352 """
352 # universal:
353 # universal:
353 loop: IOLoop for creating future connections
354 loop: IOLoop for creating future connections
354 session: streamsession for sending serialized data
355 session: streamsession for sending serialized data
355 # engine:
356 # engine:
356 queue: ZMQStream for monitoring queue messages
357 queue: ZMQStream for monitoring queue messages
357 query: ZMQStream for engine+client registration and client requests
358 query: ZMQStream for engine+client registration and client requests
358 heartbeat: HeartMonitor object for tracking engines
359 heartbeat: HeartMonitor object for tracking engines
359 # extra:
360 # extra:
360 db: ZMQStream for db connection (NotImplemented)
361 db: ZMQStream for db connection (NotImplemented)
361 engine_info: zmq address/protocol dict for engine connections
362 engine_info: zmq address/protocol dict for engine connections
362 client_info: zmq address/protocol dict for client connections
363 client_info: zmq address/protocol dict for client connections
363 """
364 """
364
365
365 super(Hub, self).__init__(**kwargs)
366 super(Hub, self).__init__(**kwargs)
366 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
367 self.registration_timeout = max(5000, 2*self.heartmonitor.period)
367
368
368 # validate connection dicts:
369 # validate connection dicts:
369 for k,v in self.client_info.iteritems():
370 for k,v in self.client_info.iteritems():
370 if k == 'task':
371 if k == 'task':
371 util.validate_url_container(v[1])
372 util.validate_url_container(v[1])
372 else:
373 else:
373 util.validate_url_container(v)
374 util.validate_url_container(v)
374 # util.validate_url_container(self.client_info)
375 # util.validate_url_container(self.client_info)
375 util.validate_url_container(self.engine_info)
376 util.validate_url_container(self.engine_info)
376
377
377 # register our callbacks
378 # register our callbacks
378 self.query.on_recv(self.dispatch_query)
379 self.query.on_recv(self.dispatch_query)
379 self.monitor.on_recv(self.dispatch_monitor_traffic)
380 self.monitor.on_recv(self.dispatch_monitor_traffic)
380
381
381 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
382 self.heartmonitor.add_heart_failure_handler(self.handle_heart_failure)
382 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
383 self.heartmonitor.add_new_heart_handler(self.handle_new_heart)
383
384
384 self.monitor_handlers = {b'in' : self.save_queue_request,
385 self.monitor_handlers = {b'in' : self.save_queue_request,
385 b'out': self.save_queue_result,
386 b'out': self.save_queue_result,
386 b'intask': self.save_task_request,
387 b'intask': self.save_task_request,
387 b'outtask': self.save_task_result,
388 b'outtask': self.save_task_result,
388 b'tracktask': self.save_task_destination,
389 b'tracktask': self.save_task_destination,
389 b'incontrol': _passer,
390 b'incontrol': _passer,
390 b'outcontrol': _passer,
391 b'outcontrol': _passer,
391 b'iopub': self.save_iopub_message,
392 b'iopub': self.save_iopub_message,
392 }
393 }
393
394
394 self.query_handlers = {'queue_request': self.queue_status,
395 self.query_handlers = {'queue_request': self.queue_status,
395 'result_request': self.get_results,
396 'result_request': self.get_results,
396 'history_request': self.get_history,
397 'history_request': self.get_history,
397 'db_request': self.db_query,
398 'db_request': self.db_query,
398 'purge_request': self.purge_results,
399 'purge_request': self.purge_results,
399 'load_request': self.check_load,
400 'load_request': self.check_load,
400 'resubmit_request': self.resubmit_task,
401 'resubmit_request': self.resubmit_task,
401 'shutdown_request': self.shutdown_request,
402 'shutdown_request': self.shutdown_request,
402 'registration_request' : self.register_engine,
403 'registration_request' : self.register_engine,
403 'unregistration_request' : self.unregister_engine,
404 'unregistration_request' : self.unregister_engine,
404 'connection_request': self.connection_request,
405 'connection_request': self.connection_request,
405 }
406 }
406
407
407 # ignore resubmit replies
408 # ignore resubmit replies
408 self.resubmit.on_recv(lambda msg: None, copy=False)
409 self.resubmit.on_recv(lambda msg: None, copy=False)
409
410
410 self.log.info("hub::created hub")
411 self.log.info("hub::created hub")
411
412
412 @property
413 @property
413 def _next_id(self):
414 def _next_id(self):
414 """gemerate a new ID.
415 """gemerate a new ID.
415
416
416 No longer reuse old ids, just count from 0."""
417 No longer reuse old ids, just count from 0."""
417 newid = self._idcounter
418 newid = self._idcounter
418 self._idcounter += 1
419 self._idcounter += 1
419 return newid
420 return newid
420 # newid = 0
421 # newid = 0
421 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
422 # incoming = [id[0] for id in self.incoming_registrations.itervalues()]
422 # # print newid, self.ids, self.incoming_registrations
423 # # print newid, self.ids, self.incoming_registrations
423 # while newid in self.ids or newid in incoming:
424 # while newid in self.ids or newid in incoming:
424 # newid += 1
425 # newid += 1
425 # return newid
426 # return newid
426
427
427 #-----------------------------------------------------------------------------
428 #-----------------------------------------------------------------------------
428 # message validation
429 # message validation
429 #-----------------------------------------------------------------------------
430 #-----------------------------------------------------------------------------
430
431
431 def _validate_targets(self, targets):
432 def _validate_targets(self, targets):
432 """turn any valid targets argument into a list of integer ids"""
433 """turn any valid targets argument into a list of integer ids"""
433 if targets is None:
434 if targets is None:
434 # default to all
435 # default to all
435 return self.ids
436 return self.ids
436
437
437 if isinstance(targets, (int,str,unicode)):
438 if isinstance(targets, (int,str,unicode)):
438 # only one target specified
439 # only one target specified
439 targets = [targets]
440 targets = [targets]
440 _targets = []
441 _targets = []
441 for t in targets:
442 for t in targets:
442 # map raw identities to ids
443 # map raw identities to ids
443 if isinstance(t, (str,unicode)):
444 if isinstance(t, (str,unicode)):
444 t = self.by_ident.get(t, t)
445 t = self.by_ident.get(cast_bytes(t), t)
445 _targets.append(t)
446 _targets.append(t)
446 targets = _targets
447 targets = _targets
447 bad_targets = [ t for t in targets if t not in self.ids ]
448 bad_targets = [ t for t in targets if t not in self.ids ]
448 if bad_targets:
449 if bad_targets:
449 raise IndexError("No Such Engine: %r" % bad_targets)
450 raise IndexError("No Such Engine: %r" % bad_targets)
450 if not targets:
451 if not targets:
451 raise IndexError("No Engines Registered")
452 raise IndexError("No Engines Registered")
452 return targets
453 return targets
453
454
454 #-----------------------------------------------------------------------------
455 #-----------------------------------------------------------------------------
455 # dispatch methods (1 per stream)
456 # dispatch methods (1 per stream)
456 #-----------------------------------------------------------------------------
457 #-----------------------------------------------------------------------------
457
458
458
459
459 @util.log_errors
460 @util.log_errors
460 def dispatch_monitor_traffic(self, msg):
461 def dispatch_monitor_traffic(self, msg):
461 """all ME and Task queue messages come through here, as well as
462 """all ME and Task queue messages come through here, as well as
462 IOPub traffic."""
463 IOPub traffic."""
463 self.log.debug("monitor traffic: %r", msg[0])
464 self.log.debug("monitor traffic: %r", msg[0])
464 switch = msg[0]
465 switch = msg[0]
465 try:
466 try:
466 idents, msg = self.session.feed_identities(msg[1:])
467 idents, msg = self.session.feed_identities(msg[1:])
467 except ValueError:
468 except ValueError:
468 idents=[]
469 idents=[]
469 if not idents:
470 if not idents:
470 self.log.error("Bad Monitor Message: %r", msg)
471 self.log.error("Monitor message without topic: %r", msg)
471 return
472 return
472 handler = self.monitor_handlers.get(switch, None)
473 handler = self.monitor_handlers.get(switch, None)
473 if handler is not None:
474 if handler is not None:
474 handler(idents, msg)
475 handler(idents, msg)
475 else:
476 else:
476 self.log.error("Invalid monitor topic: %r", switch)
477 self.log.error("Unrecognized monitor topic: %r", switch)
477
478
478
479
479 @util.log_errors
480 @util.log_errors
480 def dispatch_query(self, msg):
481 def dispatch_query(self, msg):
481 """Route registration requests and queries from clients."""
482 """Route registration requests and queries from clients."""
482 try:
483 try:
483 idents, msg = self.session.feed_identities(msg)
484 idents, msg = self.session.feed_identities(msg)
484 except ValueError:
485 except ValueError:
485 idents = []
486 idents = []
486 if not idents:
487 if not idents:
487 self.log.error("Bad Query Message: %r", msg)
488 self.log.error("Bad Query Message: %r", msg)
488 return
489 return
489 client_id = idents[0]
490 client_id = idents[0]
490 try:
491 try:
491 msg = self.session.unserialize(msg, content=True)
492 msg = self.session.unserialize(msg, content=True)
492 except Exception:
493 except Exception:
493 content = error.wrap_exception()
494 content = error.wrap_exception()
494 self.log.error("Bad Query Message: %r", msg, exc_info=True)
495 self.log.error("Bad Query Message: %r", msg, exc_info=True)
495 self.session.send(self.query, "hub_error", ident=client_id,
496 self.session.send(self.query, "hub_error", ident=client_id,
496 content=content)
497 content=content)
497 return
498 return
498 # print client_id, header, parent, content
499 # print client_id, header, parent, content
499 #switch on message type:
500 #switch on message type:
500 msg_type = msg['header']['msg_type']
501 msg_type = msg['header']['msg_type']
501 self.log.info("client::client %r requested %r", client_id, msg_type)
502 self.log.info("client::client %r requested %r", client_id, msg_type)
502 handler = self.query_handlers.get(msg_type, None)
503 handler = self.query_handlers.get(msg_type, None)
503 try:
504 try:
504 assert handler is not None, "Bad Message Type: %r" % msg_type
505 assert handler is not None, "Bad Message Type: %r" % msg_type
505 except:
506 except:
506 content = error.wrap_exception()
507 content = error.wrap_exception()
507 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
508 self.log.error("Bad Message Type: %r", msg_type, exc_info=True)
508 self.session.send(self.query, "hub_error", ident=client_id,
509 self.session.send(self.query, "hub_error", ident=client_id,
509 content=content)
510 content=content)
510 return
511 return
511
512
512 else:
513 else:
513 handler(idents, msg)
514 handler(idents, msg)
514
515
515 def dispatch_db(self, msg):
516 def dispatch_db(self, msg):
516 """"""
517 """"""
517 raise NotImplementedError
518 raise NotImplementedError
518
519
519 #---------------------------------------------------------------------------
520 #---------------------------------------------------------------------------
520 # handler methods (1 per event)
521 # handler methods (1 per event)
521 #---------------------------------------------------------------------------
522 #---------------------------------------------------------------------------
522
523
523 #----------------------- Heartbeat --------------------------------------
524 #----------------------- Heartbeat --------------------------------------
524
525
525 def handle_new_heart(self, heart):
526 def handle_new_heart(self, heart):
526 """handler to attach to heartbeater.
527 """handler to attach to heartbeater.
527 Called when a new heart starts to beat.
528 Called when a new heart starts to beat.
528 Triggers completion of registration."""
529 Triggers completion of registration."""
529 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
530 self.log.debug("heartbeat::handle_new_heart(%r)", heart)
530 if heart not in self.incoming_registrations:
531 if heart not in self.incoming_registrations:
531 self.log.info("heartbeat::ignoring new heart: %r", heart)
532 self.log.info("heartbeat::ignoring new heart: %r", heart)
532 else:
533 else:
533 self.finish_registration(heart)
534 self.finish_registration(heart)
534
535
535
536
536 def handle_heart_failure(self, heart):
537 def handle_heart_failure(self, heart):
537 """handler to attach to heartbeater.
538 """handler to attach to heartbeater.
538 called when a previously registered heart fails to respond to beat request.
539 called when a previously registered heart fails to respond to beat request.
539 triggers unregistration"""
540 triggers unregistration"""
540 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
541 self.log.debug("heartbeat::handle_heart_failure(%r)", heart)
541 eid = self.hearts.get(heart, None)
542 eid = self.hearts.get(heart, None)
542 queue = self.engines[eid].queue
543 queue = self.engines[eid].queue
543 if eid is None or self.keytable[eid] in self.dead_engines:
544 if eid is None or self.keytable[eid] in self.dead_engines:
544 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
545 self.log.info("heartbeat::ignoring heart failure %r (not an engine or already dead)", heart)
545 else:
546 else:
546 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
547 self.unregister_engine(heart, dict(content=dict(id=eid, queue=queue)))
547
548
548 #----------------------- MUX Queue Traffic ------------------------------
549 #----------------------- MUX Queue Traffic ------------------------------
549
550
550 def save_queue_request(self, idents, msg):
551 def save_queue_request(self, idents, msg):
551 if len(idents) < 2:
552 if len(idents) < 2:
552 self.log.error("invalid identity prefix: %r", idents)
553 self.log.error("invalid identity prefix: %r", idents)
553 return
554 return
554 queue_id, client_id = idents[:2]
555 queue_id, client_id = idents[:2]
555 try:
556 try:
556 msg = self.session.unserialize(msg)
557 msg = self.session.unserialize(msg)
557 except Exception:
558 except Exception:
558 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
559 self.log.error("queue::client %r sent invalid message to %r: %r", client_id, queue_id, msg, exc_info=True)
559 return
560 return
560
561
561 eid = self.by_ident.get(queue_id, None)
562 eid = self.by_ident.get(queue_id, None)
562 if eid is None:
563 if eid is None:
563 self.log.error("queue::target %r not registered", queue_id)
564 self.log.error("queue::target %r not registered", queue_id)
564 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
565 self.log.debug("queue:: valid are: %r", self.by_ident.keys())
565 return
566 return
566 record = init_record(msg)
567 record = init_record(msg)
567 msg_id = record['msg_id']
568 msg_id = record['msg_id']
568 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
569 self.log.info("queue::client %r submitted request %r to %s", client_id, msg_id, eid)
569 # Unicode in records
570 # Unicode in records
570 record['engine_uuid'] = queue_id.decode('ascii')
571 record['engine_uuid'] = queue_id.decode('ascii')
571 record['client_uuid'] = client_id.decode('ascii')
572 record['client_uuid'] = client_id.decode('ascii')
572 record['queue'] = 'mux'
573 record['queue'] = 'mux'
573
574
574 try:
575 try:
575 # it's posible iopub arrived first:
576 # it's posible iopub arrived first:
576 existing = self.db.get_record(msg_id)
577 existing = self.db.get_record(msg_id)
577 for key,evalue in existing.iteritems():
578 for key,evalue in existing.iteritems():
578 rvalue = record.get(key, None)
579 rvalue = record.get(key, None)
579 if evalue and rvalue and evalue != rvalue:
580 if evalue and rvalue and evalue != rvalue:
580 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
581 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
581 elif evalue and not rvalue:
582 elif evalue and not rvalue:
582 record[key] = evalue
583 record[key] = evalue
583 try:
584 try:
584 self.db.update_record(msg_id, record)
585 self.db.update_record(msg_id, record)
585 except Exception:
586 except Exception:
586 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
587 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
587 except KeyError:
588 except KeyError:
588 try:
589 try:
589 self.db.add_record(msg_id, record)
590 self.db.add_record(msg_id, record)
590 except Exception:
591 except Exception:
591 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
592 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
592
593
593
594
594 self.pending.add(msg_id)
595 self.pending.add(msg_id)
595 self.queues[eid].append(msg_id)
596 self.queues[eid].append(msg_id)
596
597
597 def save_queue_result(self, idents, msg):
598 def save_queue_result(self, idents, msg):
598 if len(idents) < 2:
599 if len(idents) < 2:
599 self.log.error("invalid identity prefix: %r", idents)
600 self.log.error("invalid identity prefix: %r", idents)
600 return
601 return
601
602
602 client_id, queue_id = idents[:2]
603 client_id, queue_id = idents[:2]
603 try:
604 try:
604 msg = self.session.unserialize(msg)
605 msg = self.session.unserialize(msg)
605 except Exception:
606 except Exception:
606 self.log.error("queue::engine %r sent invalid message to %r: %r",
607 self.log.error("queue::engine %r sent invalid message to %r: %r",
607 queue_id, client_id, msg, exc_info=True)
608 queue_id, client_id, msg, exc_info=True)
608 return
609 return
609
610
610 eid = self.by_ident.get(queue_id, None)
611 eid = self.by_ident.get(queue_id, None)
611 if eid is None:
612 if eid is None:
612 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
613 self.log.error("queue::unknown engine %r is sending a reply: ", queue_id)
613 return
614 return
614
615
615 parent = msg['parent_header']
616 parent = msg['parent_header']
616 if not parent:
617 if not parent:
617 return
618 return
618 msg_id = parent['msg_id']
619 msg_id = parent['msg_id']
619 if msg_id in self.pending:
620 if msg_id in self.pending:
620 self.pending.remove(msg_id)
621 self.pending.remove(msg_id)
621 self.all_completed.add(msg_id)
622 self.all_completed.add(msg_id)
622 self.queues[eid].remove(msg_id)
623 self.queues[eid].remove(msg_id)
623 self.completed[eid].append(msg_id)
624 self.completed[eid].append(msg_id)
624 self.log.info("queue::request %r completed on %s", msg_id, eid)
625 self.log.info("queue::request %r completed on %s", msg_id, eid)
625 elif msg_id not in self.all_completed:
626 elif msg_id not in self.all_completed:
626 # it could be a result from a dead engine that died before delivering the
627 # it could be a result from a dead engine that died before delivering the
627 # result
628 # result
628 self.log.warn("queue:: unknown msg finished %r", msg_id)
629 self.log.warn("queue:: unknown msg finished %r", msg_id)
629 return
630 return
630 # update record anyway, because the unregistration could have been premature
631 # update record anyway, because the unregistration could have been premature
631 rheader = msg['header']
632 rheader = msg['header']
632 completed = rheader['date']
633 completed = rheader['date']
633 started = rheader.get('started', None)
634 started = rheader.get('started', None)
634 result = {
635 result = {
635 'result_header' : rheader,
636 'result_header' : rheader,
636 'result_content': msg['content'],
637 'result_content': msg['content'],
637 'received': datetime.now(),
638 'received': datetime.now(),
638 'started' : started,
639 'started' : started,
639 'completed' : completed
640 'completed' : completed
640 }
641 }
641
642
642 result['result_buffers'] = msg['buffers']
643 result['result_buffers'] = msg['buffers']
643 try:
644 try:
644 self.db.update_record(msg_id, result)
645 self.db.update_record(msg_id, result)
645 except Exception:
646 except Exception:
646 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
647 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
647
648
648
649
649 #--------------------- Task Queue Traffic ------------------------------
650 #--------------------- Task Queue Traffic ------------------------------
650
651
651 def save_task_request(self, idents, msg):
652 def save_task_request(self, idents, msg):
652 """Save the submission of a task."""
653 """Save the submission of a task."""
653 client_id = idents[0]
654 client_id = idents[0]
654
655
655 try:
656 try:
656 msg = self.session.unserialize(msg)
657 msg = self.session.unserialize(msg)
657 except Exception:
658 except Exception:
658 self.log.error("task::client %r sent invalid task message: %r",
659 self.log.error("task::client %r sent invalid task message: %r",
659 client_id, msg, exc_info=True)
660 client_id, msg, exc_info=True)
660 return
661 return
661 record = init_record(msg)
662 record = init_record(msg)
662
663
663 record['client_uuid'] = client_id.decode('ascii')
664 record['client_uuid'] = client_id.decode('ascii')
664 record['queue'] = 'task'
665 record['queue'] = 'task'
665 header = msg['header']
666 header = msg['header']
666 msg_id = header['msg_id']
667 msg_id = header['msg_id']
667 self.pending.add(msg_id)
668 self.pending.add(msg_id)
668 self.unassigned.add(msg_id)
669 self.unassigned.add(msg_id)
669 try:
670 try:
670 # it's posible iopub arrived first:
671 # it's posible iopub arrived first:
671 existing = self.db.get_record(msg_id)
672 existing = self.db.get_record(msg_id)
672 if existing['resubmitted']:
673 if existing['resubmitted']:
673 for key in ('submitted', 'client_uuid', 'buffers'):
674 for key in ('submitted', 'client_uuid', 'buffers'):
674 # don't clobber these keys on resubmit
675 # don't clobber these keys on resubmit
675 # submitted and client_uuid should be different
676 # submitted and client_uuid should be different
676 # and buffers might be big, and shouldn't have changed
677 # and buffers might be big, and shouldn't have changed
677 record.pop(key)
678 record.pop(key)
678 # still check content,header which should not change
679 # still check content,header which should not change
679 # but are not expensive to compare as buffers
680 # but are not expensive to compare as buffers
680
681
681 for key,evalue in existing.iteritems():
682 for key,evalue in existing.iteritems():
682 if key.endswith('buffers'):
683 if key.endswith('buffers'):
683 # don't compare buffers
684 # don't compare buffers
684 continue
685 continue
685 rvalue = record.get(key, None)
686 rvalue = record.get(key, None)
686 if evalue and rvalue and evalue != rvalue:
687 if evalue and rvalue and evalue != rvalue:
687 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
688 self.log.warn("conflicting initial state for record: %r:%r <%r> %r", msg_id, rvalue, key, evalue)
688 elif evalue and not rvalue:
689 elif evalue and not rvalue:
689 record[key] = evalue
690 record[key] = evalue
690 try:
691 try:
691 self.db.update_record(msg_id, record)
692 self.db.update_record(msg_id, record)
692 except Exception:
693 except Exception:
693 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
694 self.log.error("DB Error updating record %r", msg_id, exc_info=True)
694 except KeyError:
695 except KeyError:
695 try:
696 try:
696 self.db.add_record(msg_id, record)
697 self.db.add_record(msg_id, record)
697 except Exception:
698 except Exception:
698 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
699 self.log.error("DB Error adding record %r", msg_id, exc_info=True)
699 except Exception:
700 except Exception:
700 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
701 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
701
702
702 def save_task_result(self, idents, msg):
703 def save_task_result(self, idents, msg):
703 """save the result of a completed task."""
704 """save the result of a completed task."""
704 client_id = idents[0]
705 client_id = idents[0]
705 try:
706 try:
706 msg = self.session.unserialize(msg)
707 msg = self.session.unserialize(msg)
707 except Exception:
708 except Exception:
708 self.log.error("task::invalid task result message send to %r: %r",
709 self.log.error("task::invalid task result message send to %r: %r",
709 client_id, msg, exc_info=True)
710 client_id, msg, exc_info=True)
710 return
711 return
711
712
712 parent = msg['parent_header']
713 parent = msg['parent_header']
713 if not parent:
714 if not parent:
714 # print msg
715 # print msg
715 self.log.warn("Task %r had no parent!", msg)
716 self.log.warn("Task %r had no parent!", msg)
716 return
717 return
717 msg_id = parent['msg_id']
718 msg_id = parent['msg_id']
718 if msg_id in self.unassigned:
719 if msg_id in self.unassigned:
719 self.unassigned.remove(msg_id)
720 self.unassigned.remove(msg_id)
720
721
721 header = msg['header']
722 header = msg['header']
722 engine_uuid = header.get('engine', None)
723 engine_uuid = header.get('engine', u'')
723 eid = self.by_ident.get(engine_uuid, None)
724 eid = self.by_ident.get(cast_bytes(engine_uuid), None)
725
726 status = header.get('status', None)
724
727
725 if msg_id in self.pending:
728 if msg_id in self.pending:
726 self.log.info("task::task %r finished on %s", msg_id, eid)
729 self.log.info("task::task %r finished on %s", msg_id, eid)
727 self.pending.remove(msg_id)
730 self.pending.remove(msg_id)
728 self.all_completed.add(msg_id)
731 self.all_completed.add(msg_id)
729 if eid is not None:
732 if eid is not None:
730 self.completed[eid].append(msg_id)
733 if status != 'aborted':
734 self.completed[eid].append(msg_id)
731 if msg_id in self.tasks[eid]:
735 if msg_id in self.tasks[eid]:
732 self.tasks[eid].remove(msg_id)
736 self.tasks[eid].remove(msg_id)
733 completed = header['date']
737 completed = header['date']
734 started = header.get('started', None)
738 started = header.get('started', None)
735 result = {
739 result = {
736 'result_header' : header,
740 'result_header' : header,
737 'result_content': msg['content'],
741 'result_content': msg['content'],
738 'started' : started,
742 'started' : started,
739 'completed' : completed,
743 'completed' : completed,
740 'received' : datetime.now(),
744 'received' : datetime.now(),
741 'engine_uuid': engine_uuid,
745 'engine_uuid': engine_uuid,
742 }
746 }
743
747
744 result['result_buffers'] = msg['buffers']
748 result['result_buffers'] = msg['buffers']
745 try:
749 try:
746 self.db.update_record(msg_id, result)
750 self.db.update_record(msg_id, result)
747 except Exception:
751 except Exception:
748 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
752 self.log.error("DB Error saving task request %r", msg_id, exc_info=True)
749
753
750 else:
754 else:
751 self.log.debug("task::unknown task %r finished", msg_id)
755 self.log.debug("task::unknown task %r finished", msg_id)
752
756
753 def save_task_destination(self, idents, msg):
757 def save_task_destination(self, idents, msg):
754 try:
758 try:
755 msg = self.session.unserialize(msg, content=True)
759 msg = self.session.unserialize(msg, content=True)
756 except Exception:
760 except Exception:
757 self.log.error("task::invalid task tracking message", exc_info=True)
761 self.log.error("task::invalid task tracking message", exc_info=True)
758 return
762 return
759 content = msg['content']
763 content = msg['content']
760 # print (content)
764 # print (content)
761 msg_id = content['msg_id']
765 msg_id = content['msg_id']
762 engine_uuid = content['engine_id']
766 engine_uuid = content['engine_id']
763 eid = self.by_ident[util.asbytes(engine_uuid)]
767 eid = self.by_ident[cast_bytes(engine_uuid)]
764
768
765 self.log.info("task::task %r arrived on %r", msg_id, eid)
769 self.log.info("task::task %r arrived on %r", msg_id, eid)
766 if msg_id in self.unassigned:
770 if msg_id in self.unassigned:
767 self.unassigned.remove(msg_id)
771 self.unassigned.remove(msg_id)
768 # else:
772 # else:
769 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
773 # self.log.debug("task::task %r not listed as MIA?!"%(msg_id))
770
774
771 self.tasks[eid].append(msg_id)
775 self.tasks[eid].append(msg_id)
772 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
776 # self.pending[msg_id][1].update(received=datetime.now(),engine=(eid,engine_uuid))
773 try:
777 try:
774 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
778 self.db.update_record(msg_id, dict(engine_uuid=engine_uuid))
775 except Exception:
779 except Exception:
776 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
780 self.log.error("DB Error saving task destination %r", msg_id, exc_info=True)
777
781
778
782
779 def mia_task_request(self, idents, msg):
783 def mia_task_request(self, idents, msg):
780 raise NotImplementedError
784 raise NotImplementedError
781 client_id = idents[0]
785 client_id = idents[0]
782 # content = dict(mia=self.mia,status='ok')
786 # content = dict(mia=self.mia,status='ok')
783 # self.session.send('mia_reply', content=content, idents=client_id)
787 # self.session.send('mia_reply', content=content, idents=client_id)
784
788
785
789
786 #--------------------- IOPub Traffic ------------------------------
790 #--------------------- IOPub Traffic ------------------------------
787
791
788 def save_iopub_message(self, topics, msg):
792 def save_iopub_message(self, topics, msg):
789 """save an iopub message into the db"""
793 """save an iopub message into the db"""
790 # print (topics)
794 # print (topics)
791 try:
795 try:
792 msg = self.session.unserialize(msg, content=True)
796 msg = self.session.unserialize(msg, content=True)
793 except Exception:
797 except Exception:
794 self.log.error("iopub::invalid IOPub message", exc_info=True)
798 self.log.error("iopub::invalid IOPub message", exc_info=True)
795 return
799 return
796
800
797 parent = msg['parent_header']
801 parent = msg['parent_header']
798 if not parent:
802 if not parent:
799 self.log.error("iopub::invalid IOPub message: %r", msg)
803 self.log.warn("iopub::IOPub message lacks parent: %r", msg)
800 return
804 return
801 msg_id = parent['msg_id']
805 msg_id = parent['msg_id']
802 msg_type = msg['header']['msg_type']
806 msg_type = msg['header']['msg_type']
803 content = msg['content']
807 content = msg['content']
804
808
805 # ensure msg_id is in db
809 # ensure msg_id is in db
806 try:
810 try:
807 rec = self.db.get_record(msg_id)
811 rec = self.db.get_record(msg_id)
808 except KeyError:
812 except KeyError:
809 rec = empty_record()
813 rec = empty_record()
810 rec['msg_id'] = msg_id
814 rec['msg_id'] = msg_id
811 self.db.add_record(msg_id, rec)
815 self.db.add_record(msg_id, rec)
812 # stream
816 # stream
813 d = {}
817 d = {}
814 if msg_type == 'stream':
818 if msg_type == 'stream':
815 name = content['name']
819 name = content['name']
816 s = rec[name] or ''
820 s = rec[name] or ''
817 d[name] = s + content['data']
821 d[name] = s + content['data']
818
822
819 elif msg_type == 'pyerr':
823 elif msg_type == 'pyerr':
820 d['pyerr'] = content
824 d['pyerr'] = content
821 elif msg_type == 'pyin':
825 elif msg_type == 'pyin':
822 d['pyin'] = content['code']
826 d['pyin'] = content['code']
823 else:
827 else:
824 d[msg_type] = content.get('data', '')
828 d[msg_type] = content.get('data', '')
825
829
826 try:
830 try:
827 self.db.update_record(msg_id, d)
831 self.db.update_record(msg_id, d)
828 except Exception:
832 except Exception:
829 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
833 self.log.error("DB Error saving iopub message %r", msg_id, exc_info=True)
830
834
831
835
832
836
833 #-------------------------------------------------------------------------
837 #-------------------------------------------------------------------------
834 # Registration requests
838 # Registration requests
835 #-------------------------------------------------------------------------
839 #-------------------------------------------------------------------------
836
840
837 def connection_request(self, client_id, msg):
841 def connection_request(self, client_id, msg):
838 """Reply with connection addresses for clients."""
842 """Reply with connection addresses for clients."""
839 self.log.info("client::client %r connected", client_id)
843 self.log.info("client::client %r connected", client_id)
840 content = dict(status='ok')
844 content = dict(status='ok')
841 content.update(self.client_info)
845 content.update(self.client_info)
842 jsonable = {}
846 jsonable = {}
843 for k,v in self.keytable.iteritems():
847 for k,v in self.keytable.iteritems():
844 if v not in self.dead_engines:
848 if v not in self.dead_engines:
845 jsonable[str(k)] = v.decode('ascii')
849 jsonable[str(k)] = v.decode('ascii')
846 content['engines'] = jsonable
850 content['engines'] = jsonable
847 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
851 self.session.send(self.query, 'connection_reply', content, parent=msg, ident=client_id)
848
852
849 def register_engine(self, reg, msg):
853 def register_engine(self, reg, msg):
850 """Register a new engine."""
854 """Register a new engine."""
851 content = msg['content']
855 content = msg['content']
852 try:
856 try:
853 queue = util.asbytes(content['queue'])
857 queue = cast_bytes(content['queue'])
854 except KeyError:
858 except KeyError:
855 self.log.error("registration::queue not specified", exc_info=True)
859 self.log.error("registration::queue not specified", exc_info=True)
856 return
860 return
857 heart = content.get('heartbeat', None)
861 heart = content.get('heartbeat', None)
858 if heart:
862 if heart:
859 heart = util.asbytes(heart)
863 heart = cast_bytes(heart)
860 """register a new engine, and create the socket(s) necessary"""
864 """register a new engine, and create the socket(s) necessary"""
861 eid = self._next_id
865 eid = self._next_id
862 # print (eid, queue, reg, heart)
866 # print (eid, queue, reg, heart)
863
867
864 self.log.debug("registration::register_engine(%i, %r, %r, %r)", eid, queue, reg, heart)
868 self.log.debug("registration::register_engine(%i, %r, %r, %r)", eid, queue, reg, heart)
865
869
866 content = dict(id=eid,status='ok')
870 content = dict(id=eid,status='ok')
867 content.update(self.engine_info)
871 content.update(self.engine_info)
868 # check if requesting available IDs:
872 # check if requesting available IDs:
869 if queue in self.by_ident:
873 if queue in self.by_ident:
870 try:
874 try:
871 raise KeyError("queue_id %r in use" % queue)
875 raise KeyError("queue_id %r in use" % queue)
872 except:
876 except:
873 content = error.wrap_exception()
877 content = error.wrap_exception()
874 self.log.error("queue_id %r in use", queue, exc_info=True)
878 self.log.error("queue_id %r in use", queue, exc_info=True)
875 elif heart in self.hearts: # need to check unique hearts?
879 elif heart in self.hearts: # need to check unique hearts?
876 try:
880 try:
877 raise KeyError("heart_id %r in use" % heart)
881 raise KeyError("heart_id %r in use" % heart)
878 except:
882 except:
879 self.log.error("heart_id %r in use", heart, exc_info=True)
883 self.log.error("heart_id %r in use", heart, exc_info=True)
880 content = error.wrap_exception()
884 content = error.wrap_exception()
881 else:
885 else:
882 for h, pack in self.incoming_registrations.iteritems():
886 for h, pack in self.incoming_registrations.iteritems():
883 if heart == h:
887 if heart == h:
884 try:
888 try:
885 raise KeyError("heart_id %r in use" % heart)
889 raise KeyError("heart_id %r in use" % heart)
886 except:
890 except:
887 self.log.error("heart_id %r in use", heart, exc_info=True)
891 self.log.error("heart_id %r in use", heart, exc_info=True)
888 content = error.wrap_exception()
892 content = error.wrap_exception()
889 break
893 break
890 elif queue == pack[1]:
894 elif queue == pack[1]:
891 try:
895 try:
892 raise KeyError("queue_id %r in use" % queue)
896 raise KeyError("queue_id %r in use" % queue)
893 except:
897 except:
894 self.log.error("queue_id %r in use", queue, exc_info=True)
898 self.log.error("queue_id %r in use", queue, exc_info=True)
895 content = error.wrap_exception()
899 content = error.wrap_exception()
896 break
900 break
897
901
898 msg = self.session.send(self.query, "registration_reply",
902 msg = self.session.send(self.query, "registration_reply",
899 content=content,
903 content=content,
900 ident=reg)
904 ident=reg)
901
905
902 if content['status'] == 'ok':
906 if content['status'] == 'ok':
903 if heart in self.heartmonitor.hearts:
907 if heart in self.heartmonitor.hearts:
904 # already beating
908 # already beating
905 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
909 self.incoming_registrations[heart] = (eid,queue,reg[0],None)
906 self.finish_registration(heart)
910 self.finish_registration(heart)
907 else:
911 else:
908 purge = lambda : self._purge_stalled_registration(heart)
912 purge = lambda : self._purge_stalled_registration(heart)
909 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
913 dc = ioloop.DelayedCallback(purge, self.registration_timeout, self.loop)
910 dc.start()
914 dc.start()
911 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
915 self.incoming_registrations[heart] = (eid,queue,reg[0],dc)
912 else:
916 else:
913 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
917 self.log.error("registration::registration %i failed: %r", eid, content['evalue'])
914 return eid
918 return eid
915
919
916 def unregister_engine(self, ident, msg):
920 def unregister_engine(self, ident, msg):
917 """Unregister an engine that explicitly requested to leave."""
921 """Unregister an engine that explicitly requested to leave."""
918 try:
922 try:
919 eid = msg['content']['id']
923 eid = msg['content']['id']
920 except:
924 except:
921 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
925 self.log.error("registration::bad engine id for unregistration: %r", ident, exc_info=True)
922 return
926 return
923 self.log.info("registration::unregister_engine(%r)", eid)
927 self.log.info("registration::unregister_engine(%r)", eid)
924 # print (eid)
928 # print (eid)
925 uuid = self.keytable[eid]
929 uuid = self.keytable[eid]
926 content=dict(id=eid, queue=uuid.decode('ascii'))
930 content=dict(id=eid, queue=uuid.decode('ascii'))
927 self.dead_engines.add(uuid)
931 self.dead_engines.add(uuid)
928 # self.ids.remove(eid)
932 # self.ids.remove(eid)
929 # uuid = self.keytable.pop(eid)
933 # uuid = self.keytable.pop(eid)
930 #
934 #
931 # ec = self.engines.pop(eid)
935 # ec = self.engines.pop(eid)
932 # self.hearts.pop(ec.heartbeat)
936 # self.hearts.pop(ec.heartbeat)
933 # self.by_ident.pop(ec.queue)
937 # self.by_ident.pop(ec.queue)
934 # self.completed.pop(eid)
938 # self.completed.pop(eid)
935 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
939 handleit = lambda : self._handle_stranded_msgs(eid, uuid)
936 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
940 dc = ioloop.DelayedCallback(handleit, self.registration_timeout, self.loop)
937 dc.start()
941 dc.start()
938 ############## TODO: HANDLE IT ################
942 ############## TODO: HANDLE IT ################
939
943
940 if self.notifier:
944 if self.notifier:
941 self.session.send(self.notifier, "unregistration_notification", content=content)
945 self.session.send(self.notifier, "unregistration_notification", content=content)
942
946
943 def _handle_stranded_msgs(self, eid, uuid):
947 def _handle_stranded_msgs(self, eid, uuid):
944 """Handle messages known to be on an engine when the engine unregisters.
948 """Handle messages known to be on an engine when the engine unregisters.
945
949
946 It is possible that this will fire prematurely - that is, an engine will
950 It is possible that this will fire prematurely - that is, an engine will
947 go down after completing a result, and the client will be notified
951 go down after completing a result, and the client will be notified
948 that the result failed and later receive the actual result.
952 that the result failed and later receive the actual result.
949 """
953 """
950
954
951 outstanding = self.queues[eid]
955 outstanding = self.queues[eid]
952
956
953 for msg_id in outstanding:
957 for msg_id in outstanding:
954 self.pending.remove(msg_id)
958 self.pending.remove(msg_id)
955 self.all_completed.add(msg_id)
959 self.all_completed.add(msg_id)
956 try:
960 try:
957 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
961 raise error.EngineError("Engine %r died while running task %r" % (eid, msg_id))
958 except:
962 except:
959 content = error.wrap_exception()
963 content = error.wrap_exception()
960 # build a fake header:
964 # build a fake header:
961 header = {}
965 header = {}
962 header['engine'] = uuid
966 header['engine'] = uuid
963 header['date'] = datetime.now()
967 header['date'] = datetime.now()
964 rec = dict(result_content=content, result_header=header, result_buffers=[])
968 rec = dict(result_content=content, result_header=header, result_buffers=[])
965 rec['completed'] = header['date']
969 rec['completed'] = header['date']
966 rec['engine_uuid'] = uuid
970 rec['engine_uuid'] = uuid
967 try:
971 try:
968 self.db.update_record(msg_id, rec)
972 self.db.update_record(msg_id, rec)
969 except Exception:
973 except Exception:
970 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
974 self.log.error("DB Error handling stranded msg %r", msg_id, exc_info=True)
971
975
972
976
973 def finish_registration(self, heart):
977 def finish_registration(self, heart):
974 """Second half of engine registration, called after our HeartMonitor
978 """Second half of engine registration, called after our HeartMonitor
975 has received a beat from the Engine's Heart."""
979 has received a beat from the Engine's Heart."""
976 try:
980 try:
977 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
981 (eid,queue,reg,purge) = self.incoming_registrations.pop(heart)
978 except KeyError:
982 except KeyError:
979 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
983 self.log.error("registration::tried to finish nonexistant registration", exc_info=True)
980 return
984 return
981 self.log.info("registration::finished registering engine %i:%r", eid, queue)
985 self.log.info("registration::finished registering engine %i:%r", eid, queue)
982 if purge is not None:
986 if purge is not None:
983 purge.stop()
987 purge.stop()
984 control = queue
988 control = queue
985 self.ids.add(eid)
989 self.ids.add(eid)
986 self.keytable[eid] = queue
990 self.keytable[eid] = queue
987 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
991 self.engines[eid] = EngineConnector(id=eid, queue=queue, registration=reg,
988 control=control, heartbeat=heart)
992 control=control, heartbeat=heart)
989 self.by_ident[queue] = eid
993 self.by_ident[queue] = eid
990 self.queues[eid] = list()
994 self.queues[eid] = list()
991 self.tasks[eid] = list()
995 self.tasks[eid] = list()
992 self.completed[eid] = list()
996 self.completed[eid] = list()
993 self.hearts[heart] = eid
997 self.hearts[heart] = eid
994 content = dict(id=eid, queue=self.engines[eid].queue.decode('ascii'))
998 content = dict(id=eid, queue=self.engines[eid].queue.decode('ascii'))
995 if self.notifier:
999 if self.notifier:
996 self.session.send(self.notifier, "registration_notification", content=content)
1000 self.session.send(self.notifier, "registration_notification", content=content)
997 self.log.info("engine::Engine Connected: %i", eid)
1001 self.log.info("engine::Engine Connected: %i", eid)
998
1002
999 def _purge_stalled_registration(self, heart):
1003 def _purge_stalled_registration(self, heart):
1000 if heart in self.incoming_registrations:
1004 if heart in self.incoming_registrations:
1001 eid = self.incoming_registrations.pop(heart)[0]
1005 eid = self.incoming_registrations.pop(heart)[0]
1002 self.log.info("registration::purging stalled registration: %i", eid)
1006 self.log.info("registration::purging stalled registration: %i", eid)
1003 else:
1007 else:
1004 pass
1008 pass
1005
1009
1006 #-------------------------------------------------------------------------
1010 #-------------------------------------------------------------------------
1007 # Client Requests
1011 # Client Requests
1008 #-------------------------------------------------------------------------
1012 #-------------------------------------------------------------------------
1009
1013
1010 def shutdown_request(self, client_id, msg):
1014 def shutdown_request(self, client_id, msg):
1011 """handle shutdown request."""
1015 """handle shutdown request."""
1012 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1016 self.session.send(self.query, 'shutdown_reply', content={'status': 'ok'}, ident=client_id)
1013 # also notify other clients of shutdown
1017 # also notify other clients of shutdown
1014 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1018 self.session.send(self.notifier, 'shutdown_notice', content={'status': 'ok'})
1015 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1019 dc = ioloop.DelayedCallback(lambda : self._shutdown(), 1000, self.loop)
1016 dc.start()
1020 dc.start()
1017
1021
1018 def _shutdown(self):
1022 def _shutdown(self):
1019 self.log.info("hub::hub shutting down.")
1023 self.log.info("hub::hub shutting down.")
1020 time.sleep(0.1)
1024 time.sleep(0.1)
1021 sys.exit(0)
1025 sys.exit(0)
1022
1026
1023
1027
1024 def check_load(self, client_id, msg):
1028 def check_load(self, client_id, msg):
1025 content = msg['content']
1029 content = msg['content']
1026 try:
1030 try:
1027 targets = content['targets']
1031 targets = content['targets']
1028 targets = self._validate_targets(targets)
1032 targets = self._validate_targets(targets)
1029 except:
1033 except:
1030 content = error.wrap_exception()
1034 content = error.wrap_exception()
1031 self.session.send(self.query, "hub_error",
1035 self.session.send(self.query, "hub_error",
1032 content=content, ident=client_id)
1036 content=content, ident=client_id)
1033 return
1037 return
1034
1038
1035 content = dict(status='ok')
1039 content = dict(status='ok')
1036 # loads = {}
1040 # loads = {}
1037 for t in targets:
1041 for t in targets:
1038 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1042 content[bytes(t)] = len(self.queues[t])+len(self.tasks[t])
1039 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1043 self.session.send(self.query, "load_reply", content=content, ident=client_id)
1040
1044
1041
1045
1042 def queue_status(self, client_id, msg):
1046 def queue_status(self, client_id, msg):
1043 """Return the Queue status of one or more targets.
1047 """Return the Queue status of one or more targets.
1044 if verbose: return the msg_ids
1048 if verbose: return the msg_ids
1045 else: return len of each type.
1049 else: return len of each type.
1046 keys: queue (pending MUX jobs)
1050 keys: queue (pending MUX jobs)
1047 tasks (pending Task jobs)
1051 tasks (pending Task jobs)
1048 completed (finished jobs from both queues)"""
1052 completed (finished jobs from both queues)"""
1049 content = msg['content']
1053 content = msg['content']
1050 targets = content['targets']
1054 targets = content['targets']
1051 try:
1055 try:
1052 targets = self._validate_targets(targets)
1056 targets = self._validate_targets(targets)
1053 except:
1057 except:
1054 content = error.wrap_exception()
1058 content = error.wrap_exception()
1055 self.session.send(self.query, "hub_error",
1059 self.session.send(self.query, "hub_error",
1056 content=content, ident=client_id)
1060 content=content, ident=client_id)
1057 return
1061 return
1058 verbose = content.get('verbose', False)
1062 verbose = content.get('verbose', False)
1059 content = dict(status='ok')
1063 content = dict(status='ok')
1060 for t in targets:
1064 for t in targets:
1061 queue = self.queues[t]
1065 queue = self.queues[t]
1062 completed = self.completed[t]
1066 completed = self.completed[t]
1063 tasks = self.tasks[t]
1067 tasks = self.tasks[t]
1064 if not verbose:
1068 if not verbose:
1065 queue = len(queue)
1069 queue = len(queue)
1066 completed = len(completed)
1070 completed = len(completed)
1067 tasks = len(tasks)
1071 tasks = len(tasks)
1068 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1072 content[str(t)] = {'queue': queue, 'completed': completed , 'tasks': tasks}
1069 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1073 content['unassigned'] = list(self.unassigned) if verbose else len(self.unassigned)
1070 # print (content)
1074 # print (content)
1071 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1075 self.session.send(self.query, "queue_reply", content=content, ident=client_id)
1072
1076
1073 def purge_results(self, client_id, msg):
1077 def purge_results(self, client_id, msg):
1074 """Purge results from memory. This method is more valuable before we move
1078 """Purge results from memory. This method is more valuable before we move
1075 to a DB based message storage mechanism."""
1079 to a DB based message storage mechanism."""
1076 content = msg['content']
1080 content = msg['content']
1077 self.log.info("Dropping records with %s", content)
1081 self.log.info("Dropping records with %s", content)
1078 msg_ids = content.get('msg_ids', [])
1082 msg_ids = content.get('msg_ids', [])
1079 reply = dict(status='ok')
1083 reply = dict(status='ok')
1080 if msg_ids == 'all':
1084 if msg_ids == 'all':
1081 try:
1085 try:
1082 self.db.drop_matching_records(dict(completed={'$ne':None}))
1086 self.db.drop_matching_records(dict(completed={'$ne':None}))
1083 except Exception:
1087 except Exception:
1084 reply = error.wrap_exception()
1088 reply = error.wrap_exception()
1085 else:
1089 else:
1086 pending = filter(lambda m: m in self.pending, msg_ids)
1090 pending = filter(lambda m: m in self.pending, msg_ids)
1087 if pending:
1091 if pending:
1088 try:
1092 try:
1089 raise IndexError("msg pending: %r" % pending[0])
1093 raise IndexError("msg pending: %r" % pending[0])
1090 except:
1094 except:
1091 reply = error.wrap_exception()
1095 reply = error.wrap_exception()
1092 else:
1096 else:
1093 try:
1097 try:
1094 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1098 self.db.drop_matching_records(dict(msg_id={'$in':msg_ids}))
1095 except Exception:
1099 except Exception:
1096 reply = error.wrap_exception()
1100 reply = error.wrap_exception()
1097
1101
1098 if reply['status'] == 'ok':
1102 if reply['status'] == 'ok':
1099 eids = content.get('engine_ids', [])
1103 eids = content.get('engine_ids', [])
1100 for eid in eids:
1104 for eid in eids:
1101 if eid not in self.engines:
1105 if eid not in self.engines:
1102 try:
1106 try:
1103 raise IndexError("No such engine: %i" % eid)
1107 raise IndexError("No such engine: %i" % eid)
1104 except:
1108 except:
1105 reply = error.wrap_exception()
1109 reply = error.wrap_exception()
1106 break
1110 break
1107 uid = self.engines[eid].queue
1111 uid = self.engines[eid].queue
1108 try:
1112 try:
1109 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1113 self.db.drop_matching_records(dict(engine_uuid=uid, completed={'$ne':None}))
1110 except Exception:
1114 except Exception:
1111 reply = error.wrap_exception()
1115 reply = error.wrap_exception()
1112 break
1116 break
1113
1117
1114 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1118 self.session.send(self.query, 'purge_reply', content=reply, ident=client_id)
1115
1119
1116 def resubmit_task(self, client_id, msg):
1120 def resubmit_task(self, client_id, msg):
1117 """Resubmit one or more tasks."""
1121 """Resubmit one or more tasks."""
1118 def finish(reply):
1122 def finish(reply):
1119 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1123 self.session.send(self.query, 'resubmit_reply', content=reply, ident=client_id)
1120
1124
1121 content = msg['content']
1125 content = msg['content']
1122 msg_ids = content['msg_ids']
1126 msg_ids = content['msg_ids']
1123 reply = dict(status='ok')
1127 reply = dict(status='ok')
1124 try:
1128 try:
1125 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1129 records = self.db.find_records({'msg_id' : {'$in' : msg_ids}}, keys=[
1126 'header', 'content', 'buffers'])
1130 'header', 'content', 'buffers'])
1127 except Exception:
1131 except Exception:
1128 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1132 self.log.error('db::db error finding tasks to resubmit', exc_info=True)
1129 return finish(error.wrap_exception())
1133 return finish(error.wrap_exception())
1130
1134
1131 # validate msg_ids
1135 # validate msg_ids
1132 found_ids = [ rec['msg_id'] for rec in records ]
1136 found_ids = [ rec['msg_id'] for rec in records ]
1133 invalid_ids = filter(lambda m: m in self.pending, found_ids)
1137 pending_ids = [ msg_id for msg_id in found_ids if msg_id in self.pending ]
1134 if len(records) > len(msg_ids):
1138 if len(records) > len(msg_ids):
1135 try:
1139 try:
1136 raise RuntimeError("DB appears to be in an inconsistent state."
1140 raise RuntimeError("DB appears to be in an inconsistent state."
1137 "More matching records were found than should exist")
1141 "More matching records were found than should exist")
1138 except Exception:
1142 except Exception:
1139 return finish(error.wrap_exception())
1143 return finish(error.wrap_exception())
1140 elif len(records) < len(msg_ids):
1144 elif len(records) < len(msg_ids):
1141 missing = [ m for m in msg_ids if m not in found_ids ]
1145 missing = [ m for m in msg_ids if m not in found_ids ]
1142 try:
1146 try:
1143 raise KeyError("No such msg(s): %r" % missing)
1147 raise KeyError("No such msg(s): %r" % missing)
1144 except KeyError:
1148 except KeyError:
1145 return finish(error.wrap_exception())
1149 return finish(error.wrap_exception())
1146 elif invalid_ids:
1150 elif pending_ids:
1147 msg_id = invalid_ids[0]
1151 pass
1152 # no need to raise on resubmit of pending task, now that we
1153 # resubmit under new ID, but do we want to raise anyway?
1154 # msg_id = invalid_ids[0]
1155 # try:
1156 # raise ValueError("Task(s) %r appears to be inflight" % )
1157 # except Exception:
1158 # return finish(error.wrap_exception())
1159
1160 # mapping of original IDs to resubmitted IDs
1161 resubmitted = {}
1162
1163 # send the messages
1164 for rec in records:
1165 header = rec['header']
1166 msg = self.session.msg(header['msg_type'])
1167 msg_id = msg['msg_id']
1168 msg['content'] = rec['content']
1169 header.update(msg['header'])
1170 msg['header'] = header
1171
1172 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1173
1174 resubmitted[rec['msg_id']] = msg_id
1175 self.pending.add(msg_id)
1176 msg['buffers'] = []
1148 try:
1177 try:
1149 raise ValueError("Task %r appears to be inflight" % msg_id)
1178 self.db.add_record(msg_id, init_record(msg))
1150 except Exception:
1179 except Exception:
1151 return finish(error.wrap_exception())
1180 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1152
1181
1153 # clear the existing records
1182 finish(dict(status='ok', resubmitted=resubmitted))
1154 now = datetime.now()
1183
1155 rec = empty_record()
1184 # store the new IDs in the Task DB
1156 map(rec.pop, ['msg_id', 'header', 'content', 'buffers', 'submitted'])
1185 for msg_id, resubmit_id in resubmitted.iteritems():
1157 rec['resubmitted'] = now
1186 try:
1158 rec['queue'] = 'task'
1187 self.db.update_record(msg_id, {'resubmitted' : resubmit_id})
1159 rec['client_uuid'] = client_id[0]
1188 except Exception:
1160 try:
1189 self.log.error("db::DB Error updating record: %s", msg_id, exc_info=True)
1161 for msg_id in msg_ids:
1162 self.all_completed.discard(msg_id)
1163 self.db.update_record(msg_id, rec)
1164 except Exception:
1165 self.log.error('db::db error upating record', exc_info=True)
1166 reply = error.wrap_exception()
1167 else:
1168 # send the messages
1169 for rec in records:
1170 header = rec['header']
1171 # include resubmitted in header to prevent digest collision
1172 header['resubmitted'] = now
1173 msg = self.session.msg(header['msg_type'])
1174 msg['content'] = rec['content']
1175 msg['header'] = header
1176 msg['header']['msg_id'] = rec['msg_id']
1177 self.session.send(self.resubmit, msg, buffers=rec['buffers'])
1178
1179 finish(dict(status='ok'))
1180
1190
1181
1191
1182 def _extract_record(self, rec):
1192 def _extract_record(self, rec):
1183 """decompose a TaskRecord dict into subsection of reply for get_result"""
1193 """decompose a TaskRecord dict into subsection of reply for get_result"""
1184 io_dict = {}
1194 io_dict = {}
1185 for key in ('pyin', 'pyout', 'pyerr', 'stdout', 'stderr'):
1195 for key in ('pyin', 'pyout', 'pyerr', 'stdout', 'stderr'):
1186 io_dict[key] = rec[key]
1196 io_dict[key] = rec[key]
1187 content = { 'result_content': rec['result_content'],
1197 content = { 'result_content': rec['result_content'],
1188 'header': rec['header'],
1198 'header': rec['header'],
1189 'result_header' : rec['result_header'],
1199 'result_header' : rec['result_header'],
1190 'received' : rec['received'],
1200 'received' : rec['received'],
1191 'io' : io_dict,
1201 'io' : io_dict,
1192 }
1202 }
1193 if rec['result_buffers']:
1203 if rec['result_buffers']:
1194 buffers = map(bytes, rec['result_buffers'])
1204 buffers = map(bytes, rec['result_buffers'])
1195 else:
1205 else:
1196 buffers = []
1206 buffers = []
1197
1207
1198 return content, buffers
1208 return content, buffers
1199
1209
1200 def get_results(self, client_id, msg):
1210 def get_results(self, client_id, msg):
1201 """Get the result of 1 or more messages."""
1211 """Get the result of 1 or more messages."""
1202 content = msg['content']
1212 content = msg['content']
1203 msg_ids = sorted(set(content['msg_ids']))
1213 msg_ids = sorted(set(content['msg_ids']))
1204 statusonly = content.get('status_only', False)
1214 statusonly = content.get('status_only', False)
1205 pending = []
1215 pending = []
1206 completed = []
1216 completed = []
1207 content = dict(status='ok')
1217 content = dict(status='ok')
1208 content['pending'] = pending
1218 content['pending'] = pending
1209 content['completed'] = completed
1219 content['completed'] = completed
1210 buffers = []
1220 buffers = []
1211 if not statusonly:
1221 if not statusonly:
1212 try:
1222 try:
1213 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1223 matches = self.db.find_records(dict(msg_id={'$in':msg_ids}))
1214 # turn match list into dict, for faster lookup
1224 # turn match list into dict, for faster lookup
1215 records = {}
1225 records = {}
1216 for rec in matches:
1226 for rec in matches:
1217 records[rec['msg_id']] = rec
1227 records[rec['msg_id']] = rec
1218 except Exception:
1228 except Exception:
1219 content = error.wrap_exception()
1229 content = error.wrap_exception()
1220 self.session.send(self.query, "result_reply", content=content,
1230 self.session.send(self.query, "result_reply", content=content,
1221 parent=msg, ident=client_id)
1231 parent=msg, ident=client_id)
1222 return
1232 return
1223 else:
1233 else:
1224 records = {}
1234 records = {}
1225 for msg_id in msg_ids:
1235 for msg_id in msg_ids:
1226 if msg_id in self.pending:
1236 if msg_id in self.pending:
1227 pending.append(msg_id)
1237 pending.append(msg_id)
1228 elif msg_id in self.all_completed:
1238 elif msg_id in self.all_completed:
1229 completed.append(msg_id)
1239 completed.append(msg_id)
1230 if not statusonly:
1240 if not statusonly:
1231 c,bufs = self._extract_record(records[msg_id])
1241 c,bufs = self._extract_record(records[msg_id])
1232 content[msg_id] = c
1242 content[msg_id] = c
1233 buffers.extend(bufs)
1243 buffers.extend(bufs)
1234 elif msg_id in records:
1244 elif msg_id in records:
1235 if rec['completed']:
1245 if rec['completed']:
1236 completed.append(msg_id)
1246 completed.append(msg_id)
1237 c,bufs = self._extract_record(records[msg_id])
1247 c,bufs = self._extract_record(records[msg_id])
1238 content[msg_id] = c
1248 content[msg_id] = c
1239 buffers.extend(bufs)
1249 buffers.extend(bufs)
1240 else:
1250 else:
1241 pending.append(msg_id)
1251 pending.append(msg_id)
1242 else:
1252 else:
1243 try:
1253 try:
1244 raise KeyError('No such message: '+msg_id)
1254 raise KeyError('No such message: '+msg_id)
1245 except:
1255 except:
1246 content = error.wrap_exception()
1256 content = error.wrap_exception()
1247 break
1257 break
1248 self.session.send(self.query, "result_reply", content=content,
1258 self.session.send(self.query, "result_reply", content=content,
1249 parent=msg, ident=client_id,
1259 parent=msg, ident=client_id,
1250 buffers=buffers)
1260 buffers=buffers)
1251
1261
1252 def get_history(self, client_id, msg):
1262 def get_history(self, client_id, msg):
1253 """Get a list of all msg_ids in our DB records"""
1263 """Get a list of all msg_ids in our DB records"""
1254 try:
1264 try:
1255 msg_ids = self.db.get_history()
1265 msg_ids = self.db.get_history()
1256 except Exception as e:
1266 except Exception as e:
1257 content = error.wrap_exception()
1267 content = error.wrap_exception()
1258 else:
1268 else:
1259 content = dict(status='ok', history=msg_ids)
1269 content = dict(status='ok', history=msg_ids)
1260
1270
1261 self.session.send(self.query, "history_reply", content=content,
1271 self.session.send(self.query, "history_reply", content=content,
1262 parent=msg, ident=client_id)
1272 parent=msg, ident=client_id)
1263
1273
1264 def db_query(self, client_id, msg):
1274 def db_query(self, client_id, msg):
1265 """Perform a raw query on the task record database."""
1275 """Perform a raw query on the task record database."""
1266 content = msg['content']
1276 content = msg['content']
1267 query = content.get('query', {})
1277 query = content.get('query', {})
1268 keys = content.get('keys', None)
1278 keys = content.get('keys', None)
1269 buffers = []
1279 buffers = []
1270 empty = list()
1280 empty = list()
1271 try:
1281 try:
1272 records = self.db.find_records(query, keys)
1282 records = self.db.find_records(query, keys)
1273 except Exception as e:
1283 except Exception as e:
1274 content = error.wrap_exception()
1284 content = error.wrap_exception()
1275 else:
1285 else:
1276 # extract buffers from reply content:
1286 # extract buffers from reply content:
1277 if keys is not None:
1287 if keys is not None:
1278 buffer_lens = [] if 'buffers' in keys else None
1288 buffer_lens = [] if 'buffers' in keys else None
1279 result_buffer_lens = [] if 'result_buffers' in keys else None
1289 result_buffer_lens = [] if 'result_buffers' in keys else None
1280 else:
1290 else:
1281 buffer_lens = None
1291 buffer_lens = None
1282 result_buffer_lens = None
1292 result_buffer_lens = None
1283
1293
1284 for rec in records:
1294 for rec in records:
1285 # buffers may be None, so double check
1295 # buffers may be None, so double check
1286 b = rec.pop('buffers', empty) or empty
1296 b = rec.pop('buffers', empty) or empty
1287 if buffer_lens is not None:
1297 if buffer_lens is not None:
1288 buffer_lens.append(len(b))
1298 buffer_lens.append(len(b))
1289 buffers.extend(b)
1299 buffers.extend(b)
1290 rb = rec.pop('result_buffers', empty) or empty
1300 rb = rec.pop('result_buffers', empty) or empty
1291 if result_buffer_lens is not None:
1301 if result_buffer_lens is not None:
1292 result_buffer_lens.append(len(rb))
1302 result_buffer_lens.append(len(rb))
1293 buffers.extend(rb)
1303 buffers.extend(rb)
1294 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1304 content = dict(status='ok', records=records, buffer_lens=buffer_lens,
1295 result_buffer_lens=result_buffer_lens)
1305 result_buffer_lens=result_buffer_lens)
1296 # self.log.debug (content)
1306 # self.log.debug (content)
1297 self.session.send(self.query, "db_reply", content=content,
1307 self.session.send(self.query, "db_reply", content=content,
1298 parent=msg, ident=client_id,
1308 parent=msg, ident=client_id,
1299 buffers=buffers)
1309 buffers=buffers)
1300
1310
@@ -1,767 +1,768 b''
1 """The Python scheduler for rich scheduling.
1 """The Python scheduler for rich scheduling.
2
2
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
3 The Pure ZMQ scheduler does not allow routing schemes other than LRU,
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
4 nor does it check msg_id DAG dependencies. For those, a slightly slower
5 Python Scheduler exists.
5 Python Scheduler exists.
6
6
7 Authors:
7 Authors:
8
8
9 * Min RK
9 * Min RK
10 """
10 """
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Copyright (C) 2010-2011 The IPython Development Team
12 # Copyright (C) 2010-2011 The IPython Development Team
13 #
13 #
14 # Distributed under the terms of the BSD License. The full license is in
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
15 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 #----------------------------------------------------------------------
18 #----------------------------------------------------------------------
19 # Imports
19 # Imports
20 #----------------------------------------------------------------------
20 #----------------------------------------------------------------------
21
21
22 from __future__ import print_function
22 from __future__ import print_function
23
23
24 import logging
24 import logging
25 import sys
25 import sys
26 import time
26 import time
27
27
28 from datetime import datetime, timedelta
28 from datetime import datetime, timedelta
29 from random import randint, random
29 from random import randint, random
30 from types import FunctionType
30 from types import FunctionType
31
31
32 try:
32 try:
33 import numpy
33 import numpy
34 except ImportError:
34 except ImportError:
35 numpy = None
35 numpy = None
36
36
37 import zmq
37 import zmq
38 from zmq.eventloop import ioloop, zmqstream
38 from zmq.eventloop import ioloop, zmqstream
39
39
40 # local imports
40 # local imports
41 from IPython.external.decorator import decorator
41 from IPython.external.decorator import decorator
42 from IPython.config.application import Application
42 from IPython.config.application import Application
43 from IPython.config.loader import Config
43 from IPython.config.loader import Config
44 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
44 from IPython.utils.traitlets import Instance, Dict, List, Set, Integer, Enum, CBytes
45 from IPython.utils.py3compat import cast_bytes
45
46
46 from IPython.parallel import error, util
47 from IPython.parallel import error, util
47 from IPython.parallel.factory import SessionFactory
48 from IPython.parallel.factory import SessionFactory
48 from IPython.parallel.util import connect_logger, local_logger, asbytes
49 from IPython.parallel.util import connect_logger, local_logger
49
50
50 from .dependency import Dependency
51 from .dependency import Dependency
51
52
52 @decorator
53 @decorator
53 def logged(f,self,*args,**kwargs):
54 def logged(f,self,*args,**kwargs):
54 # print ("#--------------------")
55 # print ("#--------------------")
55 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
56 self.log.debug("scheduler::%s(*%s,**%s)", f.func_name, args, kwargs)
56 # print ("#--")
57 # print ("#--")
57 return f(self,*args, **kwargs)
58 return f(self,*args, **kwargs)
58
59
59 #----------------------------------------------------------------------
60 #----------------------------------------------------------------------
60 # Chooser functions
61 # Chooser functions
61 #----------------------------------------------------------------------
62 #----------------------------------------------------------------------
62
63
63 def plainrandom(loads):
64 def plainrandom(loads):
64 """Plain random pick."""
65 """Plain random pick."""
65 n = len(loads)
66 n = len(loads)
66 return randint(0,n-1)
67 return randint(0,n-1)
67
68
68 def lru(loads):
69 def lru(loads):
69 """Always pick the front of the line.
70 """Always pick the front of the line.
70
71
71 The content of `loads` is ignored.
72 The content of `loads` is ignored.
72
73
73 Assumes LRU ordering of loads, with oldest first.
74 Assumes LRU ordering of loads, with oldest first.
74 """
75 """
75 return 0
76 return 0
76
77
77 def twobin(loads):
78 def twobin(loads):
78 """Pick two at random, use the LRU of the two.
79 """Pick two at random, use the LRU of the two.
79
80
80 The content of loads is ignored.
81 The content of loads is ignored.
81
82
82 Assumes LRU ordering of loads, with oldest first.
83 Assumes LRU ordering of loads, with oldest first.
83 """
84 """
84 n = len(loads)
85 n = len(loads)
85 a = randint(0,n-1)
86 a = randint(0,n-1)
86 b = randint(0,n-1)
87 b = randint(0,n-1)
87 return min(a,b)
88 return min(a,b)
88
89
89 def weighted(loads):
90 def weighted(loads):
90 """Pick two at random using inverse load as weight.
91 """Pick two at random using inverse load as weight.
91
92
92 Return the less loaded of the two.
93 Return the less loaded of the two.
93 """
94 """
94 # weight 0 a million times more than 1:
95 # weight 0 a million times more than 1:
95 weights = 1./(1e-6+numpy.array(loads))
96 weights = 1./(1e-6+numpy.array(loads))
96 sums = weights.cumsum()
97 sums = weights.cumsum()
97 t = sums[-1]
98 t = sums[-1]
98 x = random()*t
99 x = random()*t
99 y = random()*t
100 y = random()*t
100 idx = 0
101 idx = 0
101 idy = 0
102 idy = 0
102 while sums[idx] < x:
103 while sums[idx] < x:
103 idx += 1
104 idx += 1
104 while sums[idy] < y:
105 while sums[idy] < y:
105 idy += 1
106 idy += 1
106 if weights[idy] > weights[idx]:
107 if weights[idy] > weights[idx]:
107 return idy
108 return idy
108 else:
109 else:
109 return idx
110 return idx
110
111
111 def leastload(loads):
112 def leastload(loads):
112 """Always choose the lowest load.
113 """Always choose the lowest load.
113
114
114 If the lowest load occurs more than once, the first
115 If the lowest load occurs more than once, the first
115 occurance will be used. If loads has LRU ordering, this means
116 occurance will be used. If loads has LRU ordering, this means
116 the LRU of those with the lowest load is chosen.
117 the LRU of those with the lowest load is chosen.
117 """
118 """
118 return loads.index(min(loads))
119 return loads.index(min(loads))
119
120
120 #---------------------------------------------------------------------
121 #---------------------------------------------------------------------
121 # Classes
122 # Classes
122 #---------------------------------------------------------------------
123 #---------------------------------------------------------------------
123
124
124
125
125 # store empty default dependency:
126 # store empty default dependency:
126 MET = Dependency([])
127 MET = Dependency([])
127
128
128
129
129 class Job(object):
130 class Job(object):
130 """Simple container for a job"""
131 """Simple container for a job"""
131 def __init__(self, msg_id, raw_msg, idents, msg, header, targets, after, follow, timeout):
132 def __init__(self, msg_id, raw_msg, idents, msg, header, targets, after, follow, timeout):
132 self.msg_id = msg_id
133 self.msg_id = msg_id
133 self.raw_msg = raw_msg
134 self.raw_msg = raw_msg
134 self.idents = idents
135 self.idents = idents
135 self.msg = msg
136 self.msg = msg
136 self.header = header
137 self.header = header
137 self.targets = targets
138 self.targets = targets
138 self.after = after
139 self.after = after
139 self.follow = follow
140 self.follow = follow
140 self.timeout = timeout
141 self.timeout = timeout
141
142
142
143
143 self.timestamp = time.time()
144 self.timestamp = time.time()
144 self.blacklist = set()
145 self.blacklist = set()
145
146
146 @property
147 @property
147 def dependents(self):
148 def dependents(self):
148 return self.follow.union(self.after)
149 return self.follow.union(self.after)
149
150
150 class TaskScheduler(SessionFactory):
151 class TaskScheduler(SessionFactory):
151 """Python TaskScheduler object.
152 """Python TaskScheduler object.
152
153
153 This is the simplest object that supports msg_id based
154 This is the simplest object that supports msg_id based
154 DAG dependencies. *Only* task msg_ids are checked, not
155 DAG dependencies. *Only* task msg_ids are checked, not
155 msg_ids of jobs submitted via the MUX queue.
156 msg_ids of jobs submitted via the MUX queue.
156
157
157 """
158 """
158
159
159 hwm = Integer(1, config=True,
160 hwm = Integer(1, config=True,
160 help="""specify the High Water Mark (HWM) for the downstream
161 help="""specify the High Water Mark (HWM) for the downstream
161 socket in the Task scheduler. This is the maximum number
162 socket in the Task scheduler. This is the maximum number
162 of allowed outstanding tasks on each engine.
163 of allowed outstanding tasks on each engine.
163
164
164 The default (1) means that only one task can be outstanding on each
165 The default (1) means that only one task can be outstanding on each
165 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
166 engine. Setting TaskScheduler.hwm=0 means there is no limit, and the
166 engines continue to be assigned tasks while they are working,
167 engines continue to be assigned tasks while they are working,
167 effectively hiding network latency behind computation, but can result
168 effectively hiding network latency behind computation, but can result
168 in an imbalance of work when submitting many heterogenous tasks all at
169 in an imbalance of work when submitting many heterogenous tasks all at
169 once. Any positive value greater than one is a compromise between the
170 once. Any positive value greater than one is a compromise between the
170 two.
171 two.
171
172
172 """
173 """
173 )
174 )
174 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
175 scheme_name = Enum(('leastload', 'pure', 'lru', 'plainrandom', 'weighted', 'twobin'),
175 'leastload', config=True, allow_none=False,
176 'leastload', config=True, allow_none=False,
176 help="""select the task scheduler scheme [default: Python LRU]
177 help="""select the task scheduler scheme [default: Python LRU]
177 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
178 Options are: 'pure', 'lru', 'plainrandom', 'weighted', 'twobin','leastload'"""
178 )
179 )
179 def _scheme_name_changed(self, old, new):
180 def _scheme_name_changed(self, old, new):
180 self.log.debug("Using scheme %r"%new)
181 self.log.debug("Using scheme %r"%new)
181 self.scheme = globals()[new]
182 self.scheme = globals()[new]
182
183
183 # input arguments:
184 # input arguments:
184 scheme = Instance(FunctionType) # function for determining the destination
185 scheme = Instance(FunctionType) # function for determining the destination
185 def _scheme_default(self):
186 def _scheme_default(self):
186 return leastload
187 return leastload
187 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
188 client_stream = Instance(zmqstream.ZMQStream) # client-facing stream
188 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
189 engine_stream = Instance(zmqstream.ZMQStream) # engine-facing stream
189 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
190 notifier_stream = Instance(zmqstream.ZMQStream) # hub-facing sub stream
190 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
191 mon_stream = Instance(zmqstream.ZMQStream) # hub-facing pub stream
191
192
192 # internals:
193 # internals:
193 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
194 graph = Dict() # dict by msg_id of [ msg_ids that depend on key ]
194 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
195 retries = Dict() # dict by msg_id of retries remaining (non-neg ints)
195 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
196 # waiting = List() # list of msg_ids ready to run, but haven't due to HWM
196 depending = Dict() # dict by msg_id of Jobs
197 depending = Dict() # dict by msg_id of Jobs
197 pending = Dict() # dict by engine_uuid of submitted tasks
198 pending = Dict() # dict by engine_uuid of submitted tasks
198 completed = Dict() # dict by engine_uuid of completed tasks
199 completed = Dict() # dict by engine_uuid of completed tasks
199 failed = Dict() # dict by engine_uuid of failed tasks
200 failed = Dict() # dict by engine_uuid of failed tasks
200 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
201 destinations = Dict() # dict by msg_id of engine_uuids where jobs ran (reverse of completed+failed)
201 clients = Dict() # dict by msg_id for who submitted the task
202 clients = Dict() # dict by msg_id for who submitted the task
202 targets = List() # list of target IDENTs
203 targets = List() # list of target IDENTs
203 loads = List() # list of engine loads
204 loads = List() # list of engine loads
204 # full = Set() # set of IDENTs that have HWM outstanding tasks
205 # full = Set() # set of IDENTs that have HWM outstanding tasks
205 all_completed = Set() # set of all completed tasks
206 all_completed = Set() # set of all completed tasks
206 all_failed = Set() # set of all failed tasks
207 all_failed = Set() # set of all failed tasks
207 all_done = Set() # set of all finished tasks=union(completed,failed)
208 all_done = Set() # set of all finished tasks=union(completed,failed)
208 all_ids = Set() # set of all submitted task IDs
209 all_ids = Set() # set of all submitted task IDs
209
210
210 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
211 auditor = Instance('zmq.eventloop.ioloop.PeriodicCallback')
211
212
212 ident = CBytes() # ZMQ identity. This should just be self.session.session
213 ident = CBytes() # ZMQ identity. This should just be self.session.session
213 # but ensure Bytes
214 # but ensure Bytes
214 def _ident_default(self):
215 def _ident_default(self):
215 return self.session.bsession
216 return self.session.bsession
216
217
217 def start(self):
218 def start(self):
218 self.engine_stream.on_recv(self.dispatch_result, copy=False)
219 self.engine_stream.on_recv(self.dispatch_result, copy=False)
219 self.client_stream.on_recv(self.dispatch_submission, copy=False)
220 self.client_stream.on_recv(self.dispatch_submission, copy=False)
220
221
221 self._notification_handlers = dict(
222 self._notification_handlers = dict(
222 registration_notification = self._register_engine,
223 registration_notification = self._register_engine,
223 unregistration_notification = self._unregister_engine
224 unregistration_notification = self._unregister_engine
224 )
225 )
225 self.notifier_stream.on_recv(self.dispatch_notification)
226 self.notifier_stream.on_recv(self.dispatch_notification)
226 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
227 self.auditor = ioloop.PeriodicCallback(self.audit_timeouts, 2e3, self.loop) # 1 Hz
227 self.auditor.start()
228 self.auditor.start()
228 self.log.info("Scheduler started [%s]"%self.scheme_name)
229 self.log.info("Scheduler started [%s]"%self.scheme_name)
229
230
230 def resume_receiving(self):
231 def resume_receiving(self):
231 """Resume accepting jobs."""
232 """Resume accepting jobs."""
232 self.client_stream.on_recv(self.dispatch_submission, copy=False)
233 self.client_stream.on_recv(self.dispatch_submission, copy=False)
233
234
234 def stop_receiving(self):
235 def stop_receiving(self):
235 """Stop accepting jobs while there are no engines.
236 """Stop accepting jobs while there are no engines.
236 Leave them in the ZMQ queue."""
237 Leave them in the ZMQ queue."""
237 self.client_stream.on_recv(None)
238 self.client_stream.on_recv(None)
238
239
239 #-----------------------------------------------------------------------
240 #-----------------------------------------------------------------------
240 # [Un]Registration Handling
241 # [Un]Registration Handling
241 #-----------------------------------------------------------------------
242 #-----------------------------------------------------------------------
242
243
243
244
244 @util.log_errors
245 @util.log_errors
245 def dispatch_notification(self, msg):
246 def dispatch_notification(self, msg):
246 """dispatch register/unregister events."""
247 """dispatch register/unregister events."""
247 try:
248 try:
248 idents,msg = self.session.feed_identities(msg)
249 idents,msg = self.session.feed_identities(msg)
249 except ValueError:
250 except ValueError:
250 self.log.warn("task::Invalid Message: %r",msg)
251 self.log.warn("task::Invalid Message: %r",msg)
251 return
252 return
252 try:
253 try:
253 msg = self.session.unserialize(msg)
254 msg = self.session.unserialize(msg)
254 except ValueError:
255 except ValueError:
255 self.log.warn("task::Unauthorized message from: %r"%idents)
256 self.log.warn("task::Unauthorized message from: %r"%idents)
256 return
257 return
257
258
258 msg_type = msg['header']['msg_type']
259 msg_type = msg['header']['msg_type']
259
260
260 handler = self._notification_handlers.get(msg_type, None)
261 handler = self._notification_handlers.get(msg_type, None)
261 if handler is None:
262 if handler is None:
262 self.log.error("Unhandled message type: %r"%msg_type)
263 self.log.error("Unhandled message type: %r"%msg_type)
263 else:
264 else:
264 try:
265 try:
265 handler(asbytes(msg['content']['queue']))
266 handler(cast_bytes(msg['content']['queue']))
266 except Exception:
267 except Exception:
267 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
268 self.log.error("task::Invalid notification msg: %r", msg, exc_info=True)
268
269
269 def _register_engine(self, uid):
270 def _register_engine(self, uid):
270 """New engine with ident `uid` became available."""
271 """New engine with ident `uid` became available."""
271 # head of the line:
272 # head of the line:
272 self.targets.insert(0,uid)
273 self.targets.insert(0,uid)
273 self.loads.insert(0,0)
274 self.loads.insert(0,0)
274
275
275 # initialize sets
276 # initialize sets
276 self.completed[uid] = set()
277 self.completed[uid] = set()
277 self.failed[uid] = set()
278 self.failed[uid] = set()
278 self.pending[uid] = {}
279 self.pending[uid] = {}
279
280
280 # rescan the graph:
281 # rescan the graph:
281 self.update_graph(None)
282 self.update_graph(None)
282
283
283 def _unregister_engine(self, uid):
284 def _unregister_engine(self, uid):
284 """Existing engine with ident `uid` became unavailable."""
285 """Existing engine with ident `uid` became unavailable."""
285 if len(self.targets) == 1:
286 if len(self.targets) == 1:
286 # this was our only engine
287 # this was our only engine
287 pass
288 pass
288
289
289 # handle any potentially finished tasks:
290 # handle any potentially finished tasks:
290 self.engine_stream.flush()
291 self.engine_stream.flush()
291
292
292 # don't pop destinations, because they might be used later
293 # don't pop destinations, because they might be used later
293 # map(self.destinations.pop, self.completed.pop(uid))
294 # map(self.destinations.pop, self.completed.pop(uid))
294 # map(self.destinations.pop, self.failed.pop(uid))
295 # map(self.destinations.pop, self.failed.pop(uid))
295
296
296 # prevent this engine from receiving work
297 # prevent this engine from receiving work
297 idx = self.targets.index(uid)
298 idx = self.targets.index(uid)
298 self.targets.pop(idx)
299 self.targets.pop(idx)
299 self.loads.pop(idx)
300 self.loads.pop(idx)
300
301
301 # wait 5 seconds before cleaning up pending jobs, since the results might
302 # wait 5 seconds before cleaning up pending jobs, since the results might
302 # still be incoming
303 # still be incoming
303 if self.pending[uid]:
304 if self.pending[uid]:
304 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
305 dc = ioloop.DelayedCallback(lambda : self.handle_stranded_tasks(uid), 5000, self.loop)
305 dc.start()
306 dc.start()
306 else:
307 else:
307 self.completed.pop(uid)
308 self.completed.pop(uid)
308 self.failed.pop(uid)
309 self.failed.pop(uid)
309
310
310
311
311 def handle_stranded_tasks(self, engine):
312 def handle_stranded_tasks(self, engine):
312 """Deal with jobs resident in an engine that died."""
313 """Deal with jobs resident in an engine that died."""
313 lost = self.pending[engine]
314 lost = self.pending[engine]
314 for msg_id in lost.keys():
315 for msg_id in lost.keys():
315 if msg_id not in self.pending[engine]:
316 if msg_id not in self.pending[engine]:
316 # prevent double-handling of messages
317 # prevent double-handling of messages
317 continue
318 continue
318
319
319 raw_msg = lost[msg_id][0]
320 raw_msg = lost[msg_id].raw_msg
320 idents,msg = self.session.feed_identities(raw_msg, copy=False)
321 idents,msg = self.session.feed_identities(raw_msg, copy=False)
321 parent = self.session.unpack(msg[1].bytes)
322 parent = self.session.unpack(msg[1].bytes)
322 idents = [engine, idents[0]]
323 idents = [engine, idents[0]]
323
324
324 # build fake error reply
325 # build fake error reply
325 try:
326 try:
326 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
327 raise error.EngineError("Engine %r died while running task %r"%(engine, msg_id))
327 except:
328 except:
328 content = error.wrap_exception()
329 content = error.wrap_exception()
329 # build fake header
330 # build fake header
330 header = dict(
331 header = dict(
331 status='error',
332 status='error',
332 engine=engine,
333 engine=engine,
333 date=datetime.now(),
334 date=datetime.now(),
334 )
335 )
335 msg = self.session.msg('apply_reply', content, parent=parent, subheader=header)
336 msg = self.session.msg('apply_reply', content, parent=parent, subheader=header)
336 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
337 raw_reply = map(zmq.Message, self.session.serialize(msg, ident=idents))
337 # and dispatch it
338 # and dispatch it
338 self.dispatch_result(raw_reply)
339 self.dispatch_result(raw_reply)
339
340
340 # finally scrub completed/failed lists
341 # finally scrub completed/failed lists
341 self.completed.pop(engine)
342 self.completed.pop(engine)
342 self.failed.pop(engine)
343 self.failed.pop(engine)
343
344
344
345
345 #-----------------------------------------------------------------------
346 #-----------------------------------------------------------------------
346 # Job Submission
347 # Job Submission
347 #-----------------------------------------------------------------------
348 #-----------------------------------------------------------------------
348
349
349
350
350 @util.log_errors
351 @util.log_errors
351 def dispatch_submission(self, raw_msg):
352 def dispatch_submission(self, raw_msg):
352 """Dispatch job submission to appropriate handlers."""
353 """Dispatch job submission to appropriate handlers."""
353 # ensure targets up to date:
354 # ensure targets up to date:
354 self.notifier_stream.flush()
355 self.notifier_stream.flush()
355 try:
356 try:
356 idents, msg = self.session.feed_identities(raw_msg, copy=False)
357 idents, msg = self.session.feed_identities(raw_msg, copy=False)
357 msg = self.session.unserialize(msg, content=False, copy=False)
358 msg = self.session.unserialize(msg, content=False, copy=False)
358 except Exception:
359 except Exception:
359 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
360 self.log.error("task::Invaid task msg: %r"%raw_msg, exc_info=True)
360 return
361 return
361
362
362
363
363 # send to monitor
364 # send to monitor
364 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
365 self.mon_stream.send_multipart([b'intask']+raw_msg, copy=False)
365
366
366 header = msg['header']
367 header = msg['header']
367 msg_id = header['msg_id']
368 msg_id = header['msg_id']
368 self.all_ids.add(msg_id)
369 self.all_ids.add(msg_id)
369
370
370 # get targets as a set of bytes objects
371 # get targets as a set of bytes objects
371 # from a list of unicode objects
372 # from a list of unicode objects
372 targets = header.get('targets', [])
373 targets = header.get('targets', [])
373 targets = map(asbytes, targets)
374 targets = map(cast_bytes, targets)
374 targets = set(targets)
375 targets = set(targets)
375
376
376 retries = header.get('retries', 0)
377 retries = header.get('retries', 0)
377 self.retries[msg_id] = retries
378 self.retries[msg_id] = retries
378
379
379 # time dependencies
380 # time dependencies
380 after = header.get('after', None)
381 after = header.get('after', None)
381 if after:
382 if after:
382 after = Dependency(after)
383 after = Dependency(after)
383 if after.all:
384 if after.all:
384 if after.success:
385 if after.success:
385 after = Dependency(after.difference(self.all_completed),
386 after = Dependency(after.difference(self.all_completed),
386 success=after.success,
387 success=after.success,
387 failure=after.failure,
388 failure=after.failure,
388 all=after.all,
389 all=after.all,
389 )
390 )
390 if after.failure:
391 if after.failure:
391 after = Dependency(after.difference(self.all_failed),
392 after = Dependency(after.difference(self.all_failed),
392 success=after.success,
393 success=after.success,
393 failure=after.failure,
394 failure=after.failure,
394 all=after.all,
395 all=after.all,
395 )
396 )
396 if after.check(self.all_completed, self.all_failed):
397 if after.check(self.all_completed, self.all_failed):
397 # recast as empty set, if `after` already met,
398 # recast as empty set, if `after` already met,
398 # to prevent unnecessary set comparisons
399 # to prevent unnecessary set comparisons
399 after = MET
400 after = MET
400 else:
401 else:
401 after = MET
402 after = MET
402
403
403 # location dependencies
404 # location dependencies
404 follow = Dependency(header.get('follow', []))
405 follow = Dependency(header.get('follow', []))
405
406
406 # turn timeouts into datetime objects:
407 # turn timeouts into datetime objects:
407 timeout = header.get('timeout', None)
408 timeout = header.get('timeout', None)
408 if timeout:
409 if timeout:
409 # cast to float, because jsonlib returns floats as decimal.Decimal,
410 # cast to float, because jsonlib returns floats as decimal.Decimal,
410 # which timedelta does not accept
411 # which timedelta does not accept
411 timeout = datetime.now() + timedelta(0,float(timeout),0)
412 timeout = datetime.now() + timedelta(0,float(timeout),0)
412
413
413 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
414 job = Job(msg_id=msg_id, raw_msg=raw_msg, idents=idents, msg=msg,
414 header=header, targets=targets, after=after, follow=follow,
415 header=header, targets=targets, after=after, follow=follow,
415 timeout=timeout,
416 timeout=timeout,
416 )
417 )
417
418
418 # validate and reduce dependencies:
419 # validate and reduce dependencies:
419 for dep in after,follow:
420 for dep in after,follow:
420 if not dep: # empty dependency
421 if not dep: # empty dependency
421 continue
422 continue
422 # check valid:
423 # check valid:
423 if msg_id in dep or dep.difference(self.all_ids):
424 if msg_id in dep or dep.difference(self.all_ids):
424 self.depending[msg_id] = job
425 self.depending[msg_id] = job
425 return self.fail_unreachable(msg_id, error.InvalidDependency)
426 return self.fail_unreachable(msg_id, error.InvalidDependency)
426 # check if unreachable:
427 # check if unreachable:
427 if dep.unreachable(self.all_completed, self.all_failed):
428 if dep.unreachable(self.all_completed, self.all_failed):
428 self.depending[msg_id] = job
429 self.depending[msg_id] = job
429 return self.fail_unreachable(msg_id)
430 return self.fail_unreachable(msg_id)
430
431
431 if after.check(self.all_completed, self.all_failed):
432 if after.check(self.all_completed, self.all_failed):
432 # time deps already met, try to run
433 # time deps already met, try to run
433 if not self.maybe_run(job):
434 if not self.maybe_run(job):
434 # can't run yet
435 # can't run yet
435 if msg_id not in self.all_failed:
436 if msg_id not in self.all_failed:
436 # could have failed as unreachable
437 # could have failed as unreachable
437 self.save_unmet(job)
438 self.save_unmet(job)
438 else:
439 else:
439 self.save_unmet(job)
440 self.save_unmet(job)
440
441
441 def audit_timeouts(self):
442 def audit_timeouts(self):
442 """Audit all waiting tasks for expired timeouts."""
443 """Audit all waiting tasks for expired timeouts."""
443 now = datetime.now()
444 now = datetime.now()
444 for msg_id in self.depending.keys():
445 for msg_id in self.depending.keys():
445 # must recheck, in case one failure cascaded to another:
446 # must recheck, in case one failure cascaded to another:
446 if msg_id in self.depending:
447 if msg_id in self.depending:
447 job = self.depending[msg_id]
448 job = self.depending[msg_id]
448 if job.timeout and job.timeout < now:
449 if job.timeout and job.timeout < now:
449 self.fail_unreachable(msg_id, error.TaskTimeout)
450 self.fail_unreachable(msg_id, error.TaskTimeout)
450
451
451 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
452 def fail_unreachable(self, msg_id, why=error.ImpossibleDependency):
452 """a task has become unreachable, send a reply with an ImpossibleDependency
453 """a task has become unreachable, send a reply with an ImpossibleDependency
453 error."""
454 error."""
454 if msg_id not in self.depending:
455 if msg_id not in self.depending:
455 self.log.error("msg %r already failed!", msg_id)
456 self.log.error("msg %r already failed!", msg_id)
456 return
457 return
457 job = self.depending.pop(msg_id)
458 job = self.depending.pop(msg_id)
458 for mid in job.dependents:
459 for mid in job.dependents:
459 if mid in self.graph:
460 if mid in self.graph:
460 self.graph[mid].remove(msg_id)
461 self.graph[mid].remove(msg_id)
461
462
462 try:
463 try:
463 raise why()
464 raise why()
464 except:
465 except:
465 content = error.wrap_exception()
466 content = error.wrap_exception()
466
467
467 self.all_done.add(msg_id)
468 self.all_done.add(msg_id)
468 self.all_failed.add(msg_id)
469 self.all_failed.add(msg_id)
469
470
470 msg = self.session.send(self.client_stream, 'apply_reply', content,
471 msg = self.session.send(self.client_stream, 'apply_reply', content,
471 parent=job.header, ident=job.idents)
472 parent=job.header, ident=job.idents)
472 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
473 self.session.send(self.mon_stream, msg, ident=[b'outtask']+job.idents)
473
474
474 self.update_graph(msg_id, success=False)
475 self.update_graph(msg_id, success=False)
475
476
476 def maybe_run(self, job):
477 def maybe_run(self, job):
477 """check location dependencies, and run if they are met."""
478 """check location dependencies, and run if they are met."""
478 msg_id = job.msg_id
479 msg_id = job.msg_id
479 self.log.debug("Attempting to assign task %s", msg_id)
480 self.log.debug("Attempting to assign task %s", msg_id)
480 if not self.targets:
481 if not self.targets:
481 # no engines, definitely can't run
482 # no engines, definitely can't run
482 return False
483 return False
483
484
484 if job.follow or job.targets or job.blacklist or self.hwm:
485 if job.follow or job.targets or job.blacklist or self.hwm:
485 # we need a can_run filter
486 # we need a can_run filter
486 def can_run(idx):
487 def can_run(idx):
487 # check hwm
488 # check hwm
488 if self.hwm and self.loads[idx] == self.hwm:
489 if self.hwm and self.loads[idx] == self.hwm:
489 return False
490 return False
490 target = self.targets[idx]
491 target = self.targets[idx]
491 # check blacklist
492 # check blacklist
492 if target in job.blacklist:
493 if target in job.blacklist:
493 return False
494 return False
494 # check targets
495 # check targets
495 if job.targets and target not in job.targets:
496 if job.targets and target not in job.targets:
496 return False
497 return False
497 # check follow
498 # check follow
498 return job.follow.check(self.completed[target], self.failed[target])
499 return job.follow.check(self.completed[target], self.failed[target])
499
500
500 indices = filter(can_run, range(len(self.targets)))
501 indices = filter(can_run, range(len(self.targets)))
501
502
502 if not indices:
503 if not indices:
503 # couldn't run
504 # couldn't run
504 if job.follow.all:
505 if job.follow.all:
505 # check follow for impossibility
506 # check follow for impossibility
506 dests = set()
507 dests = set()
507 relevant = set()
508 relevant = set()
508 if job.follow.success:
509 if job.follow.success:
509 relevant = self.all_completed
510 relevant = self.all_completed
510 if job.follow.failure:
511 if job.follow.failure:
511 relevant = relevant.union(self.all_failed)
512 relevant = relevant.union(self.all_failed)
512 for m in job.follow.intersection(relevant):
513 for m in job.follow.intersection(relevant):
513 dests.add(self.destinations[m])
514 dests.add(self.destinations[m])
514 if len(dests) > 1:
515 if len(dests) > 1:
515 self.depending[msg_id] = job
516 self.depending[msg_id] = job
516 self.fail_unreachable(msg_id)
517 self.fail_unreachable(msg_id)
517 return False
518 return False
518 if job.targets:
519 if job.targets:
519 # check blacklist+targets for impossibility
520 # check blacklist+targets for impossibility
520 job.targets.difference_update(job.blacklist)
521 job.targets.difference_update(job.blacklist)
521 if not job.targets or not job.targets.intersection(self.targets):
522 if not job.targets or not job.targets.intersection(self.targets):
522 self.depending[msg_id] = job
523 self.depending[msg_id] = job
523 self.fail_unreachable(msg_id)
524 self.fail_unreachable(msg_id)
524 return False
525 return False
525 return False
526 return False
526 else:
527 else:
527 indices = None
528 indices = None
528
529
529 self.submit_task(job, indices)
530 self.submit_task(job, indices)
530 return True
531 return True
531
532
532 def save_unmet(self, job):
533 def save_unmet(self, job):
533 """Save a message for later submission when its dependencies are met."""
534 """Save a message for later submission when its dependencies are met."""
534 msg_id = job.msg_id
535 msg_id = job.msg_id
535 self.depending[msg_id] = job
536 self.depending[msg_id] = job
536 # track the ids in follow or after, but not those already finished
537 # track the ids in follow or after, but not those already finished
537 for dep_id in job.after.union(job.follow).difference(self.all_done):
538 for dep_id in job.after.union(job.follow).difference(self.all_done):
538 if dep_id not in self.graph:
539 if dep_id not in self.graph:
539 self.graph[dep_id] = set()
540 self.graph[dep_id] = set()
540 self.graph[dep_id].add(msg_id)
541 self.graph[dep_id].add(msg_id)
541
542
542 def submit_task(self, job, indices=None):
543 def submit_task(self, job, indices=None):
543 """Submit a task to any of a subset of our targets."""
544 """Submit a task to any of a subset of our targets."""
544 if indices:
545 if indices:
545 loads = [self.loads[i] for i in indices]
546 loads = [self.loads[i] for i in indices]
546 else:
547 else:
547 loads = self.loads
548 loads = self.loads
548 idx = self.scheme(loads)
549 idx = self.scheme(loads)
549 if indices:
550 if indices:
550 idx = indices[idx]
551 idx = indices[idx]
551 target = self.targets[idx]
552 target = self.targets[idx]
552 # print (target, map(str, msg[:3]))
553 # print (target, map(str, msg[:3]))
553 # send job to the engine
554 # send job to the engine
554 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
555 self.engine_stream.send(target, flags=zmq.SNDMORE, copy=False)
555 self.engine_stream.send_multipart(job.raw_msg, copy=False)
556 self.engine_stream.send_multipart(job.raw_msg, copy=False)
556 # update load
557 # update load
557 self.add_job(idx)
558 self.add_job(idx)
558 self.pending[target][job.msg_id] = job
559 self.pending[target][job.msg_id] = job
559 # notify Hub
560 # notify Hub
560 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
561 content = dict(msg_id=job.msg_id, engine_id=target.decode('ascii'))
561 self.session.send(self.mon_stream, 'task_destination', content=content,
562 self.session.send(self.mon_stream, 'task_destination', content=content,
562 ident=[b'tracktask',self.ident])
563 ident=[b'tracktask',self.ident])
563
564
564
565
565 #-----------------------------------------------------------------------
566 #-----------------------------------------------------------------------
566 # Result Handling
567 # Result Handling
567 #-----------------------------------------------------------------------
568 #-----------------------------------------------------------------------
568
569
569
570
570 @util.log_errors
571 @util.log_errors
571 def dispatch_result(self, raw_msg):
572 def dispatch_result(self, raw_msg):
572 """dispatch method for result replies"""
573 """dispatch method for result replies"""
573 try:
574 try:
574 idents,msg = self.session.feed_identities(raw_msg, copy=False)
575 idents,msg = self.session.feed_identities(raw_msg, copy=False)
575 msg = self.session.unserialize(msg, content=False, copy=False)
576 msg = self.session.unserialize(msg, content=False, copy=False)
576 engine = idents[0]
577 engine = idents[0]
577 try:
578 try:
578 idx = self.targets.index(engine)
579 idx = self.targets.index(engine)
579 except ValueError:
580 except ValueError:
580 pass # skip load-update for dead engines
581 pass # skip load-update for dead engines
581 else:
582 else:
582 self.finish_job(idx)
583 self.finish_job(idx)
583 except Exception:
584 except Exception:
584 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
585 self.log.error("task::Invaid result: %r", raw_msg, exc_info=True)
585 return
586 return
586
587
587 header = msg['header']
588 header = msg['header']
588 parent = msg['parent_header']
589 parent = msg['parent_header']
589 if header.get('dependencies_met', True):
590 if header.get('dependencies_met', True):
590 success = (header['status'] == 'ok')
591 success = (header['status'] == 'ok')
591 msg_id = parent['msg_id']
592 msg_id = parent['msg_id']
592 retries = self.retries[msg_id]
593 retries = self.retries[msg_id]
593 if not success and retries > 0:
594 if not success and retries > 0:
594 # failed
595 # failed
595 self.retries[msg_id] = retries - 1
596 self.retries[msg_id] = retries - 1
596 self.handle_unmet_dependency(idents, parent)
597 self.handle_unmet_dependency(idents, parent)
597 else:
598 else:
598 del self.retries[msg_id]
599 del self.retries[msg_id]
599 # relay to client and update graph
600 # relay to client and update graph
600 self.handle_result(idents, parent, raw_msg, success)
601 self.handle_result(idents, parent, raw_msg, success)
601 # send to Hub monitor
602 # send to Hub monitor
602 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
603 self.mon_stream.send_multipart([b'outtask']+raw_msg, copy=False)
603 else:
604 else:
604 self.handle_unmet_dependency(idents, parent)
605 self.handle_unmet_dependency(idents, parent)
605
606
606 def handle_result(self, idents, parent, raw_msg, success=True):
607 def handle_result(self, idents, parent, raw_msg, success=True):
607 """handle a real task result, either success or failure"""
608 """handle a real task result, either success or failure"""
608 # first, relay result to client
609 # first, relay result to client
609 engine = idents[0]
610 engine = idents[0]
610 client = idents[1]
611 client = idents[1]
611 # swap_ids for XREP-XREP mirror
612 # swap_ids for XREP-XREP mirror
612 raw_msg[:2] = [client,engine]
613 raw_msg[:2] = [client,engine]
613 # print (map(str, raw_msg[:4]))
614 # print (map(str, raw_msg[:4]))
614 self.client_stream.send_multipart(raw_msg, copy=False)
615 self.client_stream.send_multipart(raw_msg, copy=False)
615 # now, update our data structures
616 # now, update our data structures
616 msg_id = parent['msg_id']
617 msg_id = parent['msg_id']
617 self.pending[engine].pop(msg_id)
618 self.pending[engine].pop(msg_id)
618 if success:
619 if success:
619 self.completed[engine].add(msg_id)
620 self.completed[engine].add(msg_id)
620 self.all_completed.add(msg_id)
621 self.all_completed.add(msg_id)
621 else:
622 else:
622 self.failed[engine].add(msg_id)
623 self.failed[engine].add(msg_id)
623 self.all_failed.add(msg_id)
624 self.all_failed.add(msg_id)
624 self.all_done.add(msg_id)
625 self.all_done.add(msg_id)
625 self.destinations[msg_id] = engine
626 self.destinations[msg_id] = engine
626
627
627 self.update_graph(msg_id, success)
628 self.update_graph(msg_id, success)
628
629
629 def handle_unmet_dependency(self, idents, parent):
630 def handle_unmet_dependency(self, idents, parent):
630 """handle an unmet dependency"""
631 """handle an unmet dependency"""
631 engine = idents[0]
632 engine = idents[0]
632 msg_id = parent['msg_id']
633 msg_id = parent['msg_id']
633
634
634 job = self.pending[engine].pop(msg_id)
635 job = self.pending[engine].pop(msg_id)
635 job.blacklist.add(engine)
636 job.blacklist.add(engine)
636
637
637 if job.blacklist == job.targets:
638 if job.blacklist == job.targets:
638 self.depending[msg_id] = job
639 self.depending[msg_id] = job
639 self.fail_unreachable(msg_id)
640 self.fail_unreachable(msg_id)
640 elif not self.maybe_run(job):
641 elif not self.maybe_run(job):
641 # resubmit failed
642 # resubmit failed
642 if msg_id not in self.all_failed:
643 if msg_id not in self.all_failed:
643 # put it back in our dependency tree
644 # put it back in our dependency tree
644 self.save_unmet(job)
645 self.save_unmet(job)
645
646
646 if self.hwm:
647 if self.hwm:
647 try:
648 try:
648 idx = self.targets.index(engine)
649 idx = self.targets.index(engine)
649 except ValueError:
650 except ValueError:
650 pass # skip load-update for dead engines
651 pass # skip load-update for dead engines
651 else:
652 else:
652 if self.loads[idx] == self.hwm-1:
653 if self.loads[idx] == self.hwm-1:
653 self.update_graph(None)
654 self.update_graph(None)
654
655
655
656
656
657
657 def update_graph(self, dep_id=None, success=True):
658 def update_graph(self, dep_id=None, success=True):
658 """dep_id just finished. Update our dependency
659 """dep_id just finished. Update our dependency
659 graph and submit any jobs that just became runable.
660 graph and submit any jobs that just became runable.
660
661
661 Called with dep_id=None to update entire graph for hwm, but without finishing
662 Called with dep_id=None to update entire graph for hwm, but without finishing
662 a task.
663 a task.
663 """
664 """
664 # print ("\n\n***********")
665 # print ("\n\n***********")
665 # pprint (dep_id)
666 # pprint (dep_id)
666 # pprint (self.graph)
667 # pprint (self.graph)
667 # pprint (self.depending)
668 # pprint (self.depending)
668 # pprint (self.all_completed)
669 # pprint (self.all_completed)
669 # pprint (self.all_failed)
670 # pprint (self.all_failed)
670 # print ("\n\n***********\n\n")
671 # print ("\n\n***********\n\n")
671 # update any jobs that depended on the dependency
672 # update any jobs that depended on the dependency
672 jobs = self.graph.pop(dep_id, [])
673 jobs = self.graph.pop(dep_id, [])
673
674
674 # recheck *all* jobs if
675 # recheck *all* jobs if
675 # a) we have HWM and an engine just become no longer full
676 # a) we have HWM and an engine just become no longer full
676 # or b) dep_id was given as None
677 # or b) dep_id was given as None
677
678
678 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
679 if dep_id is None or self.hwm and any( [ load==self.hwm-1 for load in self.loads ]):
679 jobs = self.depending.keys()
680 jobs = self.depending.keys()
680
681
681 for msg_id in sorted(jobs, key=lambda msg_id: self.depending[msg_id].timestamp):
682 for msg_id in sorted(jobs, key=lambda msg_id: self.depending[msg_id].timestamp):
682 job = self.depending[msg_id]
683 job = self.depending[msg_id]
683
684
684 if job.after.unreachable(self.all_completed, self.all_failed)\
685 if job.after.unreachable(self.all_completed, self.all_failed)\
685 or job.follow.unreachable(self.all_completed, self.all_failed):
686 or job.follow.unreachable(self.all_completed, self.all_failed):
686 self.fail_unreachable(msg_id)
687 self.fail_unreachable(msg_id)
687
688
688 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
689 elif job.after.check(self.all_completed, self.all_failed): # time deps met, maybe run
689 if self.maybe_run(job):
690 if self.maybe_run(job):
690
691
691 self.depending.pop(msg_id)
692 self.depending.pop(msg_id)
692 for mid in job.dependents:
693 for mid in job.dependents:
693 if mid in self.graph:
694 if mid in self.graph:
694 self.graph[mid].remove(msg_id)
695 self.graph[mid].remove(msg_id)
695
696
696 #----------------------------------------------------------------------
697 #----------------------------------------------------------------------
697 # methods to be overridden by subclasses
698 # methods to be overridden by subclasses
698 #----------------------------------------------------------------------
699 #----------------------------------------------------------------------
699
700
700 def add_job(self, idx):
701 def add_job(self, idx):
701 """Called after self.targets[idx] just got the job with header.
702 """Called after self.targets[idx] just got the job with header.
702 Override with subclasses. The default ordering is simple LRU.
703 Override with subclasses. The default ordering is simple LRU.
703 The default loads are the number of outstanding jobs."""
704 The default loads are the number of outstanding jobs."""
704 self.loads[idx] += 1
705 self.loads[idx] += 1
705 for lis in (self.targets, self.loads):
706 for lis in (self.targets, self.loads):
706 lis.append(lis.pop(idx))
707 lis.append(lis.pop(idx))
707
708
708
709
709 def finish_job(self, idx):
710 def finish_job(self, idx):
710 """Called after self.targets[idx] just finished a job.
711 """Called after self.targets[idx] just finished a job.
711 Override with subclasses."""
712 Override with subclasses."""
712 self.loads[idx] -= 1
713 self.loads[idx] -= 1
713
714
714
715
715
716
716 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
717 def launch_scheduler(in_addr, out_addr, mon_addr, not_addr, config=None,
717 logname='root', log_url=None, loglevel=logging.DEBUG,
718 logname='root', log_url=None, loglevel=logging.DEBUG,
718 identity=b'task', in_thread=False):
719 identity=b'task', in_thread=False):
719
720
720 ZMQStream = zmqstream.ZMQStream
721 ZMQStream = zmqstream.ZMQStream
721
722
722 if config:
723 if config:
723 # unwrap dict back into Config
724 # unwrap dict back into Config
724 config = Config(config)
725 config = Config(config)
725
726
726 if in_thread:
727 if in_thread:
727 # use instance() to get the same Context/Loop as our parent
728 # use instance() to get the same Context/Loop as our parent
728 ctx = zmq.Context.instance()
729 ctx = zmq.Context.instance()
729 loop = ioloop.IOLoop.instance()
730 loop = ioloop.IOLoop.instance()
730 else:
731 else:
731 # in a process, don't use instance()
732 # in a process, don't use instance()
732 # for safety with multiprocessing
733 # for safety with multiprocessing
733 ctx = zmq.Context()
734 ctx = zmq.Context()
734 loop = ioloop.IOLoop()
735 loop = ioloop.IOLoop()
735 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
736 ins = ZMQStream(ctx.socket(zmq.ROUTER),loop)
736 ins.setsockopt(zmq.IDENTITY, identity)
737 ins.setsockopt(zmq.IDENTITY, identity)
737 ins.bind(in_addr)
738 ins.bind(in_addr)
738
739
739 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
740 outs = ZMQStream(ctx.socket(zmq.ROUTER),loop)
740 outs.setsockopt(zmq.IDENTITY, identity)
741 outs.setsockopt(zmq.IDENTITY, identity)
741 outs.bind(out_addr)
742 outs.bind(out_addr)
742 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
743 mons = zmqstream.ZMQStream(ctx.socket(zmq.PUB),loop)
743 mons.connect(mon_addr)
744 mons.connect(mon_addr)
744 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
745 nots = zmqstream.ZMQStream(ctx.socket(zmq.SUB),loop)
745 nots.setsockopt(zmq.SUBSCRIBE, b'')
746 nots.setsockopt(zmq.SUBSCRIBE, b'')
746 nots.connect(not_addr)
747 nots.connect(not_addr)
747
748
748 # setup logging.
749 # setup logging.
749 if in_thread:
750 if in_thread:
750 log = Application.instance().log
751 log = Application.instance().log
751 else:
752 else:
752 if log_url:
753 if log_url:
753 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
754 log = connect_logger(logname, ctx, log_url, root="scheduler", loglevel=loglevel)
754 else:
755 else:
755 log = local_logger(logname, loglevel)
756 log = local_logger(logname, loglevel)
756
757
757 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
758 scheduler = TaskScheduler(client_stream=ins, engine_stream=outs,
758 mon_stream=mons, notifier_stream=nots,
759 mon_stream=mons, notifier_stream=nots,
759 loop=loop, log=log,
760 loop=loop, log=log,
760 config=config)
761 config=config)
761 scheduler.start()
762 scheduler.start()
762 if not in_thread:
763 if not in_thread:
763 try:
764 try:
764 loop.start()
765 loop.start()
765 except KeyboardInterrupt:
766 except KeyboardInterrupt:
766 scheduler.log.critical("Interrupted, exiting...")
767 scheduler.log.critical("Interrupted, exiting...")
767
768
@@ -1,411 +1,412 b''
1 """A TaskRecord backend using sqlite3
1 """A TaskRecord backend using sqlite3
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2011 The IPython Development Team
8 # Copyright (C) 2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 import json
14 import json
15 import os
15 import os
16 import cPickle as pickle
16 import cPickle as pickle
17 from datetime import datetime
17 from datetime import datetime
18
18
19 try:
19 try:
20 import sqlite3
20 import sqlite3
21 except ImportError:
21 except ImportError:
22 sqlite3 = None
22 sqlite3 = None
23
23
24 from zmq.eventloop import ioloop
24 from zmq.eventloop import ioloop
25
25
26 from IPython.utils.traitlets import Unicode, Instance, List, Dict
26 from IPython.utils.traitlets import Unicode, Instance, List, Dict
27 from .dictdb import BaseDB
27 from .dictdb import BaseDB
28 from IPython.utils.jsonutil import date_default, extract_dates, squash_dates
28 from IPython.utils.jsonutil import date_default, extract_dates, squash_dates
29
29
30 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
31 # SQLite operators, adapters, and converters
31 # SQLite operators, adapters, and converters
32 #-----------------------------------------------------------------------------
32 #-----------------------------------------------------------------------------
33
33
34 try:
34 try:
35 buffer
35 buffer
36 except NameError:
36 except NameError:
37 # py3k
37 # py3k
38 buffer = memoryview
38 buffer = memoryview
39
39
40 operators = {
40 operators = {
41 '$lt' : "<",
41 '$lt' : "<",
42 '$gt' : ">",
42 '$gt' : ">",
43 # null is handled weird with ==,!=
43 # null is handled weird with ==,!=
44 '$eq' : "=",
44 '$eq' : "=",
45 '$ne' : "!=",
45 '$ne' : "!=",
46 '$lte': "<=",
46 '$lte': "<=",
47 '$gte': ">=",
47 '$gte': ">=",
48 '$in' : ('=', ' OR '),
48 '$in' : ('=', ' OR '),
49 '$nin': ('!=', ' AND '),
49 '$nin': ('!=', ' AND '),
50 # '$all': None,
50 # '$all': None,
51 # '$mod': None,
51 # '$mod': None,
52 # '$exists' : None
52 # '$exists' : None
53 }
53 }
54 null_operators = {
54 null_operators = {
55 '=' : "IS NULL",
55 '=' : "IS NULL",
56 '!=' : "IS NOT NULL",
56 '!=' : "IS NOT NULL",
57 }
57 }
58
58
59 def _adapt_dict(d):
59 def _adapt_dict(d):
60 return json.dumps(d, default=date_default)
60 return json.dumps(d, default=date_default)
61
61
62 def _convert_dict(ds):
62 def _convert_dict(ds):
63 if ds is None:
63 if ds is None:
64 return ds
64 return ds
65 else:
65 else:
66 if isinstance(ds, bytes):
66 if isinstance(ds, bytes):
67 # If I understand the sqlite doc correctly, this will always be utf8
67 # If I understand the sqlite doc correctly, this will always be utf8
68 ds = ds.decode('utf8')
68 ds = ds.decode('utf8')
69 return extract_dates(json.loads(ds))
69 return extract_dates(json.loads(ds))
70
70
71 def _adapt_bufs(bufs):
71 def _adapt_bufs(bufs):
72 # this is *horrible*
72 # this is *horrible*
73 # copy buffers into single list and pickle it:
73 # copy buffers into single list and pickle it:
74 if bufs and isinstance(bufs[0], (bytes, buffer)):
74 if bufs and isinstance(bufs[0], (bytes, buffer)):
75 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
75 return sqlite3.Binary(pickle.dumps(map(bytes, bufs),-1))
76 elif bufs:
76 elif bufs:
77 return bufs
77 return bufs
78 else:
78 else:
79 return None
79 return None
80
80
81 def _convert_bufs(bs):
81 def _convert_bufs(bs):
82 if bs is None:
82 if bs is None:
83 return []
83 return []
84 else:
84 else:
85 return pickle.loads(bytes(bs))
85 return pickle.loads(bytes(bs))
86
86
87 #-----------------------------------------------------------------------------
87 #-----------------------------------------------------------------------------
88 # SQLiteDB class
88 # SQLiteDB class
89 #-----------------------------------------------------------------------------
89 #-----------------------------------------------------------------------------
90
90
91 class SQLiteDB(BaseDB):
91 class SQLiteDB(BaseDB):
92 """SQLite3 TaskRecord backend."""
92 """SQLite3 TaskRecord backend."""
93
93
94 filename = Unicode('tasks.db', config=True,
94 filename = Unicode('tasks.db', config=True,
95 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
95 help="""The filename of the sqlite task database. [default: 'tasks.db']""")
96 location = Unicode('', config=True,
96 location = Unicode('', config=True,
97 help="""The directory containing the sqlite task database. The default
97 help="""The directory containing the sqlite task database. The default
98 is to use the cluster_dir location.""")
98 is to use the cluster_dir location.""")
99 table = Unicode("", config=True,
99 table = Unicode("", config=True,
100 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
100 help="""The SQLite Table to use for storing tasks for this session. If unspecified,
101 a new table will be created with the Hub's IDENT. Specifying the table will result
101 a new table will be created with the Hub's IDENT. Specifying the table will result
102 in tasks from previous sessions being available via Clients' db_query and
102 in tasks from previous sessions being available via Clients' db_query and
103 get_result methods.""")
103 get_result methods.""")
104
104
105 if sqlite3 is not None:
105 if sqlite3 is not None:
106 _db = Instance('sqlite3.Connection')
106 _db = Instance('sqlite3.Connection')
107 else:
107 else:
108 _db = None
108 _db = None
109 # the ordered list of column names
109 # the ordered list of column names
110 _keys = List(['msg_id' ,
110 _keys = List(['msg_id' ,
111 'header' ,
111 'header' ,
112 'content',
112 'content',
113 'buffers',
113 'buffers',
114 'submitted',
114 'submitted',
115 'client_uuid' ,
115 'client_uuid' ,
116 'engine_uuid' ,
116 'engine_uuid' ,
117 'started',
117 'started',
118 'completed',
118 'completed',
119 'resubmitted',
119 'resubmitted',
120 'received',
120 'received',
121 'result_header' ,
121 'result_header' ,
122 'result_content' ,
122 'result_content' ,
123 'result_buffers' ,
123 'result_buffers' ,
124 'queue' ,
124 'queue' ,
125 'pyin' ,
125 'pyin' ,
126 'pyout',
126 'pyout',
127 'pyerr',
127 'pyerr',
128 'stdout',
128 'stdout',
129 'stderr',
129 'stderr',
130 ])
130 ])
131 # sqlite datatypes for checking that db is current format
131 # sqlite datatypes for checking that db is current format
132 _types = Dict({'msg_id' : 'text' ,
132 _types = Dict({'msg_id' : 'text' ,
133 'header' : 'dict text',
133 'header' : 'dict text',
134 'content' : 'dict text',
134 'content' : 'dict text',
135 'buffers' : 'bufs blob',
135 'buffers' : 'bufs blob',
136 'submitted' : 'timestamp',
136 'submitted' : 'timestamp',
137 'client_uuid' : 'text',
137 'client_uuid' : 'text',
138 'engine_uuid' : 'text',
138 'engine_uuid' : 'text',
139 'started' : 'timestamp',
139 'started' : 'timestamp',
140 'completed' : 'timestamp',
140 'completed' : 'timestamp',
141 'resubmitted' : 'timestamp',
141 'resubmitted' : 'text',
142 'received' : 'timestamp',
142 'received' : 'timestamp',
143 'result_header' : 'dict text',
143 'result_header' : 'dict text',
144 'result_content' : 'dict text',
144 'result_content' : 'dict text',
145 'result_buffers' : 'bufs blob',
145 'result_buffers' : 'bufs blob',
146 'queue' : 'text',
146 'queue' : 'text',
147 'pyin' : 'text',
147 'pyin' : 'text',
148 'pyout' : 'text',
148 'pyout' : 'text',
149 'pyerr' : 'text',
149 'pyerr' : 'text',
150 'stdout' : 'text',
150 'stdout' : 'text',
151 'stderr' : 'text',
151 'stderr' : 'text',
152 })
152 })
153
153
154 def __init__(self, **kwargs):
154 def __init__(self, **kwargs):
155 super(SQLiteDB, self).__init__(**kwargs)
155 super(SQLiteDB, self).__init__(**kwargs)
156 if sqlite3 is None:
156 if sqlite3 is None:
157 raise ImportError("SQLiteDB requires sqlite3")
157 raise ImportError("SQLiteDB requires sqlite3")
158 if not self.table:
158 if not self.table:
159 # use session, and prefix _, since starting with # is illegal
159 # use session, and prefix _, since starting with # is illegal
160 self.table = '_'+self.session.replace('-','_')
160 self.table = '_'+self.session.replace('-','_')
161 if not self.location:
161 if not self.location:
162 # get current profile
162 # get current profile
163 from IPython.core.application import BaseIPythonApplication
163 from IPython.core.application import BaseIPythonApplication
164 if BaseIPythonApplication.initialized():
164 if BaseIPythonApplication.initialized():
165 app = BaseIPythonApplication.instance()
165 app = BaseIPythonApplication.instance()
166 if app.profile_dir is not None:
166 if app.profile_dir is not None:
167 self.location = app.profile_dir.location
167 self.location = app.profile_dir.location
168 else:
168 else:
169 self.location = u'.'
169 self.location = u'.'
170 else:
170 else:
171 self.location = u'.'
171 self.location = u'.'
172 self._init_db()
172 self._init_db()
173
173
174 # register db commit as 2s periodic callback
174 # register db commit as 2s periodic callback
175 # to prevent clogging pipes
175 # to prevent clogging pipes
176 # assumes we are being run in a zmq ioloop app
176 # assumes we are being run in a zmq ioloop app
177 loop = ioloop.IOLoop.instance()
177 loop = ioloop.IOLoop.instance()
178 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
178 pc = ioloop.PeriodicCallback(self._db.commit, 2000, loop)
179 pc.start()
179 pc.start()
180
180
181 def _defaults(self, keys=None):
181 def _defaults(self, keys=None):
182 """create an empty record"""
182 """create an empty record"""
183 d = {}
183 d = {}
184 keys = self._keys if keys is None else keys
184 keys = self._keys if keys is None else keys
185 for key in keys:
185 for key in keys:
186 d[key] = None
186 d[key] = None
187 return d
187 return d
188
188
189 def _check_table(self):
189 def _check_table(self):
190 """Ensure that an incorrect table doesn't exist
190 """Ensure that an incorrect table doesn't exist
191
191
192 If a bad (old) table does exist, return False
192 If a bad (old) table does exist, return False
193 """
193 """
194 cursor = self._db.execute("PRAGMA table_info(%s)"%self.table)
194 cursor = self._db.execute("PRAGMA table_info(%s)"%self.table)
195 lines = cursor.fetchall()
195 lines = cursor.fetchall()
196 if not lines:
196 if not lines:
197 # table does not exist
197 # table does not exist
198 return True
198 return True
199 types = {}
199 types = {}
200 keys = []
200 keys = []
201 for line in lines:
201 for line in lines:
202 keys.append(line[1])
202 keys.append(line[1])
203 types[line[1]] = line[2]
203 types[line[1]] = line[2]
204 if self._keys != keys:
204 if self._keys != keys:
205 # key mismatch
205 # key mismatch
206 self.log.warn('keys mismatch')
206 self.log.warn('keys mismatch')
207 return False
207 return False
208 for key in self._keys:
208 for key in self._keys:
209 if types[key] != self._types[key]:
209 if types[key] != self._types[key]:
210 self.log.warn(
210 self.log.warn(
211 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key])
211 'type mismatch: %s: %s != %s'%(key,types[key],self._types[key])
212 )
212 )
213 return False
213 return False
214 return True
214 return True
215
215
216 def _init_db(self):
216 def _init_db(self):
217 """Connect to the database and get new session number."""
217 """Connect to the database and get new session number."""
218 # register adapters
218 # register adapters
219 sqlite3.register_adapter(dict, _adapt_dict)
219 sqlite3.register_adapter(dict, _adapt_dict)
220 sqlite3.register_converter('dict', _convert_dict)
220 sqlite3.register_converter('dict', _convert_dict)
221 sqlite3.register_adapter(list, _adapt_bufs)
221 sqlite3.register_adapter(list, _adapt_bufs)
222 sqlite3.register_converter('bufs', _convert_bufs)
222 sqlite3.register_converter('bufs', _convert_bufs)
223 # connect to the db
223 # connect to the db
224 dbfile = os.path.join(self.location, self.filename)
224 dbfile = os.path.join(self.location, self.filename)
225 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
225 self._db = sqlite3.connect(dbfile, detect_types=sqlite3.PARSE_DECLTYPES,
226 # isolation_level = None)#,
226 # isolation_level = None)#,
227 cached_statements=64)
227 cached_statements=64)
228 # print dir(self._db)
228 # print dir(self._db)
229 first_table = self.table
229 first_table = previous_table = self.table
230 i=0
230 i=0
231 while not self._check_table():
231 while not self._check_table():
232 i+=1
232 i+=1
233 self.table = first_table+'_%i'%i
233 self.table = first_table+'_%i'%i
234 self.log.warn(
234 self.log.warn(
235 "Table %s exists and doesn't match db format, trying %s"%
235 "Table %s exists and doesn't match db format, trying %s"%
236 (first_table,self.table)
236 (previous_table, self.table)
237 )
237 )
238 previous_table = self.table
238
239
239 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
240 self._db.execute("""CREATE TABLE IF NOT EXISTS %s
240 (msg_id text PRIMARY KEY,
241 (msg_id text PRIMARY KEY,
241 header dict text,
242 header dict text,
242 content dict text,
243 content dict text,
243 buffers bufs blob,
244 buffers bufs blob,
244 submitted timestamp,
245 submitted timestamp,
245 client_uuid text,
246 client_uuid text,
246 engine_uuid text,
247 engine_uuid text,
247 started timestamp,
248 started timestamp,
248 completed timestamp,
249 completed timestamp,
249 resubmitted timestamp,
250 resubmitted text,
250 received timestamp,
251 received timestamp,
251 result_header dict text,
252 result_header dict text,
252 result_content dict text,
253 result_content dict text,
253 result_buffers bufs blob,
254 result_buffers bufs blob,
254 queue text,
255 queue text,
255 pyin text,
256 pyin text,
256 pyout text,
257 pyout text,
257 pyerr text,
258 pyerr text,
258 stdout text,
259 stdout text,
259 stderr text)
260 stderr text)
260 """%self.table)
261 """%self.table)
261 self._db.commit()
262 self._db.commit()
262
263
263 def _dict_to_list(self, d):
264 def _dict_to_list(self, d):
264 """turn a mongodb-style record dict into a list."""
265 """turn a mongodb-style record dict into a list."""
265
266
266 return [ d[key] for key in self._keys ]
267 return [ d[key] for key in self._keys ]
267
268
268 def _list_to_dict(self, line, keys=None):
269 def _list_to_dict(self, line, keys=None):
269 """Inverse of dict_to_list"""
270 """Inverse of dict_to_list"""
270 keys = self._keys if keys is None else keys
271 keys = self._keys if keys is None else keys
271 d = self._defaults(keys)
272 d = self._defaults(keys)
272 for key,value in zip(keys, line):
273 for key,value in zip(keys, line):
273 d[key] = value
274 d[key] = value
274
275
275 return d
276 return d
276
277
277 def _render_expression(self, check):
278 def _render_expression(self, check):
278 """Turn a mongodb-style search dict into an SQL query."""
279 """Turn a mongodb-style search dict into an SQL query."""
279 expressions = []
280 expressions = []
280 args = []
281 args = []
281
282
282 skeys = set(check.keys())
283 skeys = set(check.keys())
283 skeys.difference_update(set(self._keys))
284 skeys.difference_update(set(self._keys))
284 skeys.difference_update(set(['buffers', 'result_buffers']))
285 skeys.difference_update(set(['buffers', 'result_buffers']))
285 if skeys:
286 if skeys:
286 raise KeyError("Illegal testing key(s): %s"%skeys)
287 raise KeyError("Illegal testing key(s): %s"%skeys)
287
288
288 for name,sub_check in check.iteritems():
289 for name,sub_check in check.iteritems():
289 if isinstance(sub_check, dict):
290 if isinstance(sub_check, dict):
290 for test,value in sub_check.iteritems():
291 for test,value in sub_check.iteritems():
291 try:
292 try:
292 op = operators[test]
293 op = operators[test]
293 except KeyError:
294 except KeyError:
294 raise KeyError("Unsupported operator: %r"%test)
295 raise KeyError("Unsupported operator: %r"%test)
295 if isinstance(op, tuple):
296 if isinstance(op, tuple):
296 op, join = op
297 op, join = op
297
298
298 if value is None and op in null_operators:
299 if value is None and op in null_operators:
299 expr = "%s %s" % (name, null_operators[op])
300 expr = "%s %s" % (name, null_operators[op])
300 else:
301 else:
301 expr = "%s %s ?"%(name, op)
302 expr = "%s %s ?"%(name, op)
302 if isinstance(value, (tuple,list)):
303 if isinstance(value, (tuple,list)):
303 if op in null_operators and any([v is None for v in value]):
304 if op in null_operators and any([v is None for v in value]):
304 # equality tests don't work with NULL
305 # equality tests don't work with NULL
305 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
306 raise ValueError("Cannot use %r test with NULL values on SQLite backend"%test)
306 expr = '( %s )'%( join.join([expr]*len(value)) )
307 expr = '( %s )'%( join.join([expr]*len(value)) )
307 args.extend(value)
308 args.extend(value)
308 else:
309 else:
309 args.append(value)
310 args.append(value)
310 expressions.append(expr)
311 expressions.append(expr)
311 else:
312 else:
312 # it's an equality check
313 # it's an equality check
313 if sub_check is None:
314 if sub_check is None:
314 expressions.append("%s IS NULL" % name)
315 expressions.append("%s IS NULL" % name)
315 else:
316 else:
316 expressions.append("%s = ?"%name)
317 expressions.append("%s = ?"%name)
317 args.append(sub_check)
318 args.append(sub_check)
318
319
319 expr = " AND ".join(expressions)
320 expr = " AND ".join(expressions)
320 return expr, args
321 return expr, args
321
322
322 def add_record(self, msg_id, rec):
323 def add_record(self, msg_id, rec):
323 """Add a new Task Record, by msg_id."""
324 """Add a new Task Record, by msg_id."""
324 d = self._defaults()
325 d = self._defaults()
325 d.update(rec)
326 d.update(rec)
326 d['msg_id'] = msg_id
327 d['msg_id'] = msg_id
327 line = self._dict_to_list(d)
328 line = self._dict_to_list(d)
328 tups = '(%s)'%(','.join(['?']*len(line)))
329 tups = '(%s)'%(','.join(['?']*len(line)))
329 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
330 self._db.execute("INSERT INTO %s VALUES %s"%(self.table, tups), line)
330 # self._db.commit()
331 # self._db.commit()
331
332
332 def get_record(self, msg_id):
333 def get_record(self, msg_id):
333 """Get a specific Task Record, by msg_id."""
334 """Get a specific Task Record, by msg_id."""
334 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
335 cursor = self._db.execute("""SELECT * FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
335 line = cursor.fetchone()
336 line = cursor.fetchone()
336 if line is None:
337 if line is None:
337 raise KeyError("No such msg: %r"%msg_id)
338 raise KeyError("No such msg: %r"%msg_id)
338 return self._list_to_dict(line)
339 return self._list_to_dict(line)
339
340
340 def update_record(self, msg_id, rec):
341 def update_record(self, msg_id, rec):
341 """Update the data in an existing record."""
342 """Update the data in an existing record."""
342 query = "UPDATE %s SET "%self.table
343 query = "UPDATE %s SET "%self.table
343 sets = []
344 sets = []
344 keys = sorted(rec.keys())
345 keys = sorted(rec.keys())
345 values = []
346 values = []
346 for key in keys:
347 for key in keys:
347 sets.append('%s = ?'%key)
348 sets.append('%s = ?'%key)
348 values.append(rec[key])
349 values.append(rec[key])
349 query += ', '.join(sets)
350 query += ', '.join(sets)
350 query += ' WHERE msg_id == ?'
351 query += ' WHERE msg_id == ?'
351 values.append(msg_id)
352 values.append(msg_id)
352 self._db.execute(query, values)
353 self._db.execute(query, values)
353 # self._db.commit()
354 # self._db.commit()
354
355
355 def drop_record(self, msg_id):
356 def drop_record(self, msg_id):
356 """Remove a record from the DB."""
357 """Remove a record from the DB."""
357 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
358 self._db.execute("""DELETE FROM %s WHERE msg_id==?"""%self.table, (msg_id,))
358 # self._db.commit()
359 # self._db.commit()
359
360
360 def drop_matching_records(self, check):
361 def drop_matching_records(self, check):
361 """Remove a record from the DB."""
362 """Remove a record from the DB."""
362 expr,args = self._render_expression(check)
363 expr,args = self._render_expression(check)
363 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
364 query = "DELETE FROM %s WHERE %s"%(self.table, expr)
364 self._db.execute(query,args)
365 self._db.execute(query,args)
365 # self._db.commit()
366 # self._db.commit()
366
367
367 def find_records(self, check, keys=None):
368 def find_records(self, check, keys=None):
368 """Find records matching a query dict, optionally extracting subset of keys.
369 """Find records matching a query dict, optionally extracting subset of keys.
369
370
370 Returns list of matching records.
371 Returns list of matching records.
371
372
372 Parameters
373 Parameters
373 ----------
374 ----------
374
375
375 check: dict
376 check: dict
376 mongodb-style query argument
377 mongodb-style query argument
377 keys: list of strs [optional]
378 keys: list of strs [optional]
378 if specified, the subset of keys to extract. msg_id will *always* be
379 if specified, the subset of keys to extract. msg_id will *always* be
379 included.
380 included.
380 """
381 """
381 if keys:
382 if keys:
382 bad_keys = [ key for key in keys if key not in self._keys ]
383 bad_keys = [ key for key in keys if key not in self._keys ]
383 if bad_keys:
384 if bad_keys:
384 raise KeyError("Bad record key(s): %s"%bad_keys)
385 raise KeyError("Bad record key(s): %s"%bad_keys)
385
386
386 if keys:
387 if keys:
387 # ensure msg_id is present and first:
388 # ensure msg_id is present and first:
388 if 'msg_id' in keys:
389 if 'msg_id' in keys:
389 keys.remove('msg_id')
390 keys.remove('msg_id')
390 keys.insert(0, 'msg_id')
391 keys.insert(0, 'msg_id')
391 req = ', '.join(keys)
392 req = ', '.join(keys)
392 else:
393 else:
393 req = '*'
394 req = '*'
394 expr,args = self._render_expression(check)
395 expr,args = self._render_expression(check)
395 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
396 query = """SELECT %s FROM %s WHERE %s"""%(req, self.table, expr)
396 cursor = self._db.execute(query, args)
397 cursor = self._db.execute(query, args)
397 matches = cursor.fetchall()
398 matches = cursor.fetchall()
398 records = []
399 records = []
399 for line in matches:
400 for line in matches:
400 rec = self._list_to_dict(line, keys)
401 rec = self._list_to_dict(line, keys)
401 records.append(rec)
402 records.append(rec)
402 return records
403 return records
403
404
404 def get_history(self):
405 def get_history(self):
405 """get all msg_ids, ordered by time submitted."""
406 """get all msg_ids, ordered by time submitted."""
406 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
407 query = """SELECT msg_id FROM %s ORDER by submitted ASC"""%self.table
407 cursor = self._db.execute(query)
408 cursor = self._db.execute(query)
408 # will be a list of length 1 tuples
409 # will be a list of length 1 tuples
409 return [ tup[0] for tup in cursor.fetchall()]
410 return [ tup[0] for tup in cursor.fetchall()]
410
411
411 __all__ = ['SQLiteDB'] No newline at end of file
412 __all__ = ['SQLiteDB']
@@ -1,234 +1,237 b''
1 """A simple engine that talks to a controller over 0MQ.
1 """A simple engine that talks to a controller over 0MQ.
2 it handles registration, etc. and launches a kernel
2 it handles registration, etc. and launches a kernel
3 connected to the Controller's Schedulers.
3 connected to the Controller's Schedulers.
4
4
5 Authors:
5 Authors:
6
6
7 * Min RK
7 * Min RK
8 """
8 """
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Copyright (C) 2010-2011 The IPython Development Team
10 # Copyright (C) 2010-2011 The IPython Development Team
11 #
11 #
12 # Distributed under the terms of the BSD License. The full license is in
12 # Distributed under the terms of the BSD License. The full license is in
13 # the file COPYING, distributed as part of this software.
13 # the file COPYING, distributed as part of this software.
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15
15
16 from __future__ import print_function
16 from __future__ import print_function
17
17
18 import sys
18 import sys
19 import time
19 import time
20 from getpass import getpass
20 from getpass import getpass
21
21
22 import zmq
22 import zmq
23 from zmq.eventloop import ioloop, zmqstream
23 from zmq.eventloop import ioloop, zmqstream
24
24
25 from IPython.external.ssh import tunnel
25 from IPython.external.ssh import tunnel
26 # internal
26 # internal
27 from IPython.utils.traitlets import (
27 from IPython.utils.traitlets import (
28 Instance, Dict, Integer, Type, CFloat, Unicode, CBytes, Bool
28 Instance, Dict, Integer, Type, CFloat, Unicode, CBytes, Bool
29 )
29 )
30 from IPython.utils import py3compat
30 from IPython.utils.py3compat import cast_bytes
31
31
32 from IPython.parallel.controller.heartmonitor import Heart
32 from IPython.parallel.controller.heartmonitor import Heart
33 from IPython.parallel.factory import RegistrationFactory
33 from IPython.parallel.factory import RegistrationFactory
34 from IPython.parallel.util import disambiguate_url, asbytes
34 from IPython.parallel.util import disambiguate_url
35
35
36 from IPython.zmq.session import Message
36 from IPython.zmq.session import Message
37
37 from IPython.zmq.ipkernel import Kernel
38 from .streamkernel import Kernel
39
38
40 class EngineFactory(RegistrationFactory):
39 class EngineFactory(RegistrationFactory):
41 """IPython engine"""
40 """IPython engine"""
42
41
43 # configurables:
42 # configurables:
44 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True,
43 out_stream_factory=Type('IPython.zmq.iostream.OutStream', config=True,
45 help="""The OutStream for handling stdout/err.
44 help="""The OutStream for handling stdout/err.
46 Typically 'IPython.zmq.iostream.OutStream'""")
45 Typically 'IPython.zmq.iostream.OutStream'""")
47 display_hook_factory=Type('IPython.zmq.displayhook.ZMQDisplayHook', config=True,
46 display_hook_factory=Type('IPython.zmq.displayhook.ZMQDisplayHook', config=True,
48 help="""The class for handling displayhook.
47 help="""The class for handling displayhook.
49 Typically 'IPython.zmq.displayhook.ZMQDisplayHook'""")
48 Typically 'IPython.zmq.displayhook.ZMQDisplayHook'""")
50 location=Unicode(config=True,
49 location=Unicode(config=True,
51 help="""The location (an IP address) of the controller. This is
50 help="""The location (an IP address) of the controller. This is
52 used for disambiguating URLs, to determine whether
51 used for disambiguating URLs, to determine whether
53 loopback should be used to connect or the public address.""")
52 loopback should be used to connect or the public address.""")
54 timeout=CFloat(2,config=True,
53 timeout=CFloat(2,config=True,
55 help="""The time (in seconds) to wait for the Controller to respond
54 help="""The time (in seconds) to wait for the Controller to respond
56 to registration requests before giving up.""")
55 to registration requests before giving up.""")
57 sshserver=Unicode(config=True,
56 sshserver=Unicode(config=True,
58 help="""The SSH server to use for tunneling connections to the Controller.""")
57 help="""The SSH server to use for tunneling connections to the Controller.""")
59 sshkey=Unicode(config=True,
58 sshkey=Unicode(config=True,
60 help="""The SSH private key file to use when tunneling connections to the Controller.""")
59 help="""The SSH private key file to use when tunneling connections to the Controller.""")
61 paramiko=Bool(sys.platform == 'win32', config=True,
60 paramiko=Bool(sys.platform == 'win32', config=True,
62 help="""Whether to use paramiko instead of openssh for tunnels.""")
61 help="""Whether to use paramiko instead of openssh for tunnels.""")
63
62
64 # not configurable:
63 # not configurable:
65 user_ns=Dict()
64 user_ns=Dict()
66 id=Integer(allow_none=True)
65 id=Integer(allow_none=True)
67 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
66 registrar=Instance('zmq.eventloop.zmqstream.ZMQStream')
68 kernel=Instance(Kernel)
67 kernel=Instance(Kernel)
69
68
70 bident = CBytes()
69 bident = CBytes()
71 ident = Unicode()
70 ident = Unicode()
72 def _ident_changed(self, name, old, new):
71 def _ident_changed(self, name, old, new):
73 self.bident = asbytes(new)
72 self.bident = cast_bytes(new)
74 using_ssh=Bool(False)
73 using_ssh=Bool(False)
75
74
76
75
77 def __init__(self, **kwargs):
76 def __init__(self, **kwargs):
78 super(EngineFactory, self).__init__(**kwargs)
77 super(EngineFactory, self).__init__(**kwargs)
79 self.ident = self.session.session
78 self.ident = self.session.session
80
79
81 def init_connector(self):
80 def init_connector(self):
82 """construct connection function, which handles tunnels."""
81 """construct connection function, which handles tunnels."""
83 self.using_ssh = bool(self.sshkey or self.sshserver)
82 self.using_ssh = bool(self.sshkey or self.sshserver)
84
83
85 if self.sshkey and not self.sshserver:
84 if self.sshkey and not self.sshserver:
86 # We are using ssh directly to the controller, tunneling localhost to localhost
85 # We are using ssh directly to the controller, tunneling localhost to localhost
87 self.sshserver = self.url.split('://')[1].split(':')[0]
86 self.sshserver = self.url.split('://')[1].split(':')[0]
88
87
89 if self.using_ssh:
88 if self.using_ssh:
90 if tunnel.try_passwordless_ssh(self.sshserver, self.sshkey, self.paramiko):
89 if tunnel.try_passwordless_ssh(self.sshserver, self.sshkey, self.paramiko):
91 password=False
90 password=False
92 else:
91 else:
93 password = getpass("SSH Password for %s: "%self.sshserver)
92 password = getpass("SSH Password for %s: "%self.sshserver)
94 else:
93 else:
95 password = False
94 password = False
96
95
97 def connect(s, url):
96 def connect(s, url):
98 url = disambiguate_url(url, self.location)
97 url = disambiguate_url(url, self.location)
99 if self.using_ssh:
98 if self.using_ssh:
100 self.log.debug("Tunneling connection to %s via %s"%(url, self.sshserver))
99 self.log.debug("Tunneling connection to %s via %s"%(url, self.sshserver))
101 return tunnel.tunnel_connection(s, url, self.sshserver,
100 return tunnel.tunnel_connection(s, url, self.sshserver,
102 keyfile=self.sshkey, paramiko=self.paramiko,
101 keyfile=self.sshkey, paramiko=self.paramiko,
103 password=password,
102 password=password,
104 )
103 )
105 else:
104 else:
106 return s.connect(url)
105 return s.connect(url)
107
106
108 def maybe_tunnel(url):
107 def maybe_tunnel(url):
109 """like connect, but don't complete the connection (for use by heartbeat)"""
108 """like connect, but don't complete the connection (for use by heartbeat)"""
110 url = disambiguate_url(url, self.location)
109 url = disambiguate_url(url, self.location)
111 if self.using_ssh:
110 if self.using_ssh:
112 self.log.debug("Tunneling connection to %s via %s"%(url, self.sshserver))
111 self.log.debug("Tunneling connection to %s via %s"%(url, self.sshserver))
113 url,tunnelobj = tunnel.open_tunnel(url, self.sshserver,
112 url,tunnelobj = tunnel.open_tunnel(url, self.sshserver,
114 keyfile=self.sshkey, paramiko=self.paramiko,
113 keyfile=self.sshkey, paramiko=self.paramiko,
115 password=password,
114 password=password,
116 )
115 )
117 return url
116 return url
118 return connect, maybe_tunnel
117 return connect, maybe_tunnel
119
118
120 def register(self):
119 def register(self):
121 """send the registration_request"""
120 """send the registration_request"""
122
121
123 self.log.info("Registering with controller at %s"%self.url)
122 self.log.info("Registering with controller at %s"%self.url)
124 ctx = self.context
123 ctx = self.context
125 connect,maybe_tunnel = self.init_connector()
124 connect,maybe_tunnel = self.init_connector()
126 reg = ctx.socket(zmq.DEALER)
125 reg = ctx.socket(zmq.DEALER)
127 reg.setsockopt(zmq.IDENTITY, self.bident)
126 reg.setsockopt(zmq.IDENTITY, self.bident)
128 connect(reg, self.url)
127 connect(reg, self.url)
129 self.registrar = zmqstream.ZMQStream(reg, self.loop)
128 self.registrar = zmqstream.ZMQStream(reg, self.loop)
130
129
131
130
132 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
131 content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident)
133 self.registrar.on_recv(lambda msg: self.complete_registration(msg, connect, maybe_tunnel))
132 self.registrar.on_recv(lambda msg: self.complete_registration(msg, connect, maybe_tunnel))
134 # print (self.session.key)
133 # print (self.session.key)
135 self.session.send(self.registrar, "registration_request",content=content)
134 self.session.send(self.registrar, "registration_request",content=content)
136
135
137 def complete_registration(self, msg, connect, maybe_tunnel):
136 def complete_registration(self, msg, connect, maybe_tunnel):
138 # print msg
137 # print msg
139 self._abort_dc.stop()
138 self._abort_dc.stop()
140 ctx = self.context
139 ctx = self.context
141 loop = self.loop
140 loop = self.loop
142 identity = self.bident
141 identity = self.bident
143 idents,msg = self.session.feed_identities(msg)
142 idents,msg = self.session.feed_identities(msg)
144 msg = Message(self.session.unserialize(msg))
143 msg = Message(self.session.unserialize(msg))
145
144
146 if msg.content.status == 'ok':
145 if msg.content.status == 'ok':
147 self.id = int(msg.content.id)
146 self.id = int(msg.content.id)
148
147
149 # launch heartbeat
148 # launch heartbeat
150 hb_addrs = msg.content.heartbeat
149 hb_addrs = msg.content.heartbeat
151
150
152 # possibly forward hb ports with tunnels
151 # possibly forward hb ports with tunnels
153 hb_addrs = [ maybe_tunnel(addr) for addr in hb_addrs ]
152 hb_addrs = [ maybe_tunnel(addr) for addr in hb_addrs ]
154 heart = Heart(*map(str, hb_addrs), heart_id=identity)
153 heart = Heart(*map(str, hb_addrs), heart_id=identity)
155 heart.start()
154 heart.start()
156
155
157 # create Shell Streams (MUX, Task, etc.):
156 # create Shell Streams (MUX, Task, etc.):
158 queue_addr = msg.content.mux
157 queue_addr = msg.content.mux
159 shell_addrs = [ str(queue_addr) ]
158 shell_addrs = [ str(queue_addr) ]
160 task_addr = msg.content.task
159 task_addr = msg.content.task
161 if task_addr:
160 if task_addr:
162 shell_addrs.append(str(task_addr))
161 shell_addrs.append(str(task_addr))
163
162
164 # Uncomment this to go back to two-socket model
163 # Uncomment this to go back to two-socket model
165 # shell_streams = []
164 # shell_streams = []
166 # for addr in shell_addrs:
165 # for addr in shell_addrs:
167 # stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
166 # stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
168 # stream.setsockopt(zmq.IDENTITY, identity)
167 # stream.setsockopt(zmq.IDENTITY, identity)
169 # stream.connect(disambiguate_url(addr, self.location))
168 # stream.connect(disambiguate_url(addr, self.location))
170 # shell_streams.append(stream)
169 # shell_streams.append(stream)
171
170
172 # Now use only one shell stream for mux and tasks
171 # Now use only one shell stream for mux and tasks
173 stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
172 stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
174 stream.setsockopt(zmq.IDENTITY, identity)
173 stream.setsockopt(zmq.IDENTITY, identity)
175 shell_streams = [stream]
174 shell_streams = [stream]
176 for addr in shell_addrs:
175 for addr in shell_addrs:
177 connect(stream, addr)
176 connect(stream, addr)
178 # end single stream-socket
177 # end single stream-socket
179
178
180 # control stream:
179 # control stream:
181 control_addr = str(msg.content.control)
180 control_addr = str(msg.content.control)
182 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
181 control_stream = zmqstream.ZMQStream(ctx.socket(zmq.ROUTER), loop)
183 control_stream.setsockopt(zmq.IDENTITY, identity)
182 control_stream.setsockopt(zmq.IDENTITY, identity)
184 connect(control_stream, control_addr)
183 connect(control_stream, control_addr)
185
184
186 # create iopub stream:
185 # create iopub stream:
187 iopub_addr = msg.content.iopub
186 iopub_addr = msg.content.iopub
188 iopub_stream = zmqstream.ZMQStream(ctx.socket(zmq.PUB), loop)
187 iopub_socket = ctx.socket(zmq.PUB)
189 iopub_stream.setsockopt(zmq.IDENTITY, identity)
188 iopub_socket.setsockopt(zmq.IDENTITY, identity)
190 connect(iopub_stream, iopub_addr)
189 connect(iopub_socket, iopub_addr)
191
190
192 # # Redirect input streams and set a display hook.
191 # disable history:
192 self.config.HistoryManager.hist_file = ':memory:'
193
194 # Redirect input streams and set a display hook.
193 if self.out_stream_factory:
195 if self.out_stream_factory:
194 sys.stdout = self.out_stream_factory(self.session, iopub_stream, u'stdout')
196 sys.stdout = self.out_stream_factory(self.session, iopub_socket, u'stdout')
195 sys.stdout.topic = py3compat.cast_bytes('engine.%i.stdout' % self.id)
197 sys.stdout.topic = cast_bytes('engine.%i.stdout' % self.id)
196 sys.stderr = self.out_stream_factory(self.session, iopub_stream, u'stderr')
198 sys.stderr = self.out_stream_factory(self.session, iopub_socket, u'stderr')
197 sys.stderr.topic = py3compat.cast_bytes('engine.%i.stderr' % self.id)
199 sys.stderr.topic = cast_bytes('engine.%i.stderr' % self.id)
198 if self.display_hook_factory:
200 if self.display_hook_factory:
199 sys.displayhook = self.display_hook_factory(self.session, iopub_stream)
201 sys.displayhook = self.display_hook_factory(self.session, iopub_socket)
200 sys.displayhook.topic = py3compat.cast_bytes('engine.%i.pyout' % self.id)
202 sys.displayhook.topic = cast_bytes('engine.%i.pyout' % self.id)
201
203
202 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
204 self.kernel = Kernel(config=self.config, int_id=self.id, ident=self.ident, session=self.session,
203 control_stream=control_stream, shell_streams=shell_streams, iopub_stream=iopub_stream,
205 control_stream=control_stream, shell_streams=shell_streams, iopub_socket=iopub_socket,
204 loop=loop, user_ns = self.user_ns, log=self.log)
206 loop=loop, user_ns=self.user_ns, log=self.log)
207 self.kernel.shell.display_pub.topic = cast_bytes('engine.%i.displaypub' % self.id)
205 self.kernel.start()
208 self.kernel.start()
206
209
207
210
208 else:
211 else:
209 self.log.fatal("Registration Failed: %s"%msg)
212 self.log.fatal("Registration Failed: %s"%msg)
210 raise Exception("Registration Failed: %s"%msg)
213 raise Exception("Registration Failed: %s"%msg)
211
214
212 self.log.info("Completed registration with id %i"%self.id)
215 self.log.info("Completed registration with id %i"%self.id)
213
216
214
217
215 def abort(self):
218 def abort(self):
216 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
219 self.log.fatal("Registration timed out after %.1f seconds"%self.timeout)
217 if self.url.startswith('127.'):
220 if self.url.startswith('127.'):
218 self.log.fatal("""
221 self.log.fatal("""
219 If the controller and engines are not on the same machine,
222 If the controller and engines are not on the same machine,
220 you will have to instruct the controller to listen on an external IP (in ipcontroller_config.py):
223 you will have to instruct the controller to listen on an external IP (in ipcontroller_config.py):
221 c.HubFactory.ip='*' # for all interfaces, internal and external
224 c.HubFactory.ip='*' # for all interfaces, internal and external
222 c.HubFactory.ip='192.168.1.101' # or any interface that the engines can see
225 c.HubFactory.ip='192.168.1.101' # or any interface that the engines can see
223 or tunnel connections via ssh.
226 or tunnel connections via ssh.
224 """)
227 """)
225 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
228 self.session.send(self.registrar, "unregistration_request", content=dict(id=self.id))
226 time.sleep(1)
229 time.sleep(1)
227 sys.exit(255)
230 sys.exit(255)
228
231
229 def start(self):
232 def start(self):
230 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
233 dc = ioloop.DelayedCallback(self.register, 0, self.loop)
231 dc.start()
234 dc.start()
232 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
235 self._abort_dc = ioloop.DelayedCallback(self.abort, self.timeout*1000, self.loop)
233 self._abort_dc.start()
236 self._abort_dc.start()
234
237
@@ -1,122 +1,122 b''
1 """toplevel setup/teardown for parallel tests."""
1 """toplevel setup/teardown for parallel tests."""
2
2
3 #-------------------------------------------------------------------------------
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2011 The IPython Development Team
4 # Copyright (C) 2011 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9
9
10 #-------------------------------------------------------------------------------
10 #-------------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-------------------------------------------------------------------------------
12 #-------------------------------------------------------------------------------
13
13
14 import os
14 import os
15 import tempfile
15 import tempfile
16 import time
16 import time
17 from subprocess import Popen
17 from subprocess import Popen
18
18
19 from IPython.utils.path import get_ipython_dir
19 from IPython.utils.path import get_ipython_dir
20 from IPython.parallel import Client
20 from IPython.parallel import Client
21 from IPython.parallel.apps.launcher import (LocalProcessLauncher,
21 from IPython.parallel.apps.launcher import (LocalProcessLauncher,
22 ipengine_cmd_argv,
22 ipengine_cmd_argv,
23 ipcontroller_cmd_argv,
23 ipcontroller_cmd_argv,
24 SIGKILL,
24 SIGKILL,
25 ProcessStateError,
25 ProcessStateError,
26 )
26 )
27
27
28 # globals
28 # globals
29 launchers = []
29 launchers = []
30 blackhole = open(os.devnull, 'w')
30 blackhole = open(os.devnull, 'w')
31
31
32 # Launcher class
32 # Launcher class
33 class TestProcessLauncher(LocalProcessLauncher):
33 class TestProcessLauncher(LocalProcessLauncher):
34 """subclass LocalProcessLauncher, to prevent extra sockets and threads being created on Windows"""
34 """subclass LocalProcessLauncher, to prevent extra sockets and threads being created on Windows"""
35 def start(self):
35 def start(self):
36 if self.state == 'before':
36 if self.state == 'before':
37 self.process = Popen(self.args,
37 self.process = Popen(self.args,
38 stdout=blackhole, stderr=blackhole,
38 stdout=blackhole, stderr=blackhole,
39 env=os.environ,
39 env=os.environ,
40 cwd=self.work_dir
40 cwd=self.work_dir
41 )
41 )
42 self.notify_start(self.process.pid)
42 self.notify_start(self.process.pid)
43 self.poll = self.process.poll
43 self.poll = self.process.poll
44 else:
44 else:
45 s = 'The process was already started and has state: %r' % self.state
45 s = 'The process was already started and has state: %r' % self.state
46 raise ProcessStateError(s)
46 raise ProcessStateError(s)
47
47
48 # nose setup/teardown
48 # nose setup/teardown
49
49
50 def setup():
50 def setup():
51 cluster_dir = os.path.join(get_ipython_dir(), 'profile_iptest')
51 cluster_dir = os.path.join(get_ipython_dir(), 'profile_iptest')
52 engine_json = os.path.join(cluster_dir, 'security', 'ipcontroller-engine.json')
52 engine_json = os.path.join(cluster_dir, 'security', 'ipcontroller-engine.json')
53 client_json = os.path.join(cluster_dir, 'security', 'ipcontroller-client.json')
53 client_json = os.path.join(cluster_dir, 'security', 'ipcontroller-client.json')
54 for json in (engine_json, client_json):
54 for json in (engine_json, client_json):
55 if os.path.exists(json):
55 if os.path.exists(json):
56 os.remove(json)
56 os.remove(json)
57
57
58 cp = TestProcessLauncher()
58 cp = TestProcessLauncher()
59 cp.cmd_and_args = ipcontroller_cmd_argv + \
59 cp.cmd_and_args = ipcontroller_cmd_argv + \
60 ['--profile=iptest', '--log-level=50', '--ping=250']
60 ['--profile=iptest', '--log-level=50', '--ping=250']
61 cp.start()
61 cp.start()
62 launchers.append(cp)
62 launchers.append(cp)
63 tic = time.time()
63 tic = time.time()
64 while not os.path.exists(engine_json) or not os.path.exists(client_json):
64 while not os.path.exists(engine_json) or not os.path.exists(client_json):
65 if cp.poll() is not None:
65 if cp.poll() is not None:
66 print cp.poll()
66 print cp.poll()
67 raise RuntimeError("The test controller failed to start.")
67 raise RuntimeError("The test controller failed to start.")
68 elif time.time()-tic > 10:
68 elif time.time()-tic > 15:
69 raise RuntimeError("Timeout waiting for the test controller to start.")
69 raise RuntimeError("Timeout waiting for the test controller to start.")
70 time.sleep(0.1)
70 time.sleep(0.1)
71 add_engines(1)
71 add_engines(1)
72
72
73 def add_engines(n=1, profile='iptest', total=False):
73 def add_engines(n=1, profile='iptest', total=False):
74 """add a number of engines to a given profile.
74 """add a number of engines to a given profile.
75
75
76 If total is True, then already running engines are counted, and only
76 If total is True, then already running engines are counted, and only
77 the additional engines necessary (if any) are started.
77 the additional engines necessary (if any) are started.
78 """
78 """
79 rc = Client(profile=profile)
79 rc = Client(profile=profile)
80 base = len(rc)
80 base = len(rc)
81
81
82 if total:
82 if total:
83 n = max(n - base, 0)
83 n = max(n - base, 0)
84
84
85 eps = []
85 eps = []
86 for i in range(n):
86 for i in range(n):
87 ep = TestProcessLauncher()
87 ep = TestProcessLauncher()
88 ep.cmd_and_args = ipengine_cmd_argv + ['--profile=%s'%profile, '--log-level=50']
88 ep.cmd_and_args = ipengine_cmd_argv + ['--profile=%s'%profile, '--log-level=50']
89 ep.start()
89 ep.start()
90 launchers.append(ep)
90 launchers.append(ep)
91 eps.append(ep)
91 eps.append(ep)
92 tic = time.time()
92 tic = time.time()
93 while len(rc) < base+n:
93 while len(rc) < base+n:
94 if any([ ep.poll() is not None for ep in eps ]):
94 if any([ ep.poll() is not None for ep in eps ]):
95 raise RuntimeError("A test engine failed to start.")
95 raise RuntimeError("A test engine failed to start.")
96 elif time.time()-tic > 10:
96 elif time.time()-tic > 15:
97 raise RuntimeError("Timeout waiting for engines to connect.")
97 raise RuntimeError("Timeout waiting for engines to connect.")
98 time.sleep(.1)
98 time.sleep(.1)
99 rc.spin()
99 rc.spin()
100 rc.close()
100 rc.close()
101 return eps
101 return eps
102
102
103 def teardown():
103 def teardown():
104 time.sleep(1)
104 time.sleep(1)
105 while launchers:
105 while launchers:
106 p = launchers.pop()
106 p = launchers.pop()
107 if p.poll() is None:
107 if p.poll() is None:
108 try:
108 try:
109 p.stop()
109 p.stop()
110 except Exception, e:
110 except Exception, e:
111 print e
111 print e
112 pass
112 pass
113 if p.poll() is None:
113 if p.poll() is None:
114 time.sleep(.25)
114 time.sleep(.25)
115 if p.poll() is None:
115 if p.poll() is None:
116 try:
116 try:
117 print 'cleaning up test process...'
117 print 'cleaning up test process...'
118 p.signal(SIGKILL)
118 p.signal(SIGKILL)
119 except:
119 except:
120 print "couldn't shutdown process: ", p
120 print "couldn't shutdown process: ", p
121 blackhole.close()
121 blackhole.close()
122
122
@@ -1,205 +1,205 b''
1 """Tests for asyncresult.py
1 """Tests for asyncresult.py
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import time
19 import time
20
20
21 from IPython.parallel.error import TimeoutError
21 from IPython.parallel.error import TimeoutError
22
22
23 from IPython.parallel import error, Client
23 from IPython.parallel import error, Client
24 from IPython.parallel.tests import add_engines
24 from IPython.parallel.tests import add_engines
25 from .clienttest import ClusterTestCase
25 from .clienttest import ClusterTestCase
26
26
27 def setup():
27 def setup():
28 add_engines(2, total=True)
28 add_engines(2, total=True)
29
29
30 def wait(n):
30 def wait(n):
31 import time
31 import time
32 time.sleep(n)
32 time.sleep(n)
33 return n
33 return n
34
34
35 class AsyncResultTest(ClusterTestCase):
35 class AsyncResultTest(ClusterTestCase):
36
36
37 def test_single_result_view(self):
37 def test_single_result_view(self):
38 """various one-target views get the right value for single_result"""
38 """various one-target views get the right value for single_result"""
39 eid = self.client.ids[-1]
39 eid = self.client.ids[-1]
40 ar = self.client[eid].apply_async(lambda : 42)
40 ar = self.client[eid].apply_async(lambda : 42)
41 self.assertEquals(ar.get(), 42)
41 self.assertEquals(ar.get(), 42)
42 ar = self.client[[eid]].apply_async(lambda : 42)
42 ar = self.client[[eid]].apply_async(lambda : 42)
43 self.assertEquals(ar.get(), [42])
43 self.assertEquals(ar.get(), [42])
44 ar = self.client[-1:].apply_async(lambda : 42)
44 ar = self.client[-1:].apply_async(lambda : 42)
45 self.assertEquals(ar.get(), [42])
45 self.assertEquals(ar.get(), [42])
46
46
47 def test_get_after_done(self):
47 def test_get_after_done(self):
48 ar = self.client[-1].apply_async(lambda : 42)
48 ar = self.client[-1].apply_async(lambda : 42)
49 ar.wait()
49 ar.wait()
50 self.assertTrue(ar.ready())
50 self.assertTrue(ar.ready())
51 self.assertEquals(ar.get(), 42)
51 self.assertEquals(ar.get(), 42)
52 self.assertEquals(ar.get(), 42)
52 self.assertEquals(ar.get(), 42)
53
53
54 def test_get_before_done(self):
54 def test_get_before_done(self):
55 ar = self.client[-1].apply_async(wait, 0.1)
55 ar = self.client[-1].apply_async(wait, 0.1)
56 self.assertRaises(TimeoutError, ar.get, 0)
56 self.assertRaises(TimeoutError, ar.get, 0)
57 ar.wait(0)
57 ar.wait(0)
58 self.assertFalse(ar.ready())
58 self.assertFalse(ar.ready())
59 self.assertEquals(ar.get(), 0.1)
59 self.assertEquals(ar.get(), 0.1)
60
60
61 def test_get_after_error(self):
61 def test_get_after_error(self):
62 ar = self.client[-1].apply_async(lambda : 1/0)
62 ar = self.client[-1].apply_async(lambda : 1/0)
63 ar.wait(10)
63 ar.wait(10)
64 self.assertRaisesRemote(ZeroDivisionError, ar.get)
64 self.assertRaisesRemote(ZeroDivisionError, ar.get)
65 self.assertRaisesRemote(ZeroDivisionError, ar.get)
65 self.assertRaisesRemote(ZeroDivisionError, ar.get)
66 self.assertRaisesRemote(ZeroDivisionError, ar.get_dict)
66 self.assertRaisesRemote(ZeroDivisionError, ar.get_dict)
67
67
68 def test_get_dict(self):
68 def test_get_dict(self):
69 n = len(self.client)
69 n = len(self.client)
70 ar = self.client[:].apply_async(lambda : 5)
70 ar = self.client[:].apply_async(lambda : 5)
71 self.assertEquals(ar.get(), [5]*n)
71 self.assertEquals(ar.get(), [5]*n)
72 d = ar.get_dict()
72 d = ar.get_dict()
73 self.assertEquals(sorted(d.keys()), sorted(self.client.ids))
73 self.assertEquals(sorted(d.keys()), sorted(self.client.ids))
74 for eid,r in d.iteritems():
74 for eid,r in d.iteritems():
75 self.assertEquals(r, 5)
75 self.assertEquals(r, 5)
76
76
77 def test_list_amr(self):
77 def test_list_amr(self):
78 ar = self.client.load_balanced_view().map_async(wait, [0.1]*5)
78 ar = self.client.load_balanced_view().map_async(wait, [0.1]*5)
79 rlist = list(ar)
79 rlist = list(ar)
80
80
81 def test_getattr(self):
81 def test_getattr(self):
82 ar = self.client[:].apply_async(wait, 0.5)
82 ar = self.client[:].apply_async(wait, 0.5)
83 self.assertRaises(AttributeError, lambda : ar._foo)
83 self.assertRaises(AttributeError, lambda : ar._foo)
84 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
84 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
85 self.assertRaises(AttributeError, lambda : ar.foo)
85 self.assertRaises(AttributeError, lambda : ar.foo)
86 self.assertRaises(AttributeError, lambda : ar.engine_id)
86 self.assertRaises(AttributeError, lambda : ar.engine_id)
87 self.assertFalse(hasattr(ar, '__length_hint__'))
87 self.assertFalse(hasattr(ar, '__length_hint__'))
88 self.assertFalse(hasattr(ar, 'foo'))
88 self.assertFalse(hasattr(ar, 'foo'))
89 self.assertFalse(hasattr(ar, 'engine_id'))
89 self.assertFalse(hasattr(ar, 'engine_id'))
90 ar.get(5)
90 ar.get(5)
91 self.assertRaises(AttributeError, lambda : ar._foo)
91 self.assertRaises(AttributeError, lambda : ar._foo)
92 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
92 self.assertRaises(AttributeError, lambda : ar.__length_hint__())
93 self.assertRaises(AttributeError, lambda : ar.foo)
93 self.assertRaises(AttributeError, lambda : ar.foo)
94 self.assertTrue(isinstance(ar.engine_id, list))
94 self.assertTrue(isinstance(ar.engine_id, list))
95 self.assertEquals(ar.engine_id, ar['engine_id'])
95 self.assertEquals(ar.engine_id, ar['engine_id'])
96 self.assertFalse(hasattr(ar, '__length_hint__'))
96 self.assertFalse(hasattr(ar, '__length_hint__'))
97 self.assertFalse(hasattr(ar, 'foo'))
97 self.assertFalse(hasattr(ar, 'foo'))
98 self.assertTrue(hasattr(ar, 'engine_id'))
98 self.assertTrue(hasattr(ar, 'engine_id'))
99
99
100 def test_getitem(self):
100 def test_getitem(self):
101 ar = self.client[:].apply_async(wait, 0.5)
101 ar = self.client[:].apply_async(wait, 0.5)
102 self.assertRaises(TimeoutError, lambda : ar['foo'])
102 self.assertRaises(TimeoutError, lambda : ar['foo'])
103 self.assertRaises(TimeoutError, lambda : ar['engine_id'])
103 self.assertRaises(TimeoutError, lambda : ar['engine_id'])
104 ar.get(5)
104 ar.get(5)
105 self.assertRaises(KeyError, lambda : ar['foo'])
105 self.assertRaises(KeyError, lambda : ar['foo'])
106 self.assertTrue(isinstance(ar['engine_id'], list))
106 self.assertTrue(isinstance(ar['engine_id'], list))
107 self.assertEquals(ar.engine_id, ar['engine_id'])
107 self.assertEquals(ar.engine_id, ar['engine_id'])
108
108
109 def test_single_result(self):
109 def test_single_result(self):
110 ar = self.client[-1].apply_async(wait, 0.5)
110 ar = self.client[-1].apply_async(wait, 0.5)
111 self.assertRaises(TimeoutError, lambda : ar['foo'])
111 self.assertRaises(TimeoutError, lambda : ar['foo'])
112 self.assertRaises(TimeoutError, lambda : ar['engine_id'])
112 self.assertRaises(TimeoutError, lambda : ar['engine_id'])
113 self.assertTrue(ar.get(5) == 0.5)
113 self.assertTrue(ar.get(5) == 0.5)
114 self.assertTrue(isinstance(ar['engine_id'], int))
114 self.assertTrue(isinstance(ar['engine_id'], int))
115 self.assertTrue(isinstance(ar.engine_id, int))
115 self.assertTrue(isinstance(ar.engine_id, int))
116 self.assertEquals(ar.engine_id, ar['engine_id'])
116 self.assertEquals(ar.engine_id, ar['engine_id'])
117
117
118 def test_abort(self):
118 def test_abort(self):
119 e = self.client[-1]
119 e = self.client[-1]
120 ar = e.execute('import time; time.sleep(1)', block=False)
120 ar = e.execute('import time; time.sleep(1)', block=False)
121 ar2 = e.apply_async(lambda : 2)
121 ar2 = e.apply_async(lambda : 2)
122 ar2.abort()
122 ar2.abort()
123 self.assertRaises(error.TaskAborted, ar2.get)
123 self.assertRaises(error.TaskAborted, ar2.get)
124 ar.get()
124 ar.get()
125
125
126 def test_len(self):
126 def test_len(self):
127 v = self.client.load_balanced_view()
127 v = self.client.load_balanced_view()
128 ar = v.map_async(lambda x: x, range(10))
128 ar = v.map_async(lambda x: x, range(10))
129 self.assertEquals(len(ar), 10)
129 self.assertEquals(len(ar), 10)
130 ar = v.apply_async(lambda x: x, range(10))
130 ar = v.apply_async(lambda x: x, range(10))
131 self.assertEquals(len(ar), 1)
131 self.assertEquals(len(ar), 1)
132 ar = self.client[:].apply_async(lambda x: x, range(10))
132 ar = self.client[:].apply_async(lambda x: x, range(10))
133 self.assertEquals(len(ar), len(self.client.ids))
133 self.assertEquals(len(ar), len(self.client.ids))
134
134
135 def test_wall_time_single(self):
135 def test_wall_time_single(self):
136 v = self.client.load_balanced_view()
136 v = self.client.load_balanced_view()
137 ar = v.apply_async(time.sleep, 0.25)
137 ar = v.apply_async(time.sleep, 0.25)
138 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
138 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
139 ar.get(2)
139 ar.get(2)
140 self.assertTrue(ar.wall_time < 1.)
140 self.assertTrue(ar.wall_time < 1.)
141 self.assertTrue(ar.wall_time > 0.2)
141 self.assertTrue(ar.wall_time > 0.2)
142
142
143 def test_wall_time_multi(self):
143 def test_wall_time_multi(self):
144 self.minimum_engines(4)
144 self.minimum_engines(4)
145 v = self.client[:]
145 v = self.client[:]
146 ar = v.apply_async(time.sleep, 0.25)
146 ar = v.apply_async(time.sleep, 0.25)
147 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
147 self.assertRaises(TimeoutError, getattr, ar, 'wall_time')
148 ar.get(2)
148 ar.get(2)
149 self.assertTrue(ar.wall_time < 1.)
149 self.assertTrue(ar.wall_time < 1.)
150 self.assertTrue(ar.wall_time > 0.2)
150 self.assertTrue(ar.wall_time > 0.2)
151
151
152 def test_serial_time_single(self):
152 def test_serial_time_single(self):
153 v = self.client.load_balanced_view()
153 v = self.client.load_balanced_view()
154 ar = v.apply_async(time.sleep, 0.25)
154 ar = v.apply_async(time.sleep, 0.25)
155 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
155 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
156 ar.get(2)
156 ar.get(2)
157 self.assertTrue(ar.serial_time < 0.5)
157 self.assertTrue(ar.serial_time < 1.)
158 self.assertTrue(ar.serial_time > 0.2)
158 self.assertTrue(ar.serial_time > 0.2)
159
159
160 def test_serial_time_multi(self):
160 def test_serial_time_multi(self):
161 self.minimum_engines(4)
161 self.minimum_engines(4)
162 v = self.client[:]
162 v = self.client[:]
163 ar = v.apply_async(time.sleep, 0.25)
163 ar = v.apply_async(time.sleep, 0.25)
164 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
164 self.assertRaises(TimeoutError, getattr, ar, 'serial_time')
165 ar.get(2)
165 ar.get(2)
166 self.assertTrue(ar.serial_time < 2.)
166 self.assertTrue(ar.serial_time < 2.)
167 self.assertTrue(ar.serial_time > 0.8)
167 self.assertTrue(ar.serial_time > 0.8)
168
168
169 def test_elapsed_single(self):
169 def test_elapsed_single(self):
170 v = self.client.load_balanced_view()
170 v = self.client.load_balanced_view()
171 ar = v.apply_async(time.sleep, 0.25)
171 ar = v.apply_async(time.sleep, 0.25)
172 while not ar.ready():
172 while not ar.ready():
173 time.sleep(0.01)
173 time.sleep(0.01)
174 self.assertTrue(ar.elapsed < 0.3)
174 self.assertTrue(ar.elapsed < 1)
175 self.assertTrue(ar.elapsed < 0.3)
175 self.assertTrue(ar.elapsed < 1)
176 ar.get(2)
176 ar.get(2)
177
177
178 def test_elapsed_multi(self):
178 def test_elapsed_multi(self):
179 v = self.client[:]
179 v = self.client[:]
180 ar = v.apply_async(time.sleep, 0.25)
180 ar = v.apply_async(time.sleep, 0.25)
181 while not ar.ready():
181 while not ar.ready():
182 time.sleep(0.01)
182 time.sleep(0.01)
183 self.assertTrue(ar.elapsed < 0.3)
183 self.assertTrue(ar.elapsed < 1)
184 self.assertTrue(ar.elapsed < 0.3)
184 self.assertTrue(ar.elapsed < 1)
185 ar.get(2)
185 ar.get(2)
186
186
187 def test_hubresult_timestamps(self):
187 def test_hubresult_timestamps(self):
188 self.minimum_engines(4)
188 self.minimum_engines(4)
189 v = self.client[:]
189 v = self.client[:]
190 ar = v.apply_async(time.sleep, 0.25)
190 ar = v.apply_async(time.sleep, 0.25)
191 ar.get(2)
191 ar.get(2)
192 rc2 = Client(profile='iptest')
192 rc2 = Client(profile='iptest')
193 # must have try/finally to close second Client, otherwise
193 # must have try/finally to close second Client, otherwise
194 # will have dangling sockets causing problems
194 # will have dangling sockets causing problems
195 try:
195 try:
196 time.sleep(0.25)
196 time.sleep(0.25)
197 hr = rc2.get_result(ar.msg_ids)
197 hr = rc2.get_result(ar.msg_ids)
198 self.assertTrue(hr.elapsed > 0., "got bad elapsed: %s" % hr.elapsed)
198 self.assertTrue(hr.elapsed > 0., "got bad elapsed: %s" % hr.elapsed)
199 hr.get(1)
199 hr.get(1)
200 self.assertTrue(hr.wall_time < ar.wall_time + 0.2, "got bad wall_time: %s > %s" % (hr.wall_time, ar.wall_time))
200 self.assertTrue(hr.wall_time < ar.wall_time + 0.2, "got bad wall_time: %s > %s" % (hr.wall_time, ar.wall_time))
201 self.assertEquals(hr.serial_time, ar.serial_time)
201 self.assertEquals(hr.serial_time, ar.serial_time)
202 finally:
202 finally:
203 rc2.close()
203 rc2.close()
204
204
205
205
@@ -1,349 +1,387 b''
1 """Tests for parallel client.py
1 """Tests for parallel client.py
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import time
21 import time
22 from datetime import datetime
22 from datetime import datetime
23 from tempfile import mktemp
23 from tempfile import mktemp
24
24
25 import zmq
25 import zmq
26
26
27 from IPython.parallel.client import client as clientmod
27 from IPython.parallel.client import client as clientmod
28 from IPython.parallel import error
28 from IPython.parallel import error
29 from IPython.parallel import AsyncResult, AsyncHubResult
29 from IPython.parallel import AsyncResult, AsyncHubResult
30 from IPython.parallel import LoadBalancedView, DirectView
30 from IPython.parallel import LoadBalancedView, DirectView
31
31
32 from clienttest import ClusterTestCase, segfault, wait, add_engines
32 from clienttest import ClusterTestCase, segfault, wait, add_engines
33
33
34 def setup():
34 def setup():
35 add_engines(4, total=True)
35 add_engines(4, total=True)
36
36
37 class TestClient(ClusterTestCase):
37 class TestClient(ClusterTestCase):
38
38
39 def test_ids(self):
39 def test_ids(self):
40 n = len(self.client.ids)
40 n = len(self.client.ids)
41 self.add_engines(2)
41 self.add_engines(2)
42 self.assertEquals(len(self.client.ids), n+2)
42 self.assertEquals(len(self.client.ids), n+2)
43
43
44 def test_view_indexing(self):
44 def test_view_indexing(self):
45 """test index access for views"""
45 """test index access for views"""
46 self.minimum_engines(4)
46 self.minimum_engines(4)
47 targets = self.client._build_targets('all')[-1]
47 targets = self.client._build_targets('all')[-1]
48 v = self.client[:]
48 v = self.client[:]
49 self.assertEquals(v.targets, targets)
49 self.assertEquals(v.targets, targets)
50 t = self.client.ids[2]
50 t = self.client.ids[2]
51 v = self.client[t]
51 v = self.client[t]
52 self.assert_(isinstance(v, DirectView))
52 self.assert_(isinstance(v, DirectView))
53 self.assertEquals(v.targets, t)
53 self.assertEquals(v.targets, t)
54 t = self.client.ids[2:4]
54 t = self.client.ids[2:4]
55 v = self.client[t]
55 v = self.client[t]
56 self.assert_(isinstance(v, DirectView))
56 self.assert_(isinstance(v, DirectView))
57 self.assertEquals(v.targets, t)
57 self.assertEquals(v.targets, t)
58 v = self.client[::2]
58 v = self.client[::2]
59 self.assert_(isinstance(v, DirectView))
59 self.assert_(isinstance(v, DirectView))
60 self.assertEquals(v.targets, targets[::2])
60 self.assertEquals(v.targets, targets[::2])
61 v = self.client[1::3]
61 v = self.client[1::3]
62 self.assert_(isinstance(v, DirectView))
62 self.assert_(isinstance(v, DirectView))
63 self.assertEquals(v.targets, targets[1::3])
63 self.assertEquals(v.targets, targets[1::3])
64 v = self.client[:-3]
64 v = self.client[:-3]
65 self.assert_(isinstance(v, DirectView))
65 self.assert_(isinstance(v, DirectView))
66 self.assertEquals(v.targets, targets[:-3])
66 self.assertEquals(v.targets, targets[:-3])
67 v = self.client[-1]
67 v = self.client[-1]
68 self.assert_(isinstance(v, DirectView))
68 self.assert_(isinstance(v, DirectView))
69 self.assertEquals(v.targets, targets[-1])
69 self.assertEquals(v.targets, targets[-1])
70 self.assertRaises(TypeError, lambda : self.client[None])
70 self.assertRaises(TypeError, lambda : self.client[None])
71
71
72 def test_lbview_targets(self):
72 def test_lbview_targets(self):
73 """test load_balanced_view targets"""
73 """test load_balanced_view targets"""
74 v = self.client.load_balanced_view()
74 v = self.client.load_balanced_view()
75 self.assertEquals(v.targets, None)
75 self.assertEquals(v.targets, None)
76 v = self.client.load_balanced_view(-1)
76 v = self.client.load_balanced_view(-1)
77 self.assertEquals(v.targets, [self.client.ids[-1]])
77 self.assertEquals(v.targets, [self.client.ids[-1]])
78 v = self.client.load_balanced_view('all')
78 v = self.client.load_balanced_view('all')
79 self.assertEquals(v.targets, None)
79 self.assertEquals(v.targets, None)
80
80
81 def test_dview_targets(self):
81 def test_dview_targets(self):
82 """test direct_view targets"""
82 """test direct_view targets"""
83 v = self.client.direct_view()
83 v = self.client.direct_view()
84 self.assertEquals(v.targets, 'all')
84 self.assertEquals(v.targets, 'all')
85 v = self.client.direct_view('all')
85 v = self.client.direct_view('all')
86 self.assertEquals(v.targets, 'all')
86 self.assertEquals(v.targets, 'all')
87 v = self.client.direct_view(-1)
87 v = self.client.direct_view(-1)
88 self.assertEquals(v.targets, self.client.ids[-1])
88 self.assertEquals(v.targets, self.client.ids[-1])
89
89
90 def test_lazy_all_targets(self):
90 def test_lazy_all_targets(self):
91 """test lazy evaluation of rc.direct_view('all')"""
91 """test lazy evaluation of rc.direct_view('all')"""
92 v = self.client.direct_view()
92 v = self.client.direct_view()
93 self.assertEquals(v.targets, 'all')
93 self.assertEquals(v.targets, 'all')
94
94
95 def double(x):
95 def double(x):
96 return x*2
96 return x*2
97 seq = range(100)
97 seq = range(100)
98 ref = [ double(x) for x in seq ]
98 ref = [ double(x) for x in seq ]
99
99
100 # add some engines, which should be used
100 # add some engines, which should be used
101 self.add_engines(1)
101 self.add_engines(1)
102 n1 = len(self.client.ids)
102 n1 = len(self.client.ids)
103
103
104 # simple apply
104 # simple apply
105 r = v.apply_sync(lambda : 1)
105 r = v.apply_sync(lambda : 1)
106 self.assertEquals(r, [1] * n1)
106 self.assertEquals(r, [1] * n1)
107
107
108 # map goes through remotefunction
108 # map goes through remotefunction
109 r = v.map_sync(double, seq)
109 r = v.map_sync(double, seq)
110 self.assertEquals(r, ref)
110 self.assertEquals(r, ref)
111
111
112 # add a couple more engines, and try again
112 # add a couple more engines, and try again
113 self.add_engines(2)
113 self.add_engines(2)
114 n2 = len(self.client.ids)
114 n2 = len(self.client.ids)
115 self.assertNotEquals(n2, n1)
115 self.assertNotEquals(n2, n1)
116
116
117 # apply
117 # apply
118 r = v.apply_sync(lambda : 1)
118 r = v.apply_sync(lambda : 1)
119 self.assertEquals(r, [1] * n2)
119 self.assertEquals(r, [1] * n2)
120
120
121 # map
121 # map
122 r = v.map_sync(double, seq)
122 r = v.map_sync(double, seq)
123 self.assertEquals(r, ref)
123 self.assertEquals(r, ref)
124
124
125 def test_targets(self):
125 def test_targets(self):
126 """test various valid targets arguments"""
126 """test various valid targets arguments"""
127 build = self.client._build_targets
127 build = self.client._build_targets
128 ids = self.client.ids
128 ids = self.client.ids
129 idents,targets = build(None)
129 idents,targets = build(None)
130 self.assertEquals(ids, targets)
130 self.assertEquals(ids, targets)
131
131
132 def test_clear(self):
132 def test_clear(self):
133 """test clear behavior"""
133 """test clear behavior"""
134 self.minimum_engines(2)
134 self.minimum_engines(2)
135 v = self.client[:]
135 v = self.client[:]
136 v.block=True
136 v.block=True
137 v.push(dict(a=5))
137 v.push(dict(a=5))
138 v.pull('a')
138 v.pull('a')
139 id0 = self.client.ids[-1]
139 id0 = self.client.ids[-1]
140 self.client.clear(targets=id0, block=True)
140 self.client.clear(targets=id0, block=True)
141 a = self.client[:-1].get('a')
141 a = self.client[:-1].get('a')
142 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
142 self.assertRaisesRemote(NameError, self.client[id0].get, 'a')
143 self.client.clear(block=True)
143 self.client.clear(block=True)
144 for i in self.client.ids:
144 for i in self.client.ids:
145 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
145 self.assertRaisesRemote(NameError, self.client[i].get, 'a')
146
146
147 def test_get_result(self):
147 def test_get_result(self):
148 """test getting results from the Hub."""
148 """test getting results from the Hub."""
149 c = clientmod.Client(profile='iptest')
149 c = clientmod.Client(profile='iptest')
150 t = c.ids[-1]
150 t = c.ids[-1]
151 ar = c[t].apply_async(wait, 1)
151 ar = c[t].apply_async(wait, 1)
152 # give the monitor time to notice the message
152 # give the monitor time to notice the message
153 time.sleep(.25)
153 time.sleep(.25)
154 ahr = self.client.get_result(ar.msg_ids)
154 ahr = self.client.get_result(ar.msg_ids)
155 self.assertTrue(isinstance(ahr, AsyncHubResult))
155 self.assertTrue(isinstance(ahr, AsyncHubResult))
156 self.assertEquals(ahr.get(), ar.get())
156 self.assertEquals(ahr.get(), ar.get())
157 ar2 = self.client.get_result(ar.msg_ids)
157 ar2 = self.client.get_result(ar.msg_ids)
158 self.assertFalse(isinstance(ar2, AsyncHubResult))
158 self.assertFalse(isinstance(ar2, AsyncHubResult))
159 c.close()
159 c.close()
160
160
161 def test_ids_list(self):
161 def test_ids_list(self):
162 """test client.ids"""
162 """test client.ids"""
163 ids = self.client.ids
163 ids = self.client.ids
164 self.assertEquals(ids, self.client._ids)
164 self.assertEquals(ids, self.client._ids)
165 self.assertFalse(ids is self.client._ids)
165 self.assertFalse(ids is self.client._ids)
166 ids.remove(ids[-1])
166 ids.remove(ids[-1])
167 self.assertNotEquals(ids, self.client._ids)
167 self.assertNotEquals(ids, self.client._ids)
168
168
169 def test_queue_status(self):
169 def test_queue_status(self):
170 ids = self.client.ids
170 ids = self.client.ids
171 id0 = ids[0]
171 id0 = ids[0]
172 qs = self.client.queue_status(targets=id0)
172 qs = self.client.queue_status(targets=id0)
173 self.assertTrue(isinstance(qs, dict))
173 self.assertTrue(isinstance(qs, dict))
174 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
174 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
175 allqs = self.client.queue_status()
175 allqs = self.client.queue_status()
176 self.assertTrue(isinstance(allqs, dict))
176 self.assertTrue(isinstance(allqs, dict))
177 intkeys = list(allqs.keys())
177 intkeys = list(allqs.keys())
178 intkeys.remove('unassigned')
178 intkeys.remove('unassigned')
179 self.assertEquals(sorted(intkeys), sorted(self.client.ids))
179 self.assertEquals(sorted(intkeys), sorted(self.client.ids))
180 unassigned = allqs.pop('unassigned')
180 unassigned = allqs.pop('unassigned')
181 for eid,qs in allqs.items():
181 for eid,qs in allqs.items():
182 self.assertTrue(isinstance(qs, dict))
182 self.assertTrue(isinstance(qs, dict))
183 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
183 self.assertEquals(sorted(qs.keys()), ['completed', 'queue', 'tasks'])
184
184
185 def test_shutdown(self):
185 def test_shutdown(self):
186 ids = self.client.ids
186 ids = self.client.ids
187 id0 = ids[0]
187 id0 = ids[0]
188 self.client.shutdown(id0, block=True)
188 self.client.shutdown(id0, block=True)
189 while id0 in self.client.ids:
189 while id0 in self.client.ids:
190 time.sleep(0.1)
190 time.sleep(0.1)
191 self.client.spin()
191 self.client.spin()
192
192
193 self.assertRaises(IndexError, lambda : self.client[id0])
193 self.assertRaises(IndexError, lambda : self.client[id0])
194
194
195 def test_result_status(self):
195 def test_result_status(self):
196 pass
196 pass
197 # to be written
197 # to be written
198
198
199 def test_db_query_dt(self):
199 def test_db_query_dt(self):
200 """test db query by date"""
200 """test db query by date"""
201 hist = self.client.hub_history()
201 hist = self.client.hub_history()
202 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
202 middle = self.client.db_query({'msg_id' : hist[len(hist)//2]})[0]
203 tic = middle['submitted']
203 tic = middle['submitted']
204 before = self.client.db_query({'submitted' : {'$lt' : tic}})
204 before = self.client.db_query({'submitted' : {'$lt' : tic}})
205 after = self.client.db_query({'submitted' : {'$gte' : tic}})
205 after = self.client.db_query({'submitted' : {'$gte' : tic}})
206 self.assertEquals(len(before)+len(after),len(hist))
206 self.assertEquals(len(before)+len(after),len(hist))
207 for b in before:
207 for b in before:
208 self.assertTrue(b['submitted'] < tic)
208 self.assertTrue(b['submitted'] < tic)
209 for a in after:
209 for a in after:
210 self.assertTrue(a['submitted'] >= tic)
210 self.assertTrue(a['submitted'] >= tic)
211 same = self.client.db_query({'submitted' : tic})
211 same = self.client.db_query({'submitted' : tic})
212 for s in same:
212 for s in same:
213 self.assertTrue(s['submitted'] == tic)
213 self.assertTrue(s['submitted'] == tic)
214
214
215 def test_db_query_keys(self):
215 def test_db_query_keys(self):
216 """test extracting subset of record keys"""
216 """test extracting subset of record keys"""
217 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
217 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
218 for rec in found:
218 for rec in found:
219 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
219 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
220
220
221 def test_db_query_default_keys(self):
221 def test_db_query_default_keys(self):
222 """default db_query excludes buffers"""
222 """default db_query excludes buffers"""
223 found = self.client.db_query({'msg_id': {'$ne' : ''}})
223 found = self.client.db_query({'msg_id': {'$ne' : ''}})
224 for rec in found:
224 for rec in found:
225 keys = set(rec.keys())
225 keys = set(rec.keys())
226 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
226 self.assertFalse('buffers' in keys, "'buffers' should not be in: %s" % keys)
227 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
227 self.assertFalse('result_buffers' in keys, "'result_buffers' should not be in: %s" % keys)
228
228
229 def test_db_query_msg_id(self):
229 def test_db_query_msg_id(self):
230 """ensure msg_id is always in db queries"""
230 """ensure msg_id is always in db queries"""
231 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
231 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
232 for rec in found:
232 for rec in found:
233 self.assertTrue('msg_id' in rec.keys())
233 self.assertTrue('msg_id' in rec.keys())
234 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
234 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['submitted'])
235 for rec in found:
235 for rec in found:
236 self.assertTrue('msg_id' in rec.keys())
236 self.assertTrue('msg_id' in rec.keys())
237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
237 found = self.client.db_query({'msg_id': {'$ne' : ''}},keys=['msg_id'])
238 for rec in found:
238 for rec in found:
239 self.assertTrue('msg_id' in rec.keys())
239 self.assertTrue('msg_id' in rec.keys())
240
240
241 def test_db_query_get_result(self):
241 def test_db_query_get_result(self):
242 """pop in db_query shouldn't pop from result itself"""
242 """pop in db_query shouldn't pop from result itself"""
243 self.client[:].apply_sync(lambda : 1)
243 self.client[:].apply_sync(lambda : 1)
244 found = self.client.db_query({'msg_id': {'$ne' : ''}})
244 found = self.client.db_query({'msg_id': {'$ne' : ''}})
245 rc2 = clientmod.Client(profile='iptest')
245 rc2 = clientmod.Client(profile='iptest')
246 # If this bug is not fixed, this call will hang:
246 # If this bug is not fixed, this call will hang:
247 ar = rc2.get_result(self.client.history[-1])
247 ar = rc2.get_result(self.client.history[-1])
248 ar.wait(2)
248 ar.wait(2)
249 self.assertTrue(ar.ready())
249 self.assertTrue(ar.ready())
250 ar.get()
250 ar.get()
251 rc2.close()
251 rc2.close()
252
252
253 def test_db_query_in(self):
253 def test_db_query_in(self):
254 """test db query with '$in','$nin' operators"""
254 """test db query with '$in','$nin' operators"""
255 hist = self.client.hub_history()
255 hist = self.client.hub_history()
256 even = hist[::2]
256 even = hist[::2]
257 odd = hist[1::2]
257 odd = hist[1::2]
258 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
258 recs = self.client.db_query({ 'msg_id' : {'$in' : even}})
259 found = [ r['msg_id'] for r in recs ]
259 found = [ r['msg_id'] for r in recs ]
260 self.assertEquals(set(even), set(found))
260 self.assertEquals(set(even), set(found))
261 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
261 recs = self.client.db_query({ 'msg_id' : {'$nin' : even}})
262 found = [ r['msg_id'] for r in recs ]
262 found = [ r['msg_id'] for r in recs ]
263 self.assertEquals(set(odd), set(found))
263 self.assertEquals(set(odd), set(found))
264
264
265 def test_hub_history(self):
265 def test_hub_history(self):
266 hist = self.client.hub_history()
266 hist = self.client.hub_history()
267 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
267 recs = self.client.db_query({ 'msg_id' : {"$ne":''}})
268 recdict = {}
268 recdict = {}
269 for rec in recs:
269 for rec in recs:
270 recdict[rec['msg_id']] = rec
270 recdict[rec['msg_id']] = rec
271
271
272 latest = datetime(1984,1,1)
272 latest = datetime(1984,1,1)
273 for msg_id in hist:
273 for msg_id in hist:
274 rec = recdict[msg_id]
274 rec = recdict[msg_id]
275 newt = rec['submitted']
275 newt = rec['submitted']
276 self.assertTrue(newt >= latest)
276 self.assertTrue(newt >= latest)
277 latest = newt
277 latest = newt
278 ar = self.client[-1].apply_async(lambda : 1)
278 ar = self.client[-1].apply_async(lambda : 1)
279 ar.get()
279 ar.get()
280 time.sleep(0.25)
280 time.sleep(0.25)
281 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
281 self.assertEquals(self.client.hub_history()[-1:],ar.msg_ids)
282
282
283 def _wait_for_idle(self):
284 """wait for an engine to become idle, according to the Hub"""
285 rc = self.client
286
287 # timeout 2s, polling every 100ms
288 for i in range(20):
289 qs = rc.queue_status()
290 if qs['unassigned'] or any(qs[eid]['tasks'] for eid in rc.ids):
291 time.sleep(0.1)
292 else:
293 break
294
295 # ensure Hub up to date:
296 qs = rc.queue_status()
297 self.assertEquals(qs['unassigned'], 0)
298 for eid in rc.ids:
299 self.assertEquals(qs[eid]['tasks'], 0)
300
301
283 def test_resubmit(self):
302 def test_resubmit(self):
284 def f():
303 def f():
285 import random
304 import random
286 return random.random()
305 return random.random()
287 v = self.client.load_balanced_view()
306 v = self.client.load_balanced_view()
288 ar = v.apply_async(f)
307 ar = v.apply_async(f)
289 r1 = ar.get(1)
308 r1 = ar.get(1)
290 # give the Hub a chance to notice:
309 # give the Hub a chance to notice:
291 time.sleep(0.5)
310 self._wait_for_idle()
292 ahr = self.client.resubmit(ar.msg_ids)
311 ahr = self.client.resubmit(ar.msg_ids)
293 r2 = ahr.get(1)
312 r2 = ahr.get(1)
294 self.assertFalse(r1 == r2)
313 self.assertFalse(r1 == r2)
295
314
315 def test_resubmit_aborted(self):
316 def f():
317 import random
318 return random.random()
319 v = self.client.load_balanced_view()
320 # restrict to one engine, so we can put a sleep
321 # ahead of the task, so it will get aborted
322 eid = self.client.ids[-1]
323 v.targets = [eid]
324 sleep = v.apply_async(time.sleep, 0.5)
325 ar = v.apply_async(f)
326 ar.abort()
327 self.assertRaises(error.TaskAborted, ar.get)
328 # Give the Hub a chance to get up to date:
329 self._wait_for_idle()
330 ahr = self.client.resubmit(ar.msg_ids)
331 r2 = ahr.get(1)
332
296 def test_resubmit_inflight(self):
333 def test_resubmit_inflight(self):
297 """ensure ValueError on resubmit of inflight task"""
334 """resubmit of inflight task"""
298 v = self.client.load_balanced_view()
335 v = self.client.load_balanced_view()
299 ar = v.apply_async(time.sleep,1)
336 ar = v.apply_async(time.sleep,1)
300 # give the message a chance to arrive
337 # give the message a chance to arrive
301 time.sleep(0.2)
338 time.sleep(0.2)
302 self.assertRaisesRemote(ValueError, self.client.resubmit, ar.msg_ids)
339 ahr = self.client.resubmit(ar.msg_ids)
303 ar.get(2)
340 ar.get(2)
341 ahr.get(2)
304
342
305 def test_resubmit_badkey(self):
343 def test_resubmit_badkey(self):
306 """ensure KeyError on resubmit of nonexistant task"""
344 """ensure KeyError on resubmit of nonexistant task"""
307 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
345 self.assertRaisesRemote(KeyError, self.client.resubmit, ['invalid'])
308
346
309 def test_purge_results(self):
347 def test_purge_results(self):
310 # ensure there are some tasks
348 # ensure there are some tasks
311 for i in range(5):
349 for i in range(5):
312 self.client[:].apply_sync(lambda : 1)
350 self.client[:].apply_sync(lambda : 1)
313 # Wait for the Hub to realise the result is done:
351 # Wait for the Hub to realise the result is done:
314 # This prevents a race condition, where we
352 # This prevents a race condition, where we
315 # might purge a result the Hub still thinks is pending.
353 # might purge a result the Hub still thinks is pending.
316 time.sleep(0.1)
354 time.sleep(0.1)
317 rc2 = clientmod.Client(profile='iptest')
355 rc2 = clientmod.Client(profile='iptest')
318 hist = self.client.hub_history()
356 hist = self.client.hub_history()
319 ahr = rc2.get_result([hist[-1]])
357 ahr = rc2.get_result([hist[-1]])
320 ahr.wait(10)
358 ahr.wait(10)
321 self.client.purge_results(hist[-1])
359 self.client.purge_results(hist[-1])
322 newhist = self.client.hub_history()
360 newhist = self.client.hub_history()
323 self.assertEquals(len(newhist)+1,len(hist))
361 self.assertEquals(len(newhist)+1,len(hist))
324 rc2.spin()
362 rc2.spin()
325 rc2.close()
363 rc2.close()
326
364
327 def test_purge_all_results(self):
365 def test_purge_all_results(self):
328 self.client.purge_results('all')
366 self.client.purge_results('all')
329 hist = self.client.hub_history()
367 hist = self.client.hub_history()
330 self.assertEquals(len(hist), 0)
368 self.assertEquals(len(hist), 0)
331
369
332 def test_spin_thread(self):
370 def test_spin_thread(self):
333 self.client.spin_thread(0.01)
371 self.client.spin_thread(0.01)
334 ar = self.client[-1].apply_async(lambda : 1)
372 ar = self.client[-1].apply_async(lambda : 1)
335 time.sleep(0.1)
373 time.sleep(0.1)
336 self.assertTrue(ar.wall_time < 0.1,
374 self.assertTrue(ar.wall_time < 0.1,
337 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
375 "spin should have kept wall_time < 0.1, but got %f" % ar.wall_time
338 )
376 )
339
377
340 def test_stop_spin_thread(self):
378 def test_stop_spin_thread(self):
341 self.client.spin_thread(0.01)
379 self.client.spin_thread(0.01)
342 self.client.stop_spin_thread()
380 self.client.stop_spin_thread()
343 ar = self.client[-1].apply_async(lambda : 1)
381 ar = self.client[-1].apply_async(lambda : 1)
344 time.sleep(0.15)
382 time.sleep(0.15)
345 self.assertTrue(ar.wall_time > 0.1,
383 self.assertTrue(ar.wall_time > 0.1,
346 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
384 "Shouldn't be spinning, but got wall_time=%f" % ar.wall_time
347 )
385 )
348
386
349
387
@@ -1,240 +1,243 b''
1 """Tests for db backends
1 """Tests for db backends
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7
7
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 from __future__ import division
19 from __future__ import division
20
20
21 import logging
21 import os
22 import os
22 import tempfile
23 import tempfile
23 import time
24 import time
24
25
25 from datetime import datetime, timedelta
26 from datetime import datetime, timedelta
26 from unittest import TestCase
27 from unittest import TestCase
27
28
28 from IPython.parallel import error
29 from IPython.parallel import error
29 from IPython.parallel.controller.dictdb import DictDB
30 from IPython.parallel.controller.dictdb import DictDB
30 from IPython.parallel.controller.sqlitedb import SQLiteDB
31 from IPython.parallel.controller.sqlitedb import SQLiteDB
31 from IPython.parallel.controller.hub import init_record, empty_record
32 from IPython.parallel.controller.hub import init_record, empty_record
32
33
33 from IPython.testing import decorators as dec
34 from IPython.testing import decorators as dec
34 from IPython.zmq.session import Session
35 from IPython.zmq.session import Session
35
36
36
37
37 #-------------------------------------------------------------------------------
38 #-------------------------------------------------------------------------------
38 # TestCases
39 # TestCases
39 #-------------------------------------------------------------------------------
40 #-------------------------------------------------------------------------------
40
41
41
42
42 def setup():
43 def setup():
43 global temp_db
44 global temp_db
44 temp_db = tempfile.NamedTemporaryFile(suffix='.db').name
45 temp_db = tempfile.NamedTemporaryFile(suffix='.db').name
45
46
46
47
47 class TestDictBackend(TestCase):
48 class TestDictBackend(TestCase):
48 def setUp(self):
49 def setUp(self):
49 self.session = Session()
50 self.session = Session()
50 self.db = self.create_db()
51 self.db = self.create_db()
51 self.load_records(16)
52 self.load_records(16)
52
53
53 def create_db(self):
54 def create_db(self):
54 return DictDB()
55 return DictDB()
55
56
56 def load_records(self, n=1):
57 def load_records(self, n=1):
57 """load n records for testing"""
58 """load n records for testing"""
58 #sleep 1/10 s, to ensure timestamp is different to previous calls
59 #sleep 1/10 s, to ensure timestamp is different to previous calls
59 time.sleep(0.1)
60 time.sleep(0.1)
60 msg_ids = []
61 msg_ids = []
61 for i in range(n):
62 for i in range(n):
62 msg = self.session.msg('apply_request', content=dict(a=5))
63 msg = self.session.msg('apply_request', content=dict(a=5))
63 msg['buffers'] = []
64 msg['buffers'] = []
64 rec = init_record(msg)
65 rec = init_record(msg)
65 msg_id = msg['header']['msg_id']
66 msg_id = msg['header']['msg_id']
66 msg_ids.append(msg_id)
67 msg_ids.append(msg_id)
67 self.db.add_record(msg_id, rec)
68 self.db.add_record(msg_id, rec)
68 return msg_ids
69 return msg_ids
69
70
70 def test_add_record(self):
71 def test_add_record(self):
71 before = self.db.get_history()
72 before = self.db.get_history()
72 self.load_records(5)
73 self.load_records(5)
73 after = self.db.get_history()
74 after = self.db.get_history()
74 self.assertEquals(len(after), len(before)+5)
75 self.assertEquals(len(after), len(before)+5)
75 self.assertEquals(after[:-5],before)
76 self.assertEquals(after[:-5],before)
76
77
77 def test_drop_record(self):
78 def test_drop_record(self):
78 msg_id = self.load_records()[-1]
79 msg_id = self.load_records()[-1]
79 rec = self.db.get_record(msg_id)
80 rec = self.db.get_record(msg_id)
80 self.db.drop_record(msg_id)
81 self.db.drop_record(msg_id)
81 self.assertRaises(KeyError,self.db.get_record, msg_id)
82 self.assertRaises(KeyError,self.db.get_record, msg_id)
82
83
83 def _round_to_millisecond(self, dt):
84 def _round_to_millisecond(self, dt):
84 """necessary because mongodb rounds microseconds"""
85 """necessary because mongodb rounds microseconds"""
85 micro = dt.microsecond
86 micro = dt.microsecond
86 extra = int(str(micro)[-3:])
87 extra = int(str(micro)[-3:])
87 return dt - timedelta(microseconds=extra)
88 return dt - timedelta(microseconds=extra)
88
89
89 def test_update_record(self):
90 def test_update_record(self):
90 now = self._round_to_millisecond(datetime.now())
91 now = self._round_to_millisecond(datetime.now())
91 #
92 #
92 msg_id = self.db.get_history()[-1]
93 msg_id = self.db.get_history()[-1]
93 rec1 = self.db.get_record(msg_id)
94 rec1 = self.db.get_record(msg_id)
94 data = {'stdout': 'hello there', 'completed' : now}
95 data = {'stdout': 'hello there', 'completed' : now}
95 self.db.update_record(msg_id, data)
96 self.db.update_record(msg_id, data)
96 rec2 = self.db.get_record(msg_id)
97 rec2 = self.db.get_record(msg_id)
97 self.assertEquals(rec2['stdout'], 'hello there')
98 self.assertEquals(rec2['stdout'], 'hello there')
98 self.assertEquals(rec2['completed'], now)
99 self.assertEquals(rec2['completed'], now)
99 rec1.update(data)
100 rec1.update(data)
100 self.assertEquals(rec1, rec2)
101 self.assertEquals(rec1, rec2)
101
102
102 # def test_update_record_bad(self):
103 # def test_update_record_bad(self):
103 # """test updating nonexistant records"""
104 # """test updating nonexistant records"""
104 # msg_id = str(uuid.uuid4())
105 # msg_id = str(uuid.uuid4())
105 # data = {'stdout': 'hello there'}
106 # data = {'stdout': 'hello there'}
106 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
107 # self.assertRaises(KeyError, self.db.update_record, msg_id, data)
107
108
108 def test_find_records_dt(self):
109 def test_find_records_dt(self):
109 """test finding records by date"""
110 """test finding records by date"""
110 hist = self.db.get_history()
111 hist = self.db.get_history()
111 middle = self.db.get_record(hist[len(hist)//2])
112 middle = self.db.get_record(hist[len(hist)//2])
112 tic = middle['submitted']
113 tic = middle['submitted']
113 before = self.db.find_records({'submitted' : {'$lt' : tic}})
114 before = self.db.find_records({'submitted' : {'$lt' : tic}})
114 after = self.db.find_records({'submitted' : {'$gte' : tic}})
115 after = self.db.find_records({'submitted' : {'$gte' : tic}})
115 self.assertEquals(len(before)+len(after),len(hist))
116 self.assertEquals(len(before)+len(after),len(hist))
116 for b in before:
117 for b in before:
117 self.assertTrue(b['submitted'] < tic)
118 self.assertTrue(b['submitted'] < tic)
118 for a in after:
119 for a in after:
119 self.assertTrue(a['submitted'] >= tic)
120 self.assertTrue(a['submitted'] >= tic)
120 same = self.db.find_records({'submitted' : tic})
121 same = self.db.find_records({'submitted' : tic})
121 for s in same:
122 for s in same:
122 self.assertTrue(s['submitted'] == tic)
123 self.assertTrue(s['submitted'] == tic)
123
124
124 def test_find_records_keys(self):
125 def test_find_records_keys(self):
125 """test extracting subset of record keys"""
126 """test extracting subset of record keys"""
126 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
127 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
127 for rec in found:
128 for rec in found:
128 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
129 self.assertEquals(set(rec.keys()), set(['msg_id', 'submitted', 'completed']))
129
130
130 def test_find_records_msg_id(self):
131 def test_find_records_msg_id(self):
131 """ensure msg_id is always in found records"""
132 """ensure msg_id is always in found records"""
132 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
133 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted', 'completed'])
133 for rec in found:
134 for rec in found:
134 self.assertTrue('msg_id' in rec.keys())
135 self.assertTrue('msg_id' in rec.keys())
135 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
136 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['submitted'])
136 for rec in found:
137 for rec in found:
137 self.assertTrue('msg_id' in rec.keys())
138 self.assertTrue('msg_id' in rec.keys())
138 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
139 found = self.db.find_records({'msg_id': {'$ne' : ''}},keys=['msg_id'])
139 for rec in found:
140 for rec in found:
140 self.assertTrue('msg_id' in rec.keys())
141 self.assertTrue('msg_id' in rec.keys())
141
142
142 def test_find_records_in(self):
143 def test_find_records_in(self):
143 """test finding records with '$in','$nin' operators"""
144 """test finding records with '$in','$nin' operators"""
144 hist = self.db.get_history()
145 hist = self.db.get_history()
145 even = hist[::2]
146 even = hist[::2]
146 odd = hist[1::2]
147 odd = hist[1::2]
147 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
148 recs = self.db.find_records({ 'msg_id' : {'$in' : even}})
148 found = [ r['msg_id'] for r in recs ]
149 found = [ r['msg_id'] for r in recs ]
149 self.assertEquals(set(even), set(found))
150 self.assertEquals(set(even), set(found))
150 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
151 recs = self.db.find_records({ 'msg_id' : {'$nin' : even}})
151 found = [ r['msg_id'] for r in recs ]
152 found = [ r['msg_id'] for r in recs ]
152 self.assertEquals(set(odd), set(found))
153 self.assertEquals(set(odd), set(found))
153
154
154 def test_get_history(self):
155 def test_get_history(self):
155 msg_ids = self.db.get_history()
156 msg_ids = self.db.get_history()
156 latest = datetime(1984,1,1)
157 latest = datetime(1984,1,1)
157 for msg_id in msg_ids:
158 for msg_id in msg_ids:
158 rec = self.db.get_record(msg_id)
159 rec = self.db.get_record(msg_id)
159 newt = rec['submitted']
160 newt = rec['submitted']
160 self.assertTrue(newt >= latest)
161 self.assertTrue(newt >= latest)
161 latest = newt
162 latest = newt
162 msg_id = self.load_records(1)[-1]
163 msg_id = self.load_records(1)[-1]
163 self.assertEquals(self.db.get_history()[-1],msg_id)
164 self.assertEquals(self.db.get_history()[-1],msg_id)
164
165
165 def test_datetime(self):
166 def test_datetime(self):
166 """get/set timestamps with datetime objects"""
167 """get/set timestamps with datetime objects"""
167 msg_id = self.db.get_history()[-1]
168 msg_id = self.db.get_history()[-1]
168 rec = self.db.get_record(msg_id)
169 rec = self.db.get_record(msg_id)
169 self.assertTrue(isinstance(rec['submitted'], datetime))
170 self.assertTrue(isinstance(rec['submitted'], datetime))
170 self.db.update_record(msg_id, dict(completed=datetime.now()))
171 self.db.update_record(msg_id, dict(completed=datetime.now()))
171 rec = self.db.get_record(msg_id)
172 rec = self.db.get_record(msg_id)
172 self.assertTrue(isinstance(rec['completed'], datetime))
173 self.assertTrue(isinstance(rec['completed'], datetime))
173
174
174 def test_drop_matching(self):
175 def test_drop_matching(self):
175 msg_ids = self.load_records(10)
176 msg_ids = self.load_records(10)
176 query = {'msg_id' : {'$in':msg_ids}}
177 query = {'msg_id' : {'$in':msg_ids}}
177 self.db.drop_matching_records(query)
178 self.db.drop_matching_records(query)
178 recs = self.db.find_records(query)
179 recs = self.db.find_records(query)
179 self.assertEquals(len(recs), 0)
180 self.assertEquals(len(recs), 0)
180
181
181 def test_null(self):
182 def test_null(self):
182 """test None comparison queries"""
183 """test None comparison queries"""
183 msg_ids = self.load_records(10)
184 msg_ids = self.load_records(10)
184
185
185 query = {'msg_id' : None}
186 query = {'msg_id' : None}
186 recs = self.db.find_records(query)
187 recs = self.db.find_records(query)
187 self.assertEquals(len(recs), 0)
188 self.assertEquals(len(recs), 0)
188
189
189 query = {'msg_id' : {'$ne' : None}}
190 query = {'msg_id' : {'$ne' : None}}
190 recs = self.db.find_records(query)
191 recs = self.db.find_records(query)
191 self.assertTrue(len(recs) >= 10)
192 self.assertTrue(len(recs) >= 10)
192
193
193 def test_pop_safe_get(self):
194 def test_pop_safe_get(self):
194 """editing query results shouldn't affect record [get]"""
195 """editing query results shouldn't affect record [get]"""
195 msg_id = self.db.get_history()[-1]
196 msg_id = self.db.get_history()[-1]
196 rec = self.db.get_record(msg_id)
197 rec = self.db.get_record(msg_id)
197 rec.pop('buffers')
198 rec.pop('buffers')
198 rec['garbage'] = 'hello'
199 rec['garbage'] = 'hello'
199 rec2 = self.db.get_record(msg_id)
200 rec2 = self.db.get_record(msg_id)
200 self.assertTrue('buffers' in rec2)
201 self.assertTrue('buffers' in rec2)
201 self.assertFalse('garbage' in rec2)
202 self.assertFalse('garbage' in rec2)
202
203
203 def test_pop_safe_find(self):
204 def test_pop_safe_find(self):
204 """editing query results shouldn't affect record [find]"""
205 """editing query results shouldn't affect record [find]"""
205 msg_id = self.db.get_history()[-1]
206 msg_id = self.db.get_history()[-1]
206 rec = self.db.find_records({'msg_id' : msg_id})[0]
207 rec = self.db.find_records({'msg_id' : msg_id})[0]
207 rec.pop('buffers')
208 rec.pop('buffers')
208 rec['garbage'] = 'hello'
209 rec['garbage'] = 'hello'
209 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
210 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
210 self.assertTrue('buffers' in rec2)
211 self.assertTrue('buffers' in rec2)
211 self.assertFalse('garbage' in rec2)
212 self.assertFalse('garbage' in rec2)
212
213
213 def test_pop_safe_find_keys(self):
214 def test_pop_safe_find_keys(self):
214 """editing query results shouldn't affect record [find+keys]"""
215 """editing query results shouldn't affect record [find+keys]"""
215 msg_id = self.db.get_history()[-1]
216 msg_id = self.db.get_history()[-1]
216 rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers'])[0]
217 rec = self.db.find_records({'msg_id' : msg_id}, keys=['buffers'])[0]
217 rec.pop('buffers')
218 rec.pop('buffers')
218 rec['garbage'] = 'hello'
219 rec['garbage'] = 'hello'
219 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
220 rec2 = self.db.find_records({'msg_id' : msg_id})[0]
220 self.assertTrue('buffers' in rec2)
221 self.assertTrue('buffers' in rec2)
221 self.assertFalse('garbage' in rec2)
222 self.assertFalse('garbage' in rec2)
222
223
223
224
224 class TestSQLiteBackend(TestDictBackend):
225 class TestSQLiteBackend(TestDictBackend):
225
226
226 @dec.skip_without('sqlite3')
227 @dec.skip_without('sqlite3')
227 def create_db(self):
228 def create_db(self):
228 location, fname = os.path.split(temp_db)
229 location, fname = os.path.split(temp_db)
229 return SQLiteDB(location=location, fname=fname)
230 log = logging.getLogger('test')
231 log.setLevel(logging.CRITICAL)
232 return SQLiteDB(location=location, fname=fname, log=log)
230
233
231 def tearDown(self):
234 def tearDown(self):
232 self.db._db.close()
235 self.db._db.close()
233
236
234
237
235 def teardown():
238 def teardown():
236 """cleanup task db file after all tests have run"""
239 """cleanup task db file after all tests have run"""
237 try:
240 try:
238 os.remove(temp_db)
241 os.remove(temp_db)
239 except:
242 except:
240 pass
243 pass
@@ -1,553 +1,694 b''
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """test View objects
2 """test View objects
3
3
4 Authors:
4 Authors:
5
5
6 * Min RK
6 * Min RK
7 """
7 """
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9 # Copyright (C) 2011 The IPython Development Team
9 # Copyright (C) 2011 The IPython Development Team
10 #
10 #
11 # Distributed under the terms of the BSD License. The full license is in
11 # Distributed under the terms of the BSD License. The full license is in
12 # the file COPYING, distributed as part of this software.
12 # the file COPYING, distributed as part of this software.
13 #-------------------------------------------------------------------------------
13 #-------------------------------------------------------------------------------
14
14
15 #-------------------------------------------------------------------------------
15 #-------------------------------------------------------------------------------
16 # Imports
16 # Imports
17 #-------------------------------------------------------------------------------
17 #-------------------------------------------------------------------------------
18
18
19 import sys
19 import sys
20 import time
20 import time
21 from tempfile import mktemp
21 from tempfile import mktemp
22 from StringIO import StringIO
22 from StringIO import StringIO
23
23
24 import zmq
24 import zmq
25 from nose import SkipTest
25 from nose import SkipTest
26
26
27 from IPython.testing import decorators as dec
27 from IPython.testing import decorators as dec
28 from IPython.testing.ipunittest import ParametricTestCase
28
29
29 from IPython import parallel as pmod
30 from IPython import parallel as pmod
30 from IPython.parallel import error
31 from IPython.parallel import error
31 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
32 from IPython.parallel import AsyncResult, AsyncHubResult, AsyncMapResult
32 from IPython.parallel import DirectView
33 from IPython.parallel import DirectView
33 from IPython.parallel.util import interactive
34 from IPython.parallel.util import interactive
34
35
35 from IPython.parallel.tests import add_engines
36 from IPython.parallel.tests import add_engines
36
37
37 from .clienttest import ClusterTestCase, crash, wait, skip_without
38 from .clienttest import ClusterTestCase, crash, wait, skip_without
38
39
39 def setup():
40 def setup():
40 add_engines(3, total=True)
41 add_engines(3, total=True)
41
42
42 class TestView(ClusterTestCase):
43 class TestView(ClusterTestCase, ParametricTestCase):
43
44
44 def test_z_crash_mux(self):
45 def test_z_crash_mux(self):
45 """test graceful handling of engine death (direct)"""
46 """test graceful handling of engine death (direct)"""
46 raise SkipTest("crash tests disabled, due to undesirable crash reports")
47 raise SkipTest("crash tests disabled, due to undesirable crash reports")
47 # self.add_engines(1)
48 # self.add_engines(1)
48 eid = self.client.ids[-1]
49 eid = self.client.ids[-1]
49 ar = self.client[eid].apply_async(crash)
50 ar = self.client[eid].apply_async(crash)
50 self.assertRaisesRemote(error.EngineError, ar.get, 10)
51 self.assertRaisesRemote(error.EngineError, ar.get, 10)
51 eid = ar.engine_id
52 eid = ar.engine_id
52 tic = time.time()
53 tic = time.time()
53 while eid in self.client.ids and time.time()-tic < 5:
54 while eid in self.client.ids and time.time()-tic < 5:
54 time.sleep(.01)
55 time.sleep(.01)
55 self.client.spin()
56 self.client.spin()
56 self.assertFalse(eid in self.client.ids, "Engine should have died")
57 self.assertFalse(eid in self.client.ids, "Engine should have died")
57
58
58 def test_push_pull(self):
59 def test_push_pull(self):
59 """test pushing and pulling"""
60 """test pushing and pulling"""
60 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
61 data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'})
61 t = self.client.ids[-1]
62 t = self.client.ids[-1]
62 v = self.client[t]
63 v = self.client[t]
63 push = v.push
64 push = v.push
64 pull = v.pull
65 pull = v.pull
65 v.block=True
66 v.block=True
66 nengines = len(self.client)
67 nengines = len(self.client)
67 push({'data':data})
68 push({'data':data})
68 d = pull('data')
69 d = pull('data')
69 self.assertEquals(d, data)
70 self.assertEquals(d, data)
70 self.client[:].push({'data':data})
71 self.client[:].push({'data':data})
71 d = self.client[:].pull('data', block=True)
72 d = self.client[:].pull('data', block=True)
72 self.assertEquals(d, nengines*[data])
73 self.assertEquals(d, nengines*[data])
73 ar = push({'data':data}, block=False)
74 ar = push({'data':data}, block=False)
74 self.assertTrue(isinstance(ar, AsyncResult))
75 self.assertTrue(isinstance(ar, AsyncResult))
75 r = ar.get()
76 r = ar.get()
76 ar = self.client[:].pull('data', block=False)
77 ar = self.client[:].pull('data', block=False)
77 self.assertTrue(isinstance(ar, AsyncResult))
78 self.assertTrue(isinstance(ar, AsyncResult))
78 r = ar.get()
79 r = ar.get()
79 self.assertEquals(r, nengines*[data])
80 self.assertEquals(r, nengines*[data])
80 self.client[:].push(dict(a=10,b=20))
81 self.client[:].push(dict(a=10,b=20))
81 r = self.client[:].pull(('a','b'), block=True)
82 r = self.client[:].pull(('a','b'), block=True)
82 self.assertEquals(r, nengines*[[10,20]])
83 self.assertEquals(r, nengines*[[10,20]])
83
84
84 def test_push_pull_function(self):
85 def test_push_pull_function(self):
85 "test pushing and pulling functions"
86 "test pushing and pulling functions"
86 def testf(x):
87 def testf(x):
87 return 2.0*x
88 return 2.0*x
88
89
89 t = self.client.ids[-1]
90 t = self.client.ids[-1]
90 v = self.client[t]
91 v = self.client[t]
91 v.block=True
92 v.block=True
92 push = v.push
93 push = v.push
93 pull = v.pull
94 pull = v.pull
94 execute = v.execute
95 execute = v.execute
95 push({'testf':testf})
96 push({'testf':testf})
96 r = pull('testf')
97 r = pull('testf')
97 self.assertEqual(r(1.0), testf(1.0))
98 self.assertEqual(r(1.0), testf(1.0))
98 execute('r = testf(10)')
99 execute('r = testf(10)')
99 r = pull('r')
100 r = pull('r')
100 self.assertEquals(r, testf(10))
101 self.assertEquals(r, testf(10))
101 ar = self.client[:].push({'testf':testf}, block=False)
102 ar = self.client[:].push({'testf':testf}, block=False)
102 ar.get()
103 ar.get()
103 ar = self.client[:].pull('testf', block=False)
104 ar = self.client[:].pull('testf', block=False)
104 rlist = ar.get()
105 rlist = ar.get()
105 for r in rlist:
106 for r in rlist:
106 self.assertEqual(r(1.0), testf(1.0))
107 self.assertEqual(r(1.0), testf(1.0))
107 execute("def g(x): return x*x")
108 execute("def g(x): return x*x")
108 r = pull(('testf','g'))
109 r = pull(('testf','g'))
109 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
110 self.assertEquals((r[0](10),r[1](10)), (testf(10), 100))
110
111
111 def test_push_function_globals(self):
112 def test_push_function_globals(self):
112 """test that pushed functions have access to globals"""
113 """test that pushed functions have access to globals"""
113 @interactive
114 @interactive
114 def geta():
115 def geta():
115 return a
116 return a
116 # self.add_engines(1)
117 # self.add_engines(1)
117 v = self.client[-1]
118 v = self.client[-1]
118 v.block=True
119 v.block=True
119 v['f'] = geta
120 v['f'] = geta
120 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
121 self.assertRaisesRemote(NameError, v.execute, 'b=f()')
121 v.execute('a=5')
122 v.execute('a=5')
122 v.execute('b=f()')
123 v.execute('b=f()')
123 self.assertEquals(v['b'], 5)
124 self.assertEquals(v['b'], 5)
124
125
125 def test_push_function_defaults(self):
126 def test_push_function_defaults(self):
126 """test that pushed functions preserve default args"""
127 """test that pushed functions preserve default args"""
127 def echo(a=10):
128 def echo(a=10):
128 return a
129 return a
129 v = self.client[-1]
130 v = self.client[-1]
130 v.block=True
131 v.block=True
131 v['f'] = echo
132 v['f'] = echo
132 v.execute('b=f()')
133 v.execute('b=f()')
133 self.assertEquals(v['b'], 10)
134 self.assertEquals(v['b'], 10)
134
135
135 def test_get_result(self):
136 def test_get_result(self):
136 """test getting results from the Hub."""
137 """test getting results from the Hub."""
137 c = pmod.Client(profile='iptest')
138 c = pmod.Client(profile='iptest')
138 # self.add_engines(1)
139 # self.add_engines(1)
139 t = c.ids[-1]
140 t = c.ids[-1]
140 v = c[t]
141 v = c[t]
141 v2 = self.client[t]
142 v2 = self.client[t]
142 ar = v.apply_async(wait, 1)
143 ar = v.apply_async(wait, 1)
143 # give the monitor time to notice the message
144 # give the monitor time to notice the message
144 time.sleep(.25)
145 time.sleep(.25)
145 ahr = v2.get_result(ar.msg_ids)
146 ahr = v2.get_result(ar.msg_ids)
146 self.assertTrue(isinstance(ahr, AsyncHubResult))
147 self.assertTrue(isinstance(ahr, AsyncHubResult))
147 self.assertEquals(ahr.get(), ar.get())
148 self.assertEquals(ahr.get(), ar.get())
148 ar2 = v2.get_result(ar.msg_ids)
149 ar2 = v2.get_result(ar.msg_ids)
149 self.assertFalse(isinstance(ar2, AsyncHubResult))
150 self.assertFalse(isinstance(ar2, AsyncHubResult))
150 c.spin()
151 c.spin()
151 c.close()
152 c.close()
152
153
153 def test_run_newline(self):
154 def test_run_newline(self):
154 """test that run appends newline to files"""
155 """test that run appends newline to files"""
155 tmpfile = mktemp()
156 tmpfile = mktemp()
156 with open(tmpfile, 'w') as f:
157 with open(tmpfile, 'w') as f:
157 f.write("""def g():
158 f.write("""def g():
158 return 5
159 return 5
159 """)
160 """)
160 v = self.client[-1]
161 v = self.client[-1]
161 v.run(tmpfile, block=True)
162 v.run(tmpfile, block=True)
162 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
163 self.assertEquals(v.apply_sync(lambda f: f(), pmod.Reference('g')), 5)
163
164
164 def test_apply_tracked(self):
165 def test_apply_tracked(self):
165 """test tracking for apply"""
166 """test tracking for apply"""
166 # self.add_engines(1)
167 # self.add_engines(1)
167 t = self.client.ids[-1]
168 t = self.client.ids[-1]
168 v = self.client[t]
169 v = self.client[t]
169 v.block=False
170 v.block=False
170 def echo(n=1024*1024, **kwargs):
171 def echo(n=1024*1024, **kwargs):
171 with v.temp_flags(**kwargs):
172 with v.temp_flags(**kwargs):
172 return v.apply(lambda x: x, 'x'*n)
173 return v.apply(lambda x: x, 'x'*n)
173 ar = echo(1, track=False)
174 ar = echo(1, track=False)
174 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
175 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
175 self.assertTrue(ar.sent)
176 self.assertTrue(ar.sent)
176 ar = echo(track=True)
177 ar = echo(track=True)
177 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
178 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
178 self.assertEquals(ar.sent, ar._tracker.done)
179 self.assertEquals(ar.sent, ar._tracker.done)
179 ar._tracker.wait()
180 ar._tracker.wait()
180 self.assertTrue(ar.sent)
181 self.assertTrue(ar.sent)
181
182
182 def test_push_tracked(self):
183 def test_push_tracked(self):
183 t = self.client.ids[-1]
184 t = self.client.ids[-1]
184 ns = dict(x='x'*1024*1024)
185 ns = dict(x='x'*1024*1024)
185 v = self.client[t]
186 v = self.client[t]
186 ar = v.push(ns, block=False, track=False)
187 ar = v.push(ns, block=False, track=False)
187 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
188 self.assertTrue(ar.sent)
189 self.assertTrue(ar.sent)
189
190
190 ar = v.push(ns, block=False, track=True)
191 ar = v.push(ns, block=False, track=True)
191 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
192 ar._tracker.wait()
193 ar._tracker.wait()
193 self.assertEquals(ar.sent, ar._tracker.done)
194 self.assertEquals(ar.sent, ar._tracker.done)
194 self.assertTrue(ar.sent)
195 self.assertTrue(ar.sent)
195 ar.get()
196 ar.get()
196
197
197 def test_scatter_tracked(self):
198 def test_scatter_tracked(self):
198 t = self.client.ids
199 t = self.client.ids
199 x='x'*1024*1024
200 x='x'*1024*1024
200 ar = self.client[t].scatter('x', x, block=False, track=False)
201 ar = self.client[t].scatter('x', x, block=False, track=False)
201 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
202 self.assertTrue(ar.sent)
203 self.assertTrue(ar.sent)
203
204
204 ar = self.client[t].scatter('x', x, block=False, track=True)
205 ar = self.client[t].scatter('x', x, block=False, track=True)
205 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 self.assertTrue(isinstance(ar._tracker, zmq.MessageTracker))
206 self.assertEquals(ar.sent, ar._tracker.done)
207 self.assertEquals(ar.sent, ar._tracker.done)
207 ar._tracker.wait()
208 ar._tracker.wait()
208 self.assertTrue(ar.sent)
209 self.assertTrue(ar.sent)
209 ar.get()
210 ar.get()
210
211
211 def test_remote_reference(self):
212 def test_remote_reference(self):
212 v = self.client[-1]
213 v = self.client[-1]
213 v['a'] = 123
214 v['a'] = 123
214 ra = pmod.Reference('a')
215 ra = pmod.Reference('a')
215 b = v.apply_sync(lambda x: x, ra)
216 b = v.apply_sync(lambda x: x, ra)
216 self.assertEquals(b, 123)
217 self.assertEquals(b, 123)
217
218
218
219
219 def test_scatter_gather(self):
220 def test_scatter_gather(self):
220 view = self.client[:]
221 view = self.client[:]
221 seq1 = range(16)
222 seq1 = range(16)
222 view.scatter('a', seq1)
223 view.scatter('a', seq1)
223 seq2 = view.gather('a', block=True)
224 seq2 = view.gather('a', block=True)
224 self.assertEquals(seq2, seq1)
225 self.assertEquals(seq2, seq1)
225 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
226 self.assertRaisesRemote(NameError, view.gather, 'asdf', block=True)
226
227
227 @skip_without('numpy')
228 @skip_without('numpy')
228 def test_scatter_gather_numpy(self):
229 def test_scatter_gather_numpy(self):
229 import numpy
230 import numpy
230 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
231 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
231 view = self.client[:]
232 view = self.client[:]
232 a = numpy.arange(64)
233 a = numpy.arange(64)
233 view.scatter('a', a)
234 view.scatter('a', a)
234 b = view.gather('a', block=True)
235 b = view.gather('a', block=True)
235 assert_array_equal(b, a)
236 assert_array_equal(b, a)
236
237
237 def test_scatter_gather_lazy(self):
238 def test_scatter_gather_lazy(self):
238 """scatter/gather with targets='all'"""
239 """scatter/gather with targets='all'"""
239 view = self.client.direct_view(targets='all')
240 view = self.client.direct_view(targets='all')
240 x = range(64)
241 x = range(64)
241 view.scatter('x', x)
242 view.scatter('x', x)
242 gathered = view.gather('x', block=True)
243 gathered = view.gather('x', block=True)
243 self.assertEquals(gathered, x)
244 self.assertEquals(gathered, x)
244
245
245
246
246 @dec.known_failure_py3
247 @dec.known_failure_py3
247 @skip_without('numpy')
248 @skip_without('numpy')
248 def test_push_numpy_nocopy(self):
249 def test_push_numpy_nocopy(self):
249 import numpy
250 import numpy
250 view = self.client[:]
251 view = self.client[:]
251 a = numpy.arange(64)
252 a = numpy.arange(64)
252 view['A'] = a
253 view['A'] = a
253 @interactive
254 @interactive
254 def check_writeable(x):
255 def check_writeable(x):
255 return x.flags.writeable
256 return x.flags.writeable
256
257
257 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
258 for flag in view.apply_sync(check_writeable, pmod.Reference('A')):
258 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
259 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
259
260
260 view.push(dict(B=a))
261 view.push(dict(B=a))
261 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
262 for flag in view.apply_sync(check_writeable, pmod.Reference('B')):
262 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
263 self.assertFalse(flag, "array is writeable, push shouldn't have pickled it")
263
264
264 @skip_without('numpy')
265 @skip_without('numpy')
265 def test_apply_numpy(self):
266 def test_apply_numpy(self):
266 """view.apply(f, ndarray)"""
267 """view.apply(f, ndarray)"""
267 import numpy
268 import numpy
268 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
269 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
269
270
270 A = numpy.random.random((100,100))
271 A = numpy.random.random((100,100))
271 view = self.client[-1]
272 view = self.client[-1]
272 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
273 for dt in [ 'int32', 'uint8', 'float32', 'float64' ]:
273 B = A.astype(dt)
274 B = A.astype(dt)
274 C = view.apply_sync(lambda x:x, B)
275 C = view.apply_sync(lambda x:x, B)
275 assert_array_equal(B,C)
276 assert_array_equal(B,C)
276
277
277 def test_map(self):
278 def test_map(self):
278 view = self.client[:]
279 view = self.client[:]
279 def f(x):
280 def f(x):
280 return x**2
281 return x**2
281 data = range(16)
282 data = range(16)
282 r = view.map_sync(f, data)
283 r = view.map_sync(f, data)
283 self.assertEquals(r, map(f, data))
284 self.assertEquals(r, map(f, data))
284
285
285 def test_map_iterable(self):
286 def test_map_iterable(self):
286 """test map on iterables (direct)"""
287 """test map on iterables (direct)"""
287 view = self.client[:]
288 view = self.client[:]
288 # 101 is prime, so it won't be evenly distributed
289 # 101 is prime, so it won't be evenly distributed
289 arr = range(101)
290 arr = range(101)
290 # ensure it will be an iterator, even in Python 3
291 # ensure it will be an iterator, even in Python 3
291 it = iter(arr)
292 it = iter(arr)
292 r = view.map_sync(lambda x:x, arr)
293 r = view.map_sync(lambda x:x, arr)
293 self.assertEquals(r, list(arr))
294 self.assertEquals(r, list(arr))
294
295
295 def test_scatterGatherNonblocking(self):
296 def test_scatterGatherNonblocking(self):
296 data = range(16)
297 data = range(16)
297 view = self.client[:]
298 view = self.client[:]
298 view.scatter('a', data, block=False)
299 view.scatter('a', data, block=False)
299 ar = view.gather('a', block=False)
300 ar = view.gather('a', block=False)
300 self.assertEquals(ar.get(), data)
301 self.assertEquals(ar.get(), data)
301
302
302 @skip_without('numpy')
303 @skip_without('numpy')
303 def test_scatter_gather_numpy_nonblocking(self):
304 def test_scatter_gather_numpy_nonblocking(self):
304 import numpy
305 import numpy
305 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
306 from numpy.testing.utils import assert_array_equal, assert_array_almost_equal
306 a = numpy.arange(64)
307 a = numpy.arange(64)
307 view = self.client[:]
308 view = self.client[:]
308 ar = view.scatter('a', a, block=False)
309 ar = view.scatter('a', a, block=False)
309 self.assertTrue(isinstance(ar, AsyncResult))
310 self.assertTrue(isinstance(ar, AsyncResult))
310 amr = view.gather('a', block=False)
311 amr = view.gather('a', block=False)
311 self.assertTrue(isinstance(amr, AsyncMapResult))
312 self.assertTrue(isinstance(amr, AsyncMapResult))
312 assert_array_equal(amr.get(), a)
313 assert_array_equal(amr.get(), a)
313
314
314 def test_execute(self):
315 def test_execute(self):
315 view = self.client[:]
316 view = self.client[:]
316 # self.client.debug=True
317 # self.client.debug=True
317 execute = view.execute
318 execute = view.execute
318 ar = execute('c=30', block=False)
319 ar = execute('c=30', block=False)
319 self.assertTrue(isinstance(ar, AsyncResult))
320 self.assertTrue(isinstance(ar, AsyncResult))
320 ar = execute('d=[0,1,2]', block=False)
321 ar = execute('d=[0,1,2]', block=False)
321 self.client.wait(ar, 1)
322 self.client.wait(ar, 1)
322 self.assertEquals(len(ar.get()), len(self.client))
323 self.assertEquals(len(ar.get()), len(self.client))
323 for c in view['c']:
324 for c in view['c']:
324 self.assertEquals(c, 30)
325 self.assertEquals(c, 30)
325
326
326 def test_abort(self):
327 def test_abort(self):
327 view = self.client[-1]
328 view = self.client[-1]
328 ar = view.execute('import time; time.sleep(1)', block=False)
329 ar = view.execute('import time; time.sleep(1)', block=False)
329 ar2 = view.apply_async(lambda : 2)
330 ar2 = view.apply_async(lambda : 2)
330 ar3 = view.apply_async(lambda : 3)
331 ar3 = view.apply_async(lambda : 3)
331 view.abort(ar2)
332 view.abort(ar2)
332 view.abort(ar3.msg_ids)
333 view.abort(ar3.msg_ids)
333 self.assertRaises(error.TaskAborted, ar2.get)
334 self.assertRaises(error.TaskAborted, ar2.get)
334 self.assertRaises(error.TaskAborted, ar3.get)
335 self.assertRaises(error.TaskAborted, ar3.get)
335
336
336 def test_abort_all(self):
337 def test_abort_all(self):
337 """view.abort() aborts all outstanding tasks"""
338 """view.abort() aborts all outstanding tasks"""
338 view = self.client[-1]
339 view = self.client[-1]
339 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
340 ars = [ view.apply_async(time.sleep, 0.25) for i in range(10) ]
340 view.abort()
341 view.abort()
341 view.wait(timeout=5)
342 view.wait(timeout=5)
342 for ar in ars[5:]:
343 for ar in ars[5:]:
343 self.assertRaises(error.TaskAborted, ar.get)
344 self.assertRaises(error.TaskAborted, ar.get)
344
345
345 def test_temp_flags(self):
346 def test_temp_flags(self):
346 view = self.client[-1]
347 view = self.client[-1]
347 view.block=True
348 view.block=True
348 with view.temp_flags(block=False):
349 with view.temp_flags(block=False):
349 self.assertFalse(view.block)
350 self.assertFalse(view.block)
350 self.assertTrue(view.block)
351 self.assertTrue(view.block)
351
352
352 @dec.known_failure_py3
353 @dec.known_failure_py3
353 def test_importer(self):
354 def test_importer(self):
354 view = self.client[-1]
355 view = self.client[-1]
355 view.clear(block=True)
356 view.clear(block=True)
356 with view.importer:
357 with view.importer:
357 import re
358 import re
358
359
359 @interactive
360 @interactive
360 def findall(pat, s):
361 def findall(pat, s):
361 # this globals() step isn't necessary in real code
362 # this globals() step isn't necessary in real code
362 # only to prevent a closure in the test
363 # only to prevent a closure in the test
363 re = globals()['re']
364 re = globals()['re']
364 return re.findall(pat, s)
365 return re.findall(pat, s)
365
366
366 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
367 self.assertEquals(view.apply_sync(findall, '\w+', 'hello world'), 'hello world'.split())
367
368
368 # parallel magic tests
369 # parallel magic tests
369
370
370 def test_magic_px_blocking(self):
371 def test_magic_px_blocking(self):
371 ip = get_ipython()
372 ip = get_ipython()
372 v = self.client[-1]
373 v = self.client[-1]
373 v.activate()
374 v.activate()
374 v.block=True
375 v.block=True
375
376
376 ip.magic_px('a=5')
377 ip.magic_px('a=5')
377 self.assertEquals(v['a'], 5)
378 self.assertEquals(v['a'], 5)
378 ip.magic_px('a=10')
379 ip.magic_px('a=10')
379 self.assertEquals(v['a'], 10)
380 self.assertEquals(v['a'], 10)
380 sio = StringIO()
381 sio = StringIO()
381 savestdout = sys.stdout
382 savestdout = sys.stdout
382 sys.stdout = sio
383 sys.stdout = sio
383 # just 'print a' worst ~99% of the time, but this ensures that
384 # just 'print a' worst ~99% of the time, but this ensures that
384 # the stdout message has arrived when the result is finished:
385 # the stdout message has arrived when the result is finished:
385 ip.magic_px('import sys,time;print (a); sys.stdout.flush();time.sleep(0.2)')
386 ip.magic_px('import sys,time;print (a); sys.stdout.flush();time.sleep(0.2)')
386 sys.stdout = savestdout
387 sys.stdout = savestdout
387 buf = sio.getvalue()
388 buf = sio.getvalue()
388 self.assertTrue('[stdout:' in buf, buf)
389 self.assertTrue('[stdout:' in buf, buf)
389 self.assertTrue(buf.rstrip().endswith('10'))
390 self.assertTrue(buf.rstrip().endswith('10'))
390 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
391 self.assertRaisesRemote(ZeroDivisionError, ip.magic_px, '1/0')
391
392
392 def test_magic_px_nonblocking(self):
393 def test_magic_px_nonblocking(self):
393 ip = get_ipython()
394 ip = get_ipython()
394 v = self.client[-1]
395 v = self.client[-1]
395 v.activate()
396 v.activate()
396 v.block=False
397 v.block=False
397
398
398 ip.magic_px('a=5')
399 ip.magic_px('a=5')
399 self.assertEquals(v['a'], 5)
400 self.assertEquals(v['a'], 5)
400 ip.magic_px('a=10')
401 ip.magic_px('a=10')
401 self.assertEquals(v['a'], 10)
402 self.assertEquals(v['a'], 10)
402 sio = StringIO()
403 sio = StringIO()
403 savestdout = sys.stdout
404 savestdout = sys.stdout
404 sys.stdout = sio
405 sys.stdout = sio
405 ip.magic_px('print a')
406 ip.magic_px('print a')
406 sys.stdout = savestdout
407 sys.stdout = savestdout
407 buf = sio.getvalue()
408 buf = sio.getvalue()
408 self.assertFalse('[stdout:%i]'%v.targets in buf)
409 self.assertFalse('[stdout:%i]'%v.targets in buf)
409 ip.magic_px('1/0')
410 ip.magic_px('1/0')
410 ar = v.get_result(-1)
411 ar = v.get_result(-1)
411 self.assertRaisesRemote(ZeroDivisionError, ar.get)
412 self.assertRaisesRemote(ZeroDivisionError, ar.get)
412
413
413 def test_magic_autopx_blocking(self):
414 def test_magic_autopx_blocking(self):
414 ip = get_ipython()
415 ip = get_ipython()
415 v = self.client[-1]
416 v = self.client[-1]
416 v.activate()
417 v.activate()
417 v.block=True
418 v.block=True
418
419
419 sio = StringIO()
420 sio = StringIO()
420 savestdout = sys.stdout
421 savestdout = sys.stdout
421 sys.stdout = sio
422 sys.stdout = sio
422 ip.magic_autopx()
423 ip.magic_autopx()
423 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
424 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
424 ip.run_cell('print b')
425 ip.run_cell('b*=2')
426 ip.run_cell('print (b)')
425 ip.run_cell("b/c")
427 ip.run_cell("b/c")
426 ip.run_code(compile('b*=2', '', 'single'))
427 ip.magic_autopx()
428 ip.magic_autopx()
428 sys.stdout = savestdout
429 sys.stdout = savestdout
429 output = sio.getvalue().strip()
430 output = sio.getvalue().strip()
430 self.assertTrue(output.startswith('%autopx enabled'))
431 self.assertTrue(output.startswith('%autopx enabled'))
431 self.assertTrue(output.endswith('%autopx disabled'))
432 self.assertTrue(output.endswith('%autopx disabled'))
432 self.assertTrue('RemoteError: ZeroDivisionError' in output)
433 self.assertTrue('RemoteError: ZeroDivisionError' in output)
433 ar = v.get_result(-2)
434 ar = v.get_result(-1)
434 self.assertEquals(v['a'], 5)
435 self.assertEquals(v['a'], 5)
435 self.assertEquals(v['b'], 20)
436 self.assertEquals(v['b'], 20)
436 self.assertRaisesRemote(ZeroDivisionError, ar.get)
437 self.assertRaisesRemote(ZeroDivisionError, ar.get)
437
438
438 def test_magic_autopx_nonblocking(self):
439 def test_magic_autopx_nonblocking(self):
439 ip = get_ipython()
440 ip = get_ipython()
440 v = self.client[-1]
441 v = self.client[-1]
441 v.activate()
442 v.activate()
442 v.block=False
443 v.block=False
443
444
444 sio = StringIO()
445 sio = StringIO()
445 savestdout = sys.stdout
446 savestdout = sys.stdout
446 sys.stdout = sio
447 sys.stdout = sio
447 ip.magic_autopx()
448 ip.magic_autopx()
448 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
449 ip.run_cell('\n'.join(('a=5','b=10','c=0')))
449 ip.run_cell('print b')
450 ip.run_cell('print (b)')
451 ip.run_cell('import time; time.sleep(0.1)')
450 ip.run_cell("b/c")
452 ip.run_cell("b/c")
451 ip.run_code(compile('b*=2', '', 'single'))
453 ip.run_cell('b*=2')
452 ip.magic_autopx()
454 ip.magic_autopx()
453 sys.stdout = savestdout
455 sys.stdout = savestdout
454 output = sio.getvalue().strip()
456 output = sio.getvalue().strip()
455 self.assertTrue(output.startswith('%autopx enabled'))
457 self.assertTrue(output.startswith('%autopx enabled'))
456 self.assertTrue(output.endswith('%autopx disabled'))
458 self.assertTrue(output.endswith('%autopx disabled'))
457 self.assertFalse('ZeroDivisionError' in output)
459 self.assertFalse('ZeroDivisionError' in output)
458 ar = v.get_result(-2)
460 ar = v.get_result(-2)
459 self.assertEquals(v['a'], 5)
460 self.assertEquals(v['b'], 20)
461 self.assertRaisesRemote(ZeroDivisionError, ar.get)
461 self.assertRaisesRemote(ZeroDivisionError, ar.get)
462 # prevent TaskAborted on pulls, due to ZeroDivisionError
463 time.sleep(0.5)
464 self.assertEquals(v['a'], 5)
465 # b*=2 will not fire, due to abort
466 self.assertEquals(v['b'], 10)
462
467
463 def test_magic_result(self):
468 def test_magic_result(self):
464 ip = get_ipython()
469 ip = get_ipython()
465 v = self.client[-1]
470 v = self.client[-1]
466 v.activate()
471 v.activate()
467 v['a'] = 111
472 v['a'] = 111
468 ra = v['a']
473 ra = v['a']
469
474
470 ar = ip.magic_result()
475 ar = ip.magic_result()
471 self.assertEquals(ar.msg_ids, [v.history[-1]])
476 self.assertEquals(ar.msg_ids, [v.history[-1]])
472 self.assertEquals(ar.get(), 111)
477 self.assertEquals(ar.get(), 111)
473 ar = ip.magic_result('-2')
478 ar = ip.magic_result('-2')
474 self.assertEquals(ar.msg_ids, [v.history[-2]])
479 self.assertEquals(ar.msg_ids, [v.history[-2]])
475
480
476 def test_unicode_execute(self):
481 def test_unicode_execute(self):
477 """test executing unicode strings"""
482 """test executing unicode strings"""
478 v = self.client[-1]
483 v = self.client[-1]
479 v.block=True
484 v.block=True
480 if sys.version_info[0] >= 3:
485 if sys.version_info[0] >= 3:
481 code="a='é'"
486 code="a='é'"
482 else:
487 else:
483 code=u"a=u'é'"
488 code=u"a=u'é'"
484 v.execute(code)
489 v.execute(code)
485 self.assertEquals(v['a'], u'é')
490 self.assertEquals(v['a'], u'é')
486
491
487 def test_unicode_apply_result(self):
492 def test_unicode_apply_result(self):
488 """test unicode apply results"""
493 """test unicode apply results"""
489 v = self.client[-1]
494 v = self.client[-1]
490 r = v.apply_sync(lambda : u'é')
495 r = v.apply_sync(lambda : u'é')
491 self.assertEquals(r, u'é')
496 self.assertEquals(r, u'é')
492
497
493 def test_unicode_apply_arg(self):
498 def test_unicode_apply_arg(self):
494 """test passing unicode arguments to apply"""
499 """test passing unicode arguments to apply"""
495 v = self.client[-1]
500 v = self.client[-1]
496
501
497 @interactive
502 @interactive
498 def check_unicode(a, check):
503 def check_unicode(a, check):
499 assert isinstance(a, unicode), "%r is not unicode"%a
504 assert isinstance(a, unicode), "%r is not unicode"%a
500 assert isinstance(check, bytes), "%r is not bytes"%check
505 assert isinstance(check, bytes), "%r is not bytes"%check
501 assert a.encode('utf8') == check, "%s != %s"%(a,check)
506 assert a.encode('utf8') == check, "%s != %s"%(a,check)
502
507
503 for s in [ u'é', u'ßø®∫',u'asdf' ]:
508 for s in [ u'é', u'ßø®∫',u'asdf' ]:
504 try:
509 try:
505 v.apply_sync(check_unicode, s, s.encode('utf8'))
510 v.apply_sync(check_unicode, s, s.encode('utf8'))
506 except error.RemoteError as e:
511 except error.RemoteError as e:
507 if e.ename == 'AssertionError':
512 if e.ename == 'AssertionError':
508 self.fail(e.evalue)
513 self.fail(e.evalue)
509 else:
514 else:
510 raise e
515 raise e
511
516
512 def test_map_reference(self):
517 def test_map_reference(self):
513 """view.map(<Reference>, *seqs) should work"""
518 """view.map(<Reference>, *seqs) should work"""
514 v = self.client[:]
519 v = self.client[:]
515 v.scatter('n', self.client.ids, flatten=True)
520 v.scatter('n', self.client.ids, flatten=True)
516 v.execute("f = lambda x,y: x*y")
521 v.execute("f = lambda x,y: x*y")
517 rf = pmod.Reference('f')
522 rf = pmod.Reference('f')
518 nlist = list(range(10))
523 nlist = list(range(10))
519 mlist = nlist[::-1]
524 mlist = nlist[::-1]
520 expected = [ m*n for m,n in zip(mlist, nlist) ]
525 expected = [ m*n for m,n in zip(mlist, nlist) ]
521 result = v.map_sync(rf, mlist, nlist)
526 result = v.map_sync(rf, mlist, nlist)
522 self.assertEquals(result, expected)
527 self.assertEquals(result, expected)
523
528
524 def test_apply_reference(self):
529 def test_apply_reference(self):
525 """view.apply(<Reference>, *args) should work"""
530 """view.apply(<Reference>, *args) should work"""
526 v = self.client[:]
531 v = self.client[:]
527 v.scatter('n', self.client.ids, flatten=True)
532 v.scatter('n', self.client.ids, flatten=True)
528 v.execute("f = lambda x: n*x")
533 v.execute("f = lambda x: n*x")
529 rf = pmod.Reference('f')
534 rf = pmod.Reference('f')
530 result = v.apply_sync(rf, 5)
535 result = v.apply_sync(rf, 5)
531 expected = [ 5*id for id in self.client.ids ]
536 expected = [ 5*id for id in self.client.ids ]
532 self.assertEquals(result, expected)
537 self.assertEquals(result, expected)
533
538
534 def test_eval_reference(self):
539 def test_eval_reference(self):
535 v = self.client[self.client.ids[0]]
540 v = self.client[self.client.ids[0]]
536 v['g'] = range(5)
541 v['g'] = range(5)
537 rg = pmod.Reference('g[0]')
542 rg = pmod.Reference('g[0]')
538 echo = lambda x:x
543 echo = lambda x:x
539 self.assertEquals(v.apply_sync(echo, rg), 0)
544 self.assertEquals(v.apply_sync(echo, rg), 0)
540
545
541 def test_reference_nameerror(self):
546 def test_reference_nameerror(self):
542 v = self.client[self.client.ids[0]]
547 v = self.client[self.client.ids[0]]
543 r = pmod.Reference('elvis_has_left')
548 r = pmod.Reference('elvis_has_left')
544 echo = lambda x:x
549 echo = lambda x:x
545 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
550 self.assertRaisesRemote(NameError, v.apply_sync, echo, r)
546
551
547 def test_single_engine_map(self):
552 def test_single_engine_map(self):
548 e0 = self.client[self.client.ids[0]]
553 e0 = self.client[self.client.ids[0]]
549 r = range(5)
554 r = range(5)
550 check = [ -1*i for i in r ]
555 check = [ -1*i for i in r ]
551 result = e0.map_sync(lambda x: -1*x, r)
556 result = e0.map_sync(lambda x: -1*x, r)
552 self.assertEquals(result, check)
557 self.assertEquals(result, check)
558
559 def test_len(self):
560 """len(view) makes sense"""
561 e0 = self.client[self.client.ids[0]]
562 yield self.assertEquals(len(e0), 1)
563 v = self.client[:]
564 yield self.assertEquals(len(v), len(self.client.ids))
565 v = self.client.direct_view('all')
566 yield self.assertEquals(len(v), len(self.client.ids))
567 v = self.client[:2]
568 yield self.assertEquals(len(v), 2)
569 v = self.client[:1]
570 yield self.assertEquals(len(v), 1)
571 v = self.client.load_balanced_view()
572 yield self.assertEquals(len(v), len(self.client.ids))
573 # parametric tests seem to require manual closing?
574 self.client.close()
575
576
577 # begin execute tests
578 def _wait_for(self, f, timeout=10):
579 tic = time.time()
580 while time.time() <= tic + timeout:
581 if f():
582 return
583 time.sleep(0.1)
584 self.client.spin()
585 if not f():
586 print "Warning: Awaited condition never arrived"
587
588
589 def test_execute_reply(self):
590 e0 = self.client[self.client.ids[0]]
591 e0.block = True
592 ar = e0.execute("5", silent=False)
593 er = ar.get()
594 self._wait_for(lambda : bool(er.pyout))
595 self.assertEquals(str(er), "<ExecuteReply[%i]: 5>" % er.execution_count)
596 self.assertEquals(er.pyout['data']['text/plain'], '5')
597
598 def test_execute_reply_stdout(self):
599 e0 = self.client[self.client.ids[0]]
600 e0.block = True
601 ar = e0.execute("print (5)", silent=False)
602 er = ar.get()
603 self._wait_for(lambda : bool(er.stdout))
604 self.assertEquals(er.stdout.strip(), '5')
605
606 def test_execute_pyout(self):
607 """execute triggers pyout with silent=False"""
608 view = self.client[:]
609 ar = view.execute("5", silent=False, block=True)
610 self._wait_for(lambda : all(ar.pyout))
611
612 expected = [{'text/plain' : '5'}] * len(view)
613 mimes = [ out['data'] for out in ar.pyout ]
614 self.assertEquals(mimes, expected)
615
616 def test_execute_silent(self):
617 """execute does not trigger pyout with silent=True"""
618 view = self.client[:]
619 ar = view.execute("5", block=True)
620 expected = [None] * len(view)
621 self.assertEquals(ar.pyout, expected)
622
623 def test_execute_magic(self):
624 """execute accepts IPython commands"""
625 view = self.client[:]
626 view.execute("a = 5")
627 ar = view.execute("%whos", block=True)
628 # this will raise, if that failed
629 ar.get(5)
630 self._wait_for(lambda : all(ar.stdout))
631 for stdout in ar.stdout:
632 lines = stdout.splitlines()
633 self.assertEquals(lines[0].split(), ['Variable', 'Type', 'Data/Info'])
634 found = False
635 for line in lines[2:]:
636 split = line.split()
637 if split == ['a', 'int', '5']:
638 found = True
639 break
640 self.assertTrue(found, "whos output wrong: %s" % stdout)
641
642 def test_execute_displaypub(self):
643 """execute tracks display_pub output"""
644 view = self.client[:]
645 view.execute("from IPython.core.display import *")
646 ar = view.execute("[ display(i) for i in range(5) ]", block=True)
647
648 self._wait_for(lambda : all(len(er.outputs) >= 5 for er in ar))
649 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
650 for outputs in ar.outputs:
651 mimes = [ out['data'] for out in outputs ]
652 self.assertEquals(mimes, expected)
653
654 def test_apply_displaypub(self):
655 """apply tracks display_pub output"""
656 view = self.client[:]
657 view.execute("from IPython.core.display import *")
658
659 @interactive
660 def publish():
661 [ display(i) for i in range(5) ]
662
663 ar = view.apply_async(publish)
664 ar.get(5)
665 self._wait_for(lambda : all(len(out) >= 5 for out in ar.outputs))
666 expected = [ {u'text/plain' : unicode(j)} for j in range(5) ]
667 for outputs in ar.outputs:
668 mimes = [ out['data'] for out in outputs ]
669 self.assertEquals(mimes, expected)
670
671 def test_execute_raises(self):
672 """exceptions in execute requests raise appropriately"""
673 view = self.client[-1]
674 ar = view.execute("1/0")
675 self.assertRaisesRemote(ZeroDivisionError, ar.get, 2)
676
677 @dec.skipif_not_matplotlib
678 def test_magic_pylab(self):
679 """%pylab works on engines"""
680 view = self.client[-1]
681 ar = view.execute("%pylab inline")
682 # at least check if this raised:
683 reply = ar.get(5)
684 # include imports, in case user config
685 ar = view.execute("plot(rand(100))", silent=False)
686 reply = ar.get(5)
687 self._wait_for(lambda : all(ar.outputs))
688 self.assertEquals(len(reply.outputs), 1)
689 output = reply.outputs[0]
690 self.assertTrue("data" in output)
691 data = output['data']
692 self.assertTrue("image/png" in data)
693
553
694
@@ -1,495 +1,358 b''
1 """some generic utilities for dealing with classes, urls, and serialization
1 """some generic utilities for dealing with classes, urls, and serialization
2
2
3 Authors:
3 Authors:
4
4
5 * Min RK
5 * Min RK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2010-2011 The IPython Development Team
8 # Copyright (C) 2010-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 # Standard library imports.
18 # Standard library imports.
19 import logging
19 import logging
20 import os
20 import os
21 import re
21 import re
22 import stat
22 import stat
23 import socket
23 import socket
24 import sys
24 import sys
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 try:
26 try:
27 from signal import SIGKILL
27 from signal import SIGKILL
28 except ImportError:
28 except ImportError:
29 SIGKILL=None
29 SIGKILL=None
30
30
31 try:
31 try:
32 import cPickle
32 import cPickle
33 pickle = cPickle
33 pickle = cPickle
34 except:
34 except:
35 cPickle = None
35 cPickle = None
36 import pickle
36 import pickle
37
37
38 # System library imports
38 # System library imports
39 import zmq
39 import zmq
40 from zmq.log import handlers
40 from zmq.log import handlers
41
41
42 from IPython.external.decorator import decorator
42 from IPython.external.decorator import decorator
43
43
44 # IPython imports
44 # IPython imports
45 from IPython.config.application import Application
45 from IPython.config.application import Application
46 from IPython.utils import py3compat
46 from IPython.utils import py3compat
47 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
47 from IPython.utils.pickleutil import can, uncan, canSequence, uncanSequence
48 from IPython.utils.newserialized import serialize, unserialize
48 from IPython.utils.newserialized import serialize, unserialize
49 from IPython.zmq.log import EnginePUBHandler
49 from IPython.zmq.log import EnginePUBHandler
50 from IPython.zmq.serialize import (
51 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
52 )
50
53
51 if py3compat.PY3:
54 if py3compat.PY3:
52 buffer = memoryview
55 buffer = memoryview
53
56
54 #-----------------------------------------------------------------------------
57 #-----------------------------------------------------------------------------
55 # Classes
58 # Classes
56 #-----------------------------------------------------------------------------
59 #-----------------------------------------------------------------------------
57
60
58 class Namespace(dict):
61 class Namespace(dict):
59 """Subclass of dict for attribute access to keys."""
62 """Subclass of dict for attribute access to keys."""
60
63
61 def __getattr__(self, key):
64 def __getattr__(self, key):
62 """getattr aliased to getitem"""
65 """getattr aliased to getitem"""
63 if key in self.iterkeys():
66 if key in self.iterkeys():
64 return self[key]
67 return self[key]
65 else:
68 else:
66 raise NameError(key)
69 raise NameError(key)
67
70
68 def __setattr__(self, key, value):
71 def __setattr__(self, key, value):
69 """setattr aliased to setitem, with strict"""
72 """setattr aliased to setitem, with strict"""
70 if hasattr(dict, key):
73 if hasattr(dict, key):
71 raise KeyError("Cannot override dict keys %r"%key)
74 raise KeyError("Cannot override dict keys %r"%key)
72 self[key] = value
75 self[key] = value
73
76
74
77
75 class ReverseDict(dict):
78 class ReverseDict(dict):
76 """simple double-keyed subset of dict methods."""
79 """simple double-keyed subset of dict methods."""
77
80
78 def __init__(self, *args, **kwargs):
81 def __init__(self, *args, **kwargs):
79 dict.__init__(self, *args, **kwargs)
82 dict.__init__(self, *args, **kwargs)
80 self._reverse = dict()
83 self._reverse = dict()
81 for key, value in self.iteritems():
84 for key, value in self.iteritems():
82 self._reverse[value] = key
85 self._reverse[value] = key
83
86
84 def __getitem__(self, key):
87 def __getitem__(self, key):
85 try:
88 try:
86 return dict.__getitem__(self, key)
89 return dict.__getitem__(self, key)
87 except KeyError:
90 except KeyError:
88 return self._reverse[key]
91 return self._reverse[key]
89
92
90 def __setitem__(self, key, value):
93 def __setitem__(self, key, value):
91 if key in self._reverse:
94 if key in self._reverse:
92 raise KeyError("Can't have key %r on both sides!"%key)
95 raise KeyError("Can't have key %r on both sides!"%key)
93 dict.__setitem__(self, key, value)
96 dict.__setitem__(self, key, value)
94 self._reverse[value] = key
97 self._reverse[value] = key
95
98
96 def pop(self, key):
99 def pop(self, key):
97 value = dict.pop(self, key)
100 value = dict.pop(self, key)
98 self._reverse.pop(value)
101 self._reverse.pop(value)
99 return value
102 return value
100
103
101 def get(self, key, default=None):
104 def get(self, key, default=None):
102 try:
105 try:
103 return self[key]
106 return self[key]
104 except KeyError:
107 except KeyError:
105 return default
108 return default
106
109
107 #-----------------------------------------------------------------------------
110 #-----------------------------------------------------------------------------
108 # Functions
111 # Functions
109 #-----------------------------------------------------------------------------
112 #-----------------------------------------------------------------------------
110
113
111 @decorator
114 @decorator
112 def log_errors(f, self, *args, **kwargs):
115 def log_errors(f, self, *args, **kwargs):
113 """decorator to log unhandled exceptions raised in a method.
116 """decorator to log unhandled exceptions raised in a method.
114
117
115 For use wrapping on_recv callbacks, so that exceptions
118 For use wrapping on_recv callbacks, so that exceptions
116 do not cause the stream to be closed.
119 do not cause the stream to be closed.
117 """
120 """
118 try:
121 try:
119 return f(self, *args, **kwargs)
122 return f(self, *args, **kwargs)
120 except Exception:
123 except Exception:
121 self.log.error("Uncaught exception in %r" % f, exc_info=True)
124 self.log.error("Uncaught exception in %r" % f, exc_info=True)
122
125
123
126
124 def asbytes(s):
125 """ensure that an object is ascii bytes"""
126 if isinstance(s, unicode):
127 s = s.encode('ascii')
128 return s
129
130 def is_url(url):
127 def is_url(url):
131 """boolean check for whether a string is a zmq url"""
128 """boolean check for whether a string is a zmq url"""
132 if '://' not in url:
129 if '://' not in url:
133 return False
130 return False
134 proto, addr = url.split('://', 1)
131 proto, addr = url.split('://', 1)
135 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
132 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
136 return False
133 return False
137 return True
134 return True
138
135
139 def validate_url(url):
136 def validate_url(url):
140 """validate a url for zeromq"""
137 """validate a url for zeromq"""
141 if not isinstance(url, basestring):
138 if not isinstance(url, basestring):
142 raise TypeError("url must be a string, not %r"%type(url))
139 raise TypeError("url must be a string, not %r"%type(url))
143 url = url.lower()
140 url = url.lower()
144
141
145 proto_addr = url.split('://')
142 proto_addr = url.split('://')
146 assert len(proto_addr) == 2, 'Invalid url: %r'%url
143 assert len(proto_addr) == 2, 'Invalid url: %r'%url
147 proto, addr = proto_addr
144 proto, addr = proto_addr
148 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
145 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
149
146
150 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
147 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
151 # author: Remi Sabourin
148 # author: Remi Sabourin
152 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
149 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
153
150
154 if proto == 'tcp':
151 if proto == 'tcp':
155 lis = addr.split(':')
152 lis = addr.split(':')
156 assert len(lis) == 2, 'Invalid url: %r'%url
153 assert len(lis) == 2, 'Invalid url: %r'%url
157 addr,s_port = lis
154 addr,s_port = lis
158 try:
155 try:
159 port = int(s_port)
156 port = int(s_port)
160 except ValueError:
157 except ValueError:
161 raise AssertionError("Invalid port %r in url: %r"%(port, url))
158 raise AssertionError("Invalid port %r in url: %r"%(port, url))
162
159
163 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
160 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
164
161
165 else:
162 else:
166 # only validate tcp urls currently
163 # only validate tcp urls currently
167 pass
164 pass
168
165
169 return True
166 return True
170
167
171
168
172 def validate_url_container(container):
169 def validate_url_container(container):
173 """validate a potentially nested collection of urls."""
170 """validate a potentially nested collection of urls."""
174 if isinstance(container, basestring):
171 if isinstance(container, basestring):
175 url = container
172 url = container
176 return validate_url(url)
173 return validate_url(url)
177 elif isinstance(container, dict):
174 elif isinstance(container, dict):
178 container = container.itervalues()
175 container = container.itervalues()
179
176
180 for element in container:
177 for element in container:
181 validate_url_container(element)
178 validate_url_container(element)
182
179
183
180
184 def split_url(url):
181 def split_url(url):
185 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
182 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
186 proto_addr = url.split('://')
183 proto_addr = url.split('://')
187 assert len(proto_addr) == 2, 'Invalid url: %r'%url
184 assert len(proto_addr) == 2, 'Invalid url: %r'%url
188 proto, addr = proto_addr
185 proto, addr = proto_addr
189 lis = addr.split(':')
186 lis = addr.split(':')
190 assert len(lis) == 2, 'Invalid url: %r'%url
187 assert len(lis) == 2, 'Invalid url: %r'%url
191 addr,s_port = lis
188 addr,s_port = lis
192 return proto,addr,s_port
189 return proto,addr,s_port
193
190
194 def disambiguate_ip_address(ip, location=None):
191 def disambiguate_ip_address(ip, location=None):
195 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
192 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
196 ones, based on the location (default interpretation of location is localhost)."""
193 ones, based on the location (default interpretation of location is localhost)."""
197 if ip in ('0.0.0.0', '*'):
194 if ip in ('0.0.0.0', '*'):
198 try:
195 try:
199 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
196 external_ips = socket.gethostbyname_ex(socket.gethostname())[2]
200 except (socket.gaierror, IndexError):
197 except (socket.gaierror, IndexError):
201 # couldn't identify this machine, assume localhost
198 # couldn't identify this machine, assume localhost
202 external_ips = []
199 external_ips = []
203 if location is None or location in external_ips or not external_ips:
200 if location is None or location in external_ips or not external_ips:
204 # If location is unspecified or cannot be determined, assume local
201 # If location is unspecified or cannot be determined, assume local
205 ip='127.0.0.1'
202 ip='127.0.0.1'
206 elif location:
203 elif location:
207 return location
204 return location
208 return ip
205 return ip
209
206
210 def disambiguate_url(url, location=None):
207 def disambiguate_url(url, location=None):
211 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
208 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
212 ones, based on the location (default interpretation is localhost).
209 ones, based on the location (default interpretation is localhost).
213
210
214 This is for zeromq urls, such as tcp://*:10101."""
211 This is for zeromq urls, such as tcp://*:10101."""
215 try:
212 try:
216 proto,ip,port = split_url(url)
213 proto,ip,port = split_url(url)
217 except AssertionError:
214 except AssertionError:
218 # probably not tcp url; could be ipc, etc.
215 # probably not tcp url; could be ipc, etc.
219 return url
216 return url
220
217
221 ip = disambiguate_ip_address(ip,location)
218 ip = disambiguate_ip_address(ip,location)
222
219
223 return "%s://%s:%s"%(proto,ip,port)
220 return "%s://%s:%s"%(proto,ip,port)
224
221
225 def serialize_object(obj, threshold=64e-6):
226 """Serialize an object into a list of sendable buffers.
227
228 Parameters
229 ----------
230
231 obj : object
232 The object to be serialized
233 threshold : float
234 The threshold for not double-pickling the content.
235
236
237 Returns
238 -------
239 ('pmd', [bufs]) :
240 where pmd is the pickled metadata wrapper,
241 bufs is a list of data buffers
242 """
243 databuffers = []
244 if isinstance(obj, (list, tuple)):
245 clist = canSequence(obj)
246 slist = map(serialize, clist)
247 for s in slist:
248 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
249 databuffers.append(s.getData())
250 s.data = None
251 return pickle.dumps(slist,-1), databuffers
252 elif isinstance(obj, dict):
253 sobj = {}
254 for k in sorted(obj.iterkeys()):
255 s = serialize(can(obj[k]))
256 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
257 databuffers.append(s.getData())
258 s.data = None
259 sobj[k] = s
260 return pickle.dumps(sobj,-1),databuffers
261 else:
262 s = serialize(can(obj))
263 if s.typeDescriptor in ('buffer', 'ndarray') or s.getDataSize() > threshold:
264 databuffers.append(s.getData())
265 s.data = None
266 return pickle.dumps(s,-1),databuffers
267
268
269 def unserialize_object(bufs):
270 """reconstruct an object serialized by serialize_object from data buffers."""
271 bufs = list(bufs)
272 sobj = pickle.loads(bufs.pop(0))
273 if isinstance(sobj, (list, tuple)):
274 for s in sobj:
275 if s.data is None:
276 s.data = bufs.pop(0)
277 return uncanSequence(map(unserialize, sobj)), bufs
278 elif isinstance(sobj, dict):
279 newobj = {}
280 for k in sorted(sobj.iterkeys()):
281 s = sobj[k]
282 if s.data is None:
283 s.data = bufs.pop(0)
284 newobj[k] = uncan(unserialize(s))
285 return newobj, bufs
286 else:
287 if sobj.data is None:
288 sobj.data = bufs.pop(0)
289 return uncan(unserialize(sobj)), bufs
290
291 def pack_apply_message(f, args, kwargs, threshold=64e-6):
292 """pack up a function, args, and kwargs to be sent over the wire
293 as a series of buffers. Any object whose data is larger than `threshold`
294 will not have their data copied (currently only numpy arrays support zero-copy)"""
295 msg = [pickle.dumps(can(f),-1)]
296 databuffers = [] # for large objects
297 sargs, bufs = serialize_object(args,threshold)
298 msg.append(sargs)
299 databuffers.extend(bufs)
300 skwargs, bufs = serialize_object(kwargs,threshold)
301 msg.append(skwargs)
302 databuffers.extend(bufs)
303 msg.extend(databuffers)
304 return msg
305
306 def unpack_apply_message(bufs, g=None, copy=True):
307 """unpack f,args,kwargs from buffers packed by pack_apply_message()
308 Returns: original f,args,kwargs"""
309 bufs = list(bufs) # allow us to pop
310 assert len(bufs) >= 3, "not enough buffers!"
311 if not copy:
312 for i in range(3):
313 bufs[i] = bufs[i].bytes
314 cf = pickle.loads(bufs.pop(0))
315 sargs = list(pickle.loads(bufs.pop(0)))
316 skwargs = dict(pickle.loads(bufs.pop(0)))
317 # print sargs, skwargs
318 f = uncan(cf, g)
319 for sa in sargs:
320 if sa.data is None:
321 m = bufs.pop(0)
322 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
323 # always use a buffer, until memoryviews get sorted out
324 sa.data = buffer(m)
325 # disable memoryview support
326 # if copy:
327 # sa.data = buffer(m)
328 # else:
329 # sa.data = m.buffer
330 else:
331 if copy:
332 sa.data = m
333 else:
334 sa.data = m.bytes
335
336 args = uncanSequence(map(unserialize, sargs), g)
337 kwargs = {}
338 for k in sorted(skwargs.iterkeys()):
339 sa = skwargs[k]
340 if sa.data is None:
341 m = bufs.pop(0)
342 if sa.getTypeDescriptor() in ('buffer', 'ndarray'):
343 # always use a buffer, until memoryviews get sorted out
344 sa.data = buffer(m)
345 # disable memoryview support
346 # if copy:
347 # sa.data = buffer(m)
348 # else:
349 # sa.data = m.buffer
350 else:
351 if copy:
352 sa.data = m
353 else:
354 sa.data = m.bytes
355
356 kwargs[k] = uncan(unserialize(sa), g)
357
358 return f,args,kwargs
359
222
360 #--------------------------------------------------------------------------
223 #--------------------------------------------------------------------------
361 # helpers for implementing old MEC API via view.apply
224 # helpers for implementing old MEC API via view.apply
362 #--------------------------------------------------------------------------
225 #--------------------------------------------------------------------------
363
226
364 def interactive(f):
227 def interactive(f):
365 """decorator for making functions appear as interactively defined.
228 """decorator for making functions appear as interactively defined.
366 This results in the function being linked to the user_ns as globals()
229 This results in the function being linked to the user_ns as globals()
367 instead of the module globals().
230 instead of the module globals().
368 """
231 """
369 f.__module__ = '__main__'
232 f.__module__ = '__main__'
370 return f
233 return f
371
234
372 @interactive
235 @interactive
373 def _push(**ns):
236 def _push(**ns):
374 """helper method for implementing `client.push` via `client.apply`"""
237 """helper method for implementing `client.push` via `client.apply`"""
375 globals().update(ns)
238 globals().update(ns)
376
239
377 @interactive
240 @interactive
378 def _pull(keys):
241 def _pull(keys):
379 """helper method for implementing `client.pull` via `client.apply`"""
242 """helper method for implementing `client.pull` via `client.apply`"""
380 user_ns = globals()
243 user_ns = globals()
381 if isinstance(keys, (list,tuple, set)):
244 if isinstance(keys, (list,tuple, set)):
382 for key in keys:
245 for key in keys:
383 if not user_ns.has_key(key):
246 if not user_ns.has_key(key):
384 raise NameError("name '%s' is not defined"%key)
247 raise NameError("name '%s' is not defined"%key)
385 return map(user_ns.get, keys)
248 return map(user_ns.get, keys)
386 else:
249 else:
387 if not user_ns.has_key(keys):
250 if not user_ns.has_key(keys):
388 raise NameError("name '%s' is not defined"%keys)
251 raise NameError("name '%s' is not defined"%keys)
389 return user_ns.get(keys)
252 return user_ns.get(keys)
390
253
391 @interactive
254 @interactive
392 def _execute(code):
255 def _execute(code):
393 """helper method for implementing `client.execute` via `client.apply`"""
256 """helper method for implementing `client.execute` via `client.apply`"""
394 exec code in globals()
257 exec code in globals()
395
258
396 #--------------------------------------------------------------------------
259 #--------------------------------------------------------------------------
397 # extra process management utilities
260 # extra process management utilities
398 #--------------------------------------------------------------------------
261 #--------------------------------------------------------------------------
399
262
400 _random_ports = set()
263 _random_ports = set()
401
264
402 def select_random_ports(n):
265 def select_random_ports(n):
403 """Selects and return n random ports that are available."""
266 """Selects and return n random ports that are available."""
404 ports = []
267 ports = []
405 for i in xrange(n):
268 for i in xrange(n):
406 sock = socket.socket()
269 sock = socket.socket()
407 sock.bind(('', 0))
270 sock.bind(('', 0))
408 while sock.getsockname()[1] in _random_ports:
271 while sock.getsockname()[1] in _random_ports:
409 sock.close()
272 sock.close()
410 sock = socket.socket()
273 sock = socket.socket()
411 sock.bind(('', 0))
274 sock.bind(('', 0))
412 ports.append(sock)
275 ports.append(sock)
413 for i, sock in enumerate(ports):
276 for i, sock in enumerate(ports):
414 port = sock.getsockname()[1]
277 port = sock.getsockname()[1]
415 sock.close()
278 sock.close()
416 ports[i] = port
279 ports[i] = port
417 _random_ports.add(port)
280 _random_ports.add(port)
418 return ports
281 return ports
419
282
420 def signal_children(children):
283 def signal_children(children):
421 """Relay interupt/term signals to children, for more solid process cleanup."""
284 """Relay interupt/term signals to children, for more solid process cleanup."""
422 def terminate_children(sig, frame):
285 def terminate_children(sig, frame):
423 log = Application.instance().log
286 log = Application.instance().log
424 log.critical("Got signal %i, terminating children..."%sig)
287 log.critical("Got signal %i, terminating children..."%sig)
425 for child in children:
288 for child in children:
426 child.terminate()
289 child.terminate()
427
290
428 sys.exit(sig != SIGINT)
291 sys.exit(sig != SIGINT)
429 # sys.exit(sig)
292 # sys.exit(sig)
430 for sig in (SIGINT, SIGABRT, SIGTERM):
293 for sig in (SIGINT, SIGABRT, SIGTERM):
431 signal(sig, terminate_children)
294 signal(sig, terminate_children)
432
295
433 def generate_exec_key(keyfile):
296 def generate_exec_key(keyfile):
434 import uuid
297 import uuid
435 newkey = str(uuid.uuid4())
298 newkey = str(uuid.uuid4())
436 with open(keyfile, 'w') as f:
299 with open(keyfile, 'w') as f:
437 # f.write('ipython-key ')
300 # f.write('ipython-key ')
438 f.write(newkey+'\n')
301 f.write(newkey+'\n')
439 # set user-only RW permissions (0600)
302 # set user-only RW permissions (0600)
440 # this will have no effect on Windows
303 # this will have no effect on Windows
441 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
304 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
442
305
443
306
444 def integer_loglevel(loglevel):
307 def integer_loglevel(loglevel):
445 try:
308 try:
446 loglevel = int(loglevel)
309 loglevel = int(loglevel)
447 except ValueError:
310 except ValueError:
448 if isinstance(loglevel, str):
311 if isinstance(loglevel, str):
449 loglevel = getattr(logging, loglevel)
312 loglevel = getattr(logging, loglevel)
450 return loglevel
313 return loglevel
451
314
452 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
315 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
453 logger = logging.getLogger(logname)
316 logger = logging.getLogger(logname)
454 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
317 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
455 # don't add a second PUBHandler
318 # don't add a second PUBHandler
456 return
319 return
457 loglevel = integer_loglevel(loglevel)
320 loglevel = integer_loglevel(loglevel)
458 lsock = context.socket(zmq.PUB)
321 lsock = context.socket(zmq.PUB)
459 lsock.connect(iface)
322 lsock.connect(iface)
460 handler = handlers.PUBHandler(lsock)
323 handler = handlers.PUBHandler(lsock)
461 handler.setLevel(loglevel)
324 handler.setLevel(loglevel)
462 handler.root_topic = root
325 handler.root_topic = root
463 logger.addHandler(handler)
326 logger.addHandler(handler)
464 logger.setLevel(loglevel)
327 logger.setLevel(loglevel)
465
328
466 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
329 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
467 logger = logging.getLogger()
330 logger = logging.getLogger()
468 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
331 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
469 # don't add a second PUBHandler
332 # don't add a second PUBHandler
470 return
333 return
471 loglevel = integer_loglevel(loglevel)
334 loglevel = integer_loglevel(loglevel)
472 lsock = context.socket(zmq.PUB)
335 lsock = context.socket(zmq.PUB)
473 lsock.connect(iface)
336 lsock.connect(iface)
474 handler = EnginePUBHandler(engine, lsock)
337 handler = EnginePUBHandler(engine, lsock)
475 handler.setLevel(loglevel)
338 handler.setLevel(loglevel)
476 logger.addHandler(handler)
339 logger.addHandler(handler)
477 logger.setLevel(loglevel)
340 logger.setLevel(loglevel)
478 return logger
341 return logger
479
342
480 def local_logger(logname, loglevel=logging.DEBUG):
343 def local_logger(logname, loglevel=logging.DEBUG):
481 loglevel = integer_loglevel(loglevel)
344 loglevel = integer_loglevel(loglevel)
482 logger = logging.getLogger(logname)
345 logger = logging.getLogger(logname)
483 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
346 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
484 # don't add a second StreamHandler
347 # don't add a second StreamHandler
485 return
348 return
486 handler = logging.StreamHandler()
349 handler = logging.StreamHandler()
487 handler.setLevel(loglevel)
350 handler.setLevel(loglevel)
488 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
351 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
489 datefmt="%Y-%m-%d %H:%M:%S")
352 datefmt="%Y-%m-%d %H:%M:%S")
490 handler.setFormatter(formatter)
353 handler.setFormatter(formatter)
491
354
492 logger.addHandler(handler)
355 logger.addHandler(handler)
493 logger.setLevel(loglevel)
356 logger.setLevel(loglevel)
494 return logger
357 return logger
495
358
@@ -1,72 +1,73 b''
1 import __builtin__
1 import __builtin__
2 import sys
2 import sys
3 from base64 import encodestring
3 from base64 import encodestring
4
4
5 from IPython.core.displayhook import DisplayHook
5 from IPython.core.displayhook import DisplayHook
6 from IPython.utils.traitlets import Instance, Dict
6 from IPython.utils.traitlets import Instance, Dict
7 from session import extract_header, Session
7 from session import extract_header, Session
8
8
9 class ZMQDisplayHook(object):
9 class ZMQDisplayHook(object):
10 """A simple displayhook that publishes the object's repr over a ZeroMQ
10 """A simple displayhook that publishes the object's repr over a ZeroMQ
11 socket."""
11 socket."""
12 topic=None
12 topic=None
13
13
14 def __init__(self, session, pub_socket):
14 def __init__(self, session, pub_socket):
15 self.session = session
15 self.session = session
16 self.pub_socket = pub_socket
16 self.pub_socket = pub_socket
17 self.parent_header = {}
17 self.parent_header = {}
18
18
19 def __call__(self, obj):
19 def __call__(self, obj):
20 if obj is None:
20 if obj is None:
21 return
21 return
22
22
23 __builtin__._ = obj
23 __builtin__._ = obj
24 sys.stdout.flush()
24 sys.stdout.flush()
25 sys.stderr.flush()
25 sys.stderr.flush()
26 msg = self.session.send(self.pub_socket, u'pyout', {u'data':repr(obj)},
26 msg = self.session.send(self.pub_socket, u'pyout', {u'data':repr(obj)},
27 parent=self.parent_header, ident=self.topic)
27 parent=self.parent_header, ident=self.topic)
28
28
29 def set_parent(self, parent):
29 def set_parent(self, parent):
30 self.parent_header = extract_header(parent)
30 self.parent_header = extract_header(parent)
31
31
32
32
33 def _encode_binary(format_dict):
33 def _encode_binary(format_dict):
34 pngdata = format_dict.get('image/png')
34 pngdata = format_dict.get('image/png')
35 if pngdata is not None:
35 if pngdata is not None:
36 format_dict['image/png'] = encodestring(pngdata).decode('ascii')
36 format_dict['image/png'] = encodestring(pngdata).decode('ascii')
37 jpegdata = format_dict.get('image/jpeg')
37 jpegdata = format_dict.get('image/jpeg')
38 if jpegdata is not None:
38 if jpegdata is not None:
39 format_dict['image/jpeg'] = encodestring(jpegdata).decode('ascii')
39 format_dict['image/jpeg'] = encodestring(jpegdata).decode('ascii')
40
40
41
41
42 class ZMQShellDisplayHook(DisplayHook):
42 class ZMQShellDisplayHook(DisplayHook):
43 """A displayhook subclass that publishes data using ZeroMQ. This is intended
43 """A displayhook subclass that publishes data using ZeroMQ. This is intended
44 to work with an InteractiveShell instance. It sends a dict of different
44 to work with an InteractiveShell instance. It sends a dict of different
45 representations of the object."""
45 representations of the object."""
46 topic=None
46
47
47 session = Instance(Session)
48 session = Instance(Session)
48 pub_socket = Instance('zmq.Socket')
49 pub_socket = Instance('zmq.Socket')
49 parent_header = Dict({})
50 parent_header = Dict({})
50
51
51 def set_parent(self, parent):
52 def set_parent(self, parent):
52 """Set the parent for outbound messages."""
53 """Set the parent for outbound messages."""
53 self.parent_header = extract_header(parent)
54 self.parent_header = extract_header(parent)
54
55
55 def start_displayhook(self):
56 def start_displayhook(self):
56 self.msg = self.session.msg(u'pyout', {}, parent=self.parent_header)
57 self.msg = self.session.msg(u'pyout', {}, parent=self.parent_header)
57
58
58 def write_output_prompt(self):
59 def write_output_prompt(self):
59 """Write the output prompt."""
60 """Write the output prompt."""
60 self.msg['content']['execution_count'] = self.prompt_count
61 self.msg['content']['execution_count'] = self.prompt_count
61
62
62 def write_format_data(self, format_dict):
63 def write_format_data(self, format_dict):
63 _encode_binary(format_dict)
64 _encode_binary(format_dict)
64 self.msg['content']['data'] = format_dict
65 self.msg['content']['data'] = format_dict
65
66
66 def finish_displayhook(self):
67 def finish_displayhook(self):
67 """Finish up all displayhook activities."""
68 """Finish up all displayhook activities."""
68 sys.stdout.flush()
69 sys.stdout.flush()
69 sys.stderr.flush()
70 sys.stderr.flush()
70 self.session.send(self.pub_socket, self.msg)
71 self.session.send(self.pub_socket, self.msg, ident=self.topic)
71 self.msg = None
72 self.msg = None
72
73
@@ -1,214 +1,223 b''
1 # encoding: utf-8
1 # encoding: utf-8
2 """Event loop integration for the ZeroMQ-based kernels.
2 """Event loop integration for the ZeroMQ-based kernels.
3 """
3 """
4
4
5 #-----------------------------------------------------------------------------
5 #-----------------------------------------------------------------------------
6 # Copyright (C) 2011 The IPython Development Team
6 # Copyright (C) 2011 The IPython Development Team
7
7
8 # Distributed under the terms of the BSD License. The full license is in
8 # Distributed under the terms of the BSD License. The full license is in
9 # the file COPYING, distributed as part of this software.
9 # the file COPYING, distributed as part of this software.
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11
11
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16
16
17 import sys
17 import sys
18
18
19 # System library imports.
19 # System library imports.
20 import zmq
20 import zmq
21
21
22 # Local imports.
22 # Local imports.
23 from IPython.config.application import Application
23 from IPython.utils import io
24 from IPython.utils import io
24
25
26
25 #------------------------------------------------------------------------------
27 #------------------------------------------------------------------------------
26 # Eventloops for integrating the Kernel into different GUIs
28 # Eventloops for integrating the Kernel into different GUIs
27 #------------------------------------------------------------------------------
29 #------------------------------------------------------------------------------
28
30
29 def loop_qt4(kernel):
31 def loop_qt4(kernel):
30 """Start a kernel with PyQt4 event loop integration."""
32 """Start a kernel with PyQt4 event loop integration."""
31
33
32 from IPython.external.qt_for_kernel import QtCore
34 from IPython.external.qt_for_kernel import QtCore
33 from IPython.lib.guisupport import get_app_qt4, start_event_loop_qt4
35 from IPython.lib.guisupport import get_app_qt4, start_event_loop_qt4
34
36
35 kernel.app = get_app_qt4([" "])
37 kernel.app = get_app_qt4([" "])
36 kernel.app.setQuitOnLastWindowClosed(False)
38 kernel.app.setQuitOnLastWindowClosed(False)
37 kernel.timer = QtCore.QTimer()
39 kernel.timer = QtCore.QTimer()
38 kernel.timer.timeout.connect(kernel.do_one_iteration)
40 kernel.timer.timeout.connect(kernel.do_one_iteration)
39 # Units for the timer are in milliseconds
41 # Units for the timer are in milliseconds
40 kernel.timer.start(1000*kernel._poll_interval)
42 kernel.timer.start(1000*kernel._poll_interval)
41 start_event_loop_qt4(kernel.app)
43 start_event_loop_qt4(kernel.app)
42
44
43
45
44 def loop_wx(kernel):
46 def loop_wx(kernel):
45 """Start a kernel with wx event loop support."""
47 """Start a kernel with wx event loop support."""
46
48
47 import wx
49 import wx
48 from IPython.lib.guisupport import start_event_loop_wx
50 from IPython.lib.guisupport import start_event_loop_wx
49
51
50 doi = kernel.do_one_iteration
52 doi = kernel.do_one_iteration
51 # Wx uses milliseconds
53 # Wx uses milliseconds
52 poll_interval = int(1000*kernel._poll_interval)
54 poll_interval = int(1000*kernel._poll_interval)
53
55
54 # We have to put the wx.Timer in a wx.Frame for it to fire properly.
56 # We have to put the wx.Timer in a wx.Frame for it to fire properly.
55 # We make the Frame hidden when we create it in the main app below.
57 # We make the Frame hidden when we create it in the main app below.
56 class TimerFrame(wx.Frame):
58 class TimerFrame(wx.Frame):
57 def __init__(self, func):
59 def __init__(self, func):
58 wx.Frame.__init__(self, None, -1)
60 wx.Frame.__init__(self, None, -1)
59 self.timer = wx.Timer(self)
61 self.timer = wx.Timer(self)
60 # Units for the timer are in milliseconds
62 # Units for the timer are in milliseconds
61 self.timer.Start(poll_interval)
63 self.timer.Start(poll_interval)
62 self.Bind(wx.EVT_TIMER, self.on_timer)
64 self.Bind(wx.EVT_TIMER, self.on_timer)
63 self.func = func
65 self.func = func
64
66
65 def on_timer(self, event):
67 def on_timer(self, event):
66 self.func()
68 self.func()
67
69
68 # We need a custom wx.App to create our Frame subclass that has the
70 # We need a custom wx.App to create our Frame subclass that has the
69 # wx.Timer to drive the ZMQ event loop.
71 # wx.Timer to drive the ZMQ event loop.
70 class IPWxApp(wx.App):
72 class IPWxApp(wx.App):
71 def OnInit(self):
73 def OnInit(self):
72 self.frame = TimerFrame(doi)
74 self.frame = TimerFrame(doi)
73 self.frame.Show(False)
75 self.frame.Show(False)
74 return True
76 return True
75
77
76 # The redirect=False here makes sure that wx doesn't replace
78 # The redirect=False here makes sure that wx doesn't replace
77 # sys.stdout/stderr with its own classes.
79 # sys.stdout/stderr with its own classes.
78 kernel.app = IPWxApp(redirect=False)
80 kernel.app = IPWxApp(redirect=False)
79
81
80 # The import of wx on Linux sets the handler for signal.SIGINT
82 # The import of wx on Linux sets the handler for signal.SIGINT
81 # to 0. This is a bug in wx or gtk. We fix by just setting it
83 # to 0. This is a bug in wx or gtk. We fix by just setting it
82 # back to the Python default.
84 # back to the Python default.
83 import signal
85 import signal
84 if not callable(signal.getsignal(signal.SIGINT)):
86 if not callable(signal.getsignal(signal.SIGINT)):
85 signal.signal(signal.SIGINT, signal.default_int_handler)
87 signal.signal(signal.SIGINT, signal.default_int_handler)
86
88
87 start_event_loop_wx(kernel.app)
89 start_event_loop_wx(kernel.app)
88
90
89
91
90 def loop_tk(kernel):
92 def loop_tk(kernel):
91 """Start a kernel with the Tk event loop."""
93 """Start a kernel with the Tk event loop."""
92
94
93 import Tkinter
95 import Tkinter
94 doi = kernel.do_one_iteration
96 doi = kernel.do_one_iteration
95 # Tk uses milliseconds
97 # Tk uses milliseconds
96 poll_interval = int(1000*kernel._poll_interval)
98 poll_interval = int(1000*kernel._poll_interval)
97 # For Tkinter, we create a Tk object and call its withdraw method.
99 # For Tkinter, we create a Tk object and call its withdraw method.
98 class Timer(object):
100 class Timer(object):
99 def __init__(self, func):
101 def __init__(self, func):
100 self.app = Tkinter.Tk()
102 self.app = Tkinter.Tk()
101 self.app.withdraw()
103 self.app.withdraw()
102 self.func = func
104 self.func = func
103
105
104 def on_timer(self):
106 def on_timer(self):
105 self.func()
107 self.func()
106 self.app.after(poll_interval, self.on_timer)
108 self.app.after(poll_interval, self.on_timer)
107
109
108 def start(self):
110 def start(self):
109 self.on_timer() # Call it once to get things going.
111 self.on_timer() # Call it once to get things going.
110 self.app.mainloop()
112 self.app.mainloop()
111
113
112 kernel.timer = Timer(doi)
114 kernel.timer = Timer(doi)
113 kernel.timer.start()
115 kernel.timer.start()
114
116
115
117
116 def loop_gtk(kernel):
118 def loop_gtk(kernel):
117 """Start the kernel, coordinating with the GTK event loop"""
119 """Start the kernel, coordinating with the GTK event loop"""
118 from .gui.gtkembed import GTKEmbed
120 from .gui.gtkembed import GTKEmbed
119
121
120 gtk_kernel = GTKEmbed(kernel)
122 gtk_kernel = GTKEmbed(kernel)
121 gtk_kernel.start()
123 gtk_kernel.start()
122
124
123
125
124 def loop_cocoa(kernel):
126 def loop_cocoa(kernel):
125 """Start the kernel, coordinating with the Cocoa CFRunLoop event loop
127 """Start the kernel, coordinating with the Cocoa CFRunLoop event loop
126 via the matplotlib MacOSX backend.
128 via the matplotlib MacOSX backend.
127 """
129 """
128 import matplotlib
130 import matplotlib
129 if matplotlib.__version__ < '1.1.0':
131 if matplotlib.__version__ < '1.1.0':
130 kernel.log.warn(
132 kernel.log.warn(
131 "MacOSX backend in matplotlib %s doesn't have a Timer, "
133 "MacOSX backend in matplotlib %s doesn't have a Timer, "
132 "falling back on Tk for CFRunLoop integration. Note that "
134 "falling back on Tk for CFRunLoop integration. Note that "
133 "even this won't work if Tk is linked against X11 instead of "
135 "even this won't work if Tk is linked against X11 instead of "
134 "Cocoa (e.g. EPD). To use the MacOSX backend in the kernel, "
136 "Cocoa (e.g. EPD). To use the MacOSX backend in the kernel, "
135 "you must use matplotlib >= 1.1.0, or a native libtk."
137 "you must use matplotlib >= 1.1.0, or a native libtk."
136 )
138 )
137 return loop_tk(kernel)
139 return loop_tk(kernel)
138
140
139 from matplotlib.backends.backend_macosx import TimerMac, show
141 from matplotlib.backends.backend_macosx import TimerMac, show
140
142
141 # scale interval for sec->ms
143 # scale interval for sec->ms
142 poll_interval = int(1000*kernel._poll_interval)
144 poll_interval = int(1000*kernel._poll_interval)
143
145
144 real_excepthook = sys.excepthook
146 real_excepthook = sys.excepthook
145 def handle_int(etype, value, tb):
147 def handle_int(etype, value, tb):
146 """don't let KeyboardInterrupts look like crashes"""
148 """don't let KeyboardInterrupts look like crashes"""
147 if etype is KeyboardInterrupt:
149 if etype is KeyboardInterrupt:
148 io.raw_print("KeyboardInterrupt caught in CFRunLoop")
150 io.raw_print("KeyboardInterrupt caught in CFRunLoop")
149 else:
151 else:
150 real_excepthook(etype, value, tb)
152 real_excepthook(etype, value, tb)
151
153
152 # add doi() as a Timer to the CFRunLoop
154 # add doi() as a Timer to the CFRunLoop
153 def doi():
155 def doi():
154 # restore excepthook during IPython code
156 # restore excepthook during IPython code
155 sys.excepthook = real_excepthook
157 sys.excepthook = real_excepthook
156 kernel.do_one_iteration()
158 kernel.do_one_iteration()
157 # and back:
159 # and back:
158 sys.excepthook = handle_int
160 sys.excepthook = handle_int
159
161
160 t = TimerMac(poll_interval)
162 t = TimerMac(poll_interval)
161 t.add_callback(doi)
163 t.add_callback(doi)
162 t.start()
164 t.start()
163
165
164 # but still need a Poller for when there are no active windows,
166 # but still need a Poller for when there are no active windows,
165 # during which time mainloop() returns immediately
167 # during which time mainloop() returns immediately
166 poller = zmq.Poller()
168 poller = zmq.Poller()
167 poller.register(kernel.shell_socket, zmq.POLLIN)
169 if kernel.control_stream:
170 poller.register(kernel.control_stream.socket, zmq.POLLIN)
171 for stream in kernel.shell_streams:
172 poller.register(stream.socket, zmq.POLLIN)
168
173
169 while True:
174 while True:
170 try:
175 try:
171 # double nested try/except, to properly catch KeyboardInterrupt
176 # double nested try/except, to properly catch KeyboardInterrupt
172 # due to pyzmq Issue #130
177 # due to pyzmq Issue #130
173 try:
178 try:
174 # don't let interrupts during mainloop invoke crash_handler:
179 # don't let interrupts during mainloop invoke crash_handler:
175 sys.excepthook = handle_int
180 sys.excepthook = handle_int
176 show.mainloop()
181 show.mainloop()
177 sys.excepthook = real_excepthook
182 sys.excepthook = real_excepthook
178 # use poller if mainloop returned (no windows)
183 # use poller if mainloop returned (no windows)
179 # scale by extra factor of 10, since it's a real poll
184 # scale by extra factor of 10, since it's a real poll
180 poller.poll(10*poll_interval)
185 poller.poll(10*poll_interval)
181 kernel.do_one_iteration()
186 kernel.do_one_iteration()
182 except:
187 except:
183 raise
188 raise
184 except KeyboardInterrupt:
189 except KeyboardInterrupt:
185 # Ctrl-C shouldn't crash the kernel
190 # Ctrl-C shouldn't crash the kernel
186 io.raw_print("KeyboardInterrupt caught in kernel")
191 io.raw_print("KeyboardInterrupt caught in kernel")
187 finally:
192 finally:
188 # ensure excepthook is restored
193 # ensure excepthook is restored
189 sys.excepthook = real_excepthook
194 sys.excepthook = real_excepthook
190
195
191 # mapping of keys to loop functions
196 # mapping of keys to loop functions
192 loop_map = {
197 loop_map = {
193 'qt' : loop_qt4,
198 'qt' : loop_qt4,
194 'qt4': loop_qt4,
199 'qt4': loop_qt4,
195 'inline': None,
200 'inline': None,
196 'osx': loop_cocoa,
201 'osx': loop_cocoa,
197 'wx' : loop_wx,
202 'wx' : loop_wx,
198 'tk' : loop_tk,
203 'tk' : loop_tk,
199 'gtk': loop_gtk,
204 'gtk': loop_gtk,
200 None : None,
205 None : None,
201 }
206 }
202
207
203
208
204 def enable_gui(gui, kernel=None):
209 def enable_gui(gui, kernel=None):
205 """Enable integration with a given GUI"""
210 """Enable integration with a given GUI"""
206 if kernel is None:
207 from .ipkernel import IPKernelApp
208 kernel = IPKernelApp.instance().kernel
209 if gui not in loop_map:
211 if gui not in loop_map:
210 raise ValueError("GUI %r not supported" % gui)
212 raise ValueError("GUI %r not supported" % gui)
213 if kernel is None:
214 if Application.initialized():
215 kernel = getattr(Application.instance(), 'kernel', None)
216 if kernel is None:
217 raise RuntimeError("You didn't specify a kernel,"
218 " and no IPython Application with a kernel appears to be running."
219 )
211 loop = loop_map[gui]
220 loop = loop_map[gui]
212 if kernel.eventloop is not None and kernel.eventloop is not loop:
221 if kernel.eventloop is not None and kernel.eventloop is not loop:
213 raise RuntimeError("Cannot activate multiple GUI eventloops")
222 raise RuntimeError("Cannot activate multiple GUI eventloops")
214 kernel.eventloop = loop
223 kernel.eventloop = loop
@@ -1,710 +1,919 b''
1 #!/usr/bin/env python
1 #!/usr/bin/env python
2 """A simple interactive kernel that talks to a frontend over 0MQ.
2 """A simple interactive kernel that talks to a frontend over 0MQ.
3
3
4 Things to do:
4 Things to do:
5
5
6 * Implement `set_parent` logic. Right before doing exec, the Kernel should
6 * Implement `set_parent` logic. Right before doing exec, the Kernel should
7 call set_parent on all the PUB objects with the message about to be executed.
7 call set_parent on all the PUB objects with the message about to be executed.
8 * Implement random port and security key logic.
8 * Implement random port and security key logic.
9 * Implement control messages.
9 * Implement control messages.
10 * Implement event loop and poll version.
10 * Implement event loop and poll version.
11 """
11 """
12
12
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 from __future__ import print_function
16 from __future__ import print_function
17
17
18 # Standard library imports.
18 # Standard library imports
19 import __builtin__
19 import __builtin__
20 import atexit
20 import atexit
21 import sys
21 import sys
22 import time
22 import time
23 import traceback
23 import traceback
24 import logging
24 import logging
25 import uuid
26
27 from datetime import datetime
25 from signal import (
28 from signal import (
26 signal, default_int_handler, SIGINT, SIG_IGN
29 signal, getsignal, default_int_handler, SIGINT, SIG_IGN
27 )
30 )
28 # System library imports.
31
32 # System library imports
29 import zmq
33 import zmq
34 from zmq.eventloop import ioloop
35 from zmq.eventloop.zmqstream import ZMQStream
30
36
31 # Local imports.
37 # Local imports
32 from IPython.core import pylabtools
38 from IPython.core import pylabtools
33 from IPython.config.configurable import Configurable
39 from IPython.config.configurable import Configurable
34 from IPython.config.application import boolean_flag, catch_config_error
40 from IPython.config.application import boolean_flag, catch_config_error
35 from IPython.core.application import ProfileDir
41 from IPython.core.application import ProfileDir
36 from IPython.core.error import StdinNotImplementedError
42 from IPython.core.error import StdinNotImplementedError
37 from IPython.core.shellapp import (
43 from IPython.core.shellapp import (
38 InteractiveShellApp, shell_flags, shell_aliases
44 InteractiveShellApp, shell_flags, shell_aliases
39 )
45 )
40 from IPython.utils import io
46 from IPython.utils import io
41 from IPython.utils import py3compat
47 from IPython.utils import py3compat
42 from IPython.utils.frame import extract_module_locals
48 from IPython.utils.frame import extract_module_locals
43 from IPython.utils.jsonutil import json_clean
49 from IPython.utils.jsonutil import json_clean
44 from IPython.utils.traitlets import (
50 from IPython.utils.traitlets import (
45 Any, Instance, Float, Dict, CaselessStrEnum
51 Any, Instance, Float, Dict, CaselessStrEnum, List, Set, Integer, Unicode
46 )
52 )
47
53
48 from entry_point import base_launch_kernel
54 from entry_point import base_launch_kernel
49 from kernelapp import KernelApp, kernel_flags, kernel_aliases
55 from kernelapp import KernelApp, kernel_flags, kernel_aliases
56 from serialize import serialize_object, unpack_apply_message
50 from session import Session, Message
57 from session import Session, Message
51 from zmqshell import ZMQInteractiveShell
58 from zmqshell import ZMQInteractiveShell
52
59
53
60
54 #-----------------------------------------------------------------------------
61 #-----------------------------------------------------------------------------
55 # Main kernel class
62 # Main kernel class
56 #-----------------------------------------------------------------------------
63 #-----------------------------------------------------------------------------
57
64
58 class Kernel(Configurable):
65 class Kernel(Configurable):
59
66
60 #---------------------------------------------------------------------------
67 #---------------------------------------------------------------------------
61 # Kernel interface
68 # Kernel interface
62 #---------------------------------------------------------------------------
69 #---------------------------------------------------------------------------
63
70
64 # attribute to override with a GUI
71 # attribute to override with a GUI
65 eventloop = Any(None)
72 eventloop = Any(None)
73 def _eventloop_changed(self, name, old, new):
74 """schedule call to eventloop from IOLoop"""
75 loop = ioloop.IOLoop.instance()
76 loop.add_timeout(time.time()+0.1, self.enter_eventloop)
66
77
67 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
78 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
68 session = Instance(Session)
79 session = Instance(Session)
69 profile_dir = Instance('IPython.core.profiledir.ProfileDir')
80 profile_dir = Instance('IPython.core.profiledir.ProfileDir')
70 shell_socket = Instance('zmq.Socket')
81 shell_streams = List()
71 iopub_socket = Instance('zmq.Socket')
82 control_stream = Instance(ZMQStream)
72 stdin_socket = Instance('zmq.Socket')
83 iopub_socket = Instance(zmq.Socket)
84 stdin_socket = Instance(zmq.Socket)
73 log = Instance(logging.Logger)
85 log = Instance(logging.Logger)
74
86
75 user_module = Instance('types.ModuleType')
87 user_module = Any()
76 def _user_module_changed(self, name, old, new):
88 def _user_module_changed(self, name, old, new):
77 if self.shell is not None:
89 if self.shell is not None:
78 self.shell.user_module = new
90 self.shell.user_module = new
79
91
80 user_ns = Dict(default_value=None)
92 user_ns = Dict(default_value=None)
81 def _user_ns_changed(self, name, old, new):
93 def _user_ns_changed(self, name, old, new):
82 if self.shell is not None:
94 if self.shell is not None:
83 self.shell.user_ns = new
95 self.shell.user_ns = new
84 self.shell.init_user_ns()
96 self.shell.init_user_ns()
85
97
86 # Private interface
98 # identities:
99 int_id = Integer(-1)
100 ident = Unicode()
101
102 def _ident_default(self):
103 return unicode(uuid.uuid4())
87
104
105
106 # Private interface
107
88 # Time to sleep after flushing the stdout/err buffers in each execute
108 # Time to sleep after flushing the stdout/err buffers in each execute
89 # cycle. While this introduces a hard limit on the minimal latency of the
109 # cycle. While this introduces a hard limit on the minimal latency of the
90 # execute cycle, it helps prevent output synchronization problems for
110 # execute cycle, it helps prevent output synchronization problems for
91 # clients.
111 # clients.
92 # Units are in seconds. The minimum zmq latency on local host is probably
112 # Units are in seconds. The minimum zmq latency on local host is probably
93 # ~150 microseconds, set this to 500us for now. We may need to increase it
113 # ~150 microseconds, set this to 500us for now. We may need to increase it
94 # a little if it's not enough after more interactive testing.
114 # a little if it's not enough after more interactive testing.
95 _execute_sleep = Float(0.0005, config=True)
115 _execute_sleep = Float(0.0005, config=True)
96
116
97 # Frequency of the kernel's event loop.
117 # Frequency of the kernel's event loop.
98 # Units are in seconds, kernel subclasses for GUI toolkits may need to
118 # Units are in seconds, kernel subclasses for GUI toolkits may need to
99 # adapt to milliseconds.
119 # adapt to milliseconds.
100 _poll_interval = Float(0.05, config=True)
120 _poll_interval = Float(0.05, config=True)
101
121
102 # If the shutdown was requested over the network, we leave here the
122 # If the shutdown was requested over the network, we leave here the
103 # necessary reply message so it can be sent by our registered atexit
123 # necessary reply message so it can be sent by our registered atexit
104 # handler. This ensures that the reply is only sent to clients truly at
124 # handler. This ensures that the reply is only sent to clients truly at
105 # the end of our shutdown process (which happens after the underlying
125 # the end of our shutdown process (which happens after the underlying
106 # IPython shell's own shutdown).
126 # IPython shell's own shutdown).
107 _shutdown_message = None
127 _shutdown_message = None
108
128
109 # This is a dict of port number that the kernel is listening on. It is set
129 # This is a dict of port number that the kernel is listening on. It is set
110 # by record_ports and used by connect_request.
130 # by record_ports and used by connect_request.
111 _recorded_ports = Dict()
131 _recorded_ports = Dict()
112
132
133 # set of aborted msg_ids
134 aborted = Set()
113
135
114
136
115 def __init__(self, **kwargs):
137 def __init__(self, **kwargs):
116 super(Kernel, self).__init__(**kwargs)
138 super(Kernel, self).__init__(**kwargs)
117
139
118 # Before we even start up the shell, register *first* our exit handlers
119 # so they come before the shell's
120 atexit.register(self._at_shutdown)
121
122 # Initialize the InteractiveShell subclass
140 # Initialize the InteractiveShell subclass
123 self.shell = ZMQInteractiveShell.instance(config=self.config,
141 self.shell = ZMQInteractiveShell.instance(config=self.config,
124 profile_dir = self.profile_dir,
142 profile_dir = self.profile_dir,
125 user_module = self.user_module,
143 user_module = self.user_module,
126 user_ns = self.user_ns,
144 user_ns = self.user_ns,
127 )
145 )
128 self.shell.displayhook.session = self.session
146 self.shell.displayhook.session = self.session
129 self.shell.displayhook.pub_socket = self.iopub_socket
147 self.shell.displayhook.pub_socket = self.iopub_socket
148 self.shell.displayhook.topic = self._topic('pyout')
130 self.shell.display_pub.session = self.session
149 self.shell.display_pub.session = self.session
131 self.shell.display_pub.pub_socket = self.iopub_socket
150 self.shell.display_pub.pub_socket = self.iopub_socket
132
151
133 # TMP - hack while developing
152 # TMP - hack while developing
134 self.shell._reply_content = None
153 self.shell._reply_content = None
135
154
136 # Build dict of handlers for message types
155 # Build dict of handlers for message types
137 msg_types = [ 'execute_request', 'complete_request',
156 msg_types = [ 'execute_request', 'complete_request',
138 'object_info_request', 'history_request',
157 'object_info_request', 'history_request',
139 'connect_request', 'shutdown_request']
158 'connect_request', 'shutdown_request',
140 self.handlers = {}
159 'apply_request',
160 ]
161 self.shell_handlers = {}
141 for msg_type in msg_types:
162 for msg_type in msg_types:
142 self.handlers[msg_type] = getattr(self, msg_type)
163 self.shell_handlers[msg_type] = getattr(self, msg_type)
143
164
144 def do_one_iteration(self):
165 control_msg_types = msg_types + [ 'clear_request', 'abort_request' ]
145 """Do one iteration of the kernel's evaluation loop.
166 self.control_handlers = {}
146 """
167 for msg_type in control_msg_types:
168 self.control_handlers[msg_type] = getattr(self, msg_type)
169
170 def dispatch_control(self, msg):
171 """dispatch control requests"""
172 idents,msg = self.session.feed_identities(msg, copy=False)
147 try:
173 try:
148 ident,msg = self.session.recv(self.shell_socket, zmq.NOBLOCK)
174 msg = self.session.unserialize(msg, content=True, copy=False)
149 except Exception:
175 except:
150 self.log.warn("Invalid Message:", exc_info=True)
176 self.log.error("Invalid Control Message", exc_info=True)
151 return
152 if msg is None:
153 return
177 return
154
178
155 msg_type = msg['header']['msg_type']
179 self.log.debug("Control received: %s", msg)
180
181 header = msg['header']
182 msg_id = header['msg_id']
183 msg_type = header['msg_type']
156
184
157 # This assert will raise in versions of zeromq 2.0.7 and lesser.
185 handler = self.control_handlers.get(msg_type, None)
158 # We now require 2.0.8 or above, so we can uncomment for safety.
186 if handler is None:
159 # print(ident,msg, file=sys.__stdout__)
187 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type)
160 assert ident is not None, "Missing message part."
188 else:
189 try:
190 handler(self.control_stream, idents, msg)
191 except Exception:
192 self.log.error("Exception in control handler:", exc_info=True)
193
194 def dispatch_shell(self, stream, msg):
195 """dispatch shell requests"""
196 # flush control requests first
197 if self.control_stream:
198 self.control_stream.flush()
199
200 idents,msg = self.session.feed_identities(msg, copy=False)
201 try:
202 msg = self.session.unserialize(msg, content=True, copy=False)
203 except:
204 self.log.error("Invalid Message", exc_info=True)
205 return
161
206
207 header = msg['header']
208 msg_id = header['msg_id']
209 msg_type = msg['header']['msg_type']
210
162 # Print some info about this message and leave a '--->' marker, so it's
211 # Print some info about this message and leave a '--->' marker, so it's
163 # easier to trace visually the message chain when debugging. Each
212 # easier to trace visually the message chain when debugging. Each
164 # handler prints its message at the end.
213 # handler prints its message at the end.
165 self.log.debug('\n*** MESSAGE TYPE:'+str(msg_type)+'***')
214 self.log.debug('\n*** MESSAGE TYPE:%s***', msg_type)
166 self.log.debug(' Content: '+str(msg['content'])+'\n --->\n ')
215 self.log.debug(' Content: %s\n --->\n ', msg['content'])
167
216
168 # Find and call actual handler for message
217 if msg_id in self.aborted:
169 handler = self.handlers.get(msg_type, None)
218 self.aborted.remove(msg_id)
219 # is it safe to assume a msg_id will not be resubmitted?
220 reply_type = msg_type.split('_')[0] + '_reply'
221 status = {'status' : 'aborted'}
222 sub = {'engine' : self.ident}
223 sub.update(status)
224 reply_msg = self.session.send(stream, reply_type, subheader=sub,
225 content=status, parent=msg, ident=idents)
226 return
227
228 handler = self.shell_handlers.get(msg_type, None)
170 if handler is None:
229 if handler is None:
171 self.log.error("UNKNOWN MESSAGE TYPE:" +str(msg))
230 self.log.error("UNKNOWN MESSAGE TYPE: %r", msg_type)
172 else:
231 else:
173 handler(ident, msg)
232 # ensure default_int_handler during handler call
174
233 sig = signal(SIGINT, default_int_handler)
175 # Check whether we should exit, in case the incoming message set the
176 # exit flag on
177 if self.shell.exit_now:
178 self.log.debug('\nExiting IPython kernel...')
179 # We do a normal, clean exit, which allows any actions registered
180 # via atexit (such as history saving) to take place.
181 sys.exit(0)
182
183
184 def start(self):
185 """ Start the kernel main loop.
186 """
187 # a KeyboardInterrupt (SIGINT) can occur on any python statement, so
188 # let's ignore (SIG_IGN) them until we're in a place to handle them properly
189 signal(SIGINT,SIG_IGN)
190 poller = zmq.Poller()
191 poller.register(self.shell_socket, zmq.POLLIN)
192 # loop while self.eventloop has not been overridden
193 while self.eventloop is None:
194 try:
234 try:
195 # scale by extra factor of 10, because there is no
235 handler(stream, idents, msg)
196 # reason for this to be anything less than ~ 0.1s
236 except Exception:
197 # since it is a real poller and will respond
237 self.log.error("Exception in message handler:", exc_info=True)
198 # to events immediately
238 finally:
199
239 signal(SIGINT, sig)
200 # double nested try/except, to properly catch KeyboardInterrupt
240
201 # due to pyzmq Issue #130
241 def enter_eventloop(self):
202 try:
242 """enter eventloop"""
203 poller.poll(10*1000*self._poll_interval)
243 self.log.info("entering eventloop")
204 # restore raising of KeyboardInterrupt
244 # restore default_int_handler
205 signal(SIGINT, default_int_handler)
206 self.do_one_iteration()
207 except:
208 raise
209 finally:
210 # prevent raising of KeyboardInterrupt
211 signal(SIGINT,SIG_IGN)
212 except KeyboardInterrupt:
213 # Ctrl-C shouldn't crash the kernel
214 io.raw_print("KeyboardInterrupt caught in kernel")
215 # stop ignoring sigint, now that we are out of our own loop,
216 # we don't want to prevent future code from handling it
217 signal(SIGINT, default_int_handler)
245 signal(SIGINT, default_int_handler)
218 while self.eventloop is not None:
246 while self.eventloop is not None:
219 try:
247 try:
220 self.eventloop(self)
248 self.eventloop(self)
221 except KeyboardInterrupt:
249 except KeyboardInterrupt:
222 # Ctrl-C shouldn't crash the kernel
250 # Ctrl-C shouldn't crash the kernel
223 io.raw_print("KeyboardInterrupt caught in kernel")
251 self.log.error("KeyboardInterrupt caught in kernel")
224 continue
252 continue
225 else:
253 else:
226 # eventloop exited cleanly, this means we should stop (right?)
254 # eventloop exited cleanly, this means we should stop (right?)
227 self.eventloop = None
255 self.eventloop = None
228 break
256 break
257 self.log.info("exiting eventloop")
258 # if eventloop exits, IOLoop should stop
259 ioloop.IOLoop.instance().stop()
260
261 def start(self):
262 """register dispatchers for streams"""
263 self.shell.exit_now = False
264 if self.control_stream:
265 self.control_stream.on_recv(self.dispatch_control, copy=False)
266
267 def make_dispatcher(stream):
268 def dispatcher(msg):
269 return self.dispatch_shell(stream, msg)
270 return dispatcher
271
272 for s in self.shell_streams:
273 s.on_recv(make_dispatcher(s), copy=False)
274
275 def do_one_iteration(self):
276 """step eventloop just once"""
277 if self.control_stream:
278 self.control_stream.flush()
279 for stream in self.shell_streams:
280 # handle at most one request per iteration
281 stream.flush(zmq.POLLIN, 1)
282 stream.flush(zmq.POLLOUT)
229
283
230
284
231 def record_ports(self, ports):
285 def record_ports(self, ports):
232 """Record the ports that this kernel is using.
286 """Record the ports that this kernel is using.
233
287
234 The creator of the Kernel instance must call this methods if they
288 The creator of the Kernel instance must call this methods if they
235 want the :meth:`connect_request` method to return the port numbers.
289 want the :meth:`connect_request` method to return the port numbers.
236 """
290 """
237 self._recorded_ports = ports
291 self._recorded_ports = ports
238
292
239 #---------------------------------------------------------------------------
293 #---------------------------------------------------------------------------
240 # Kernel request handlers
294 # Kernel request handlers
241 #---------------------------------------------------------------------------
295 #---------------------------------------------------------------------------
242
296
297 def _make_subheader(self):
298 """init subheader dict, for execute/apply_reply"""
299 return {
300 'dependencies_met' : True,
301 'engine' : self.ident,
302 'started': datetime.now(),
303 }
304
243 def _publish_pyin(self, code, parent, execution_count):
305 def _publish_pyin(self, code, parent, execution_count):
244 """Publish the code request on the pyin stream."""
306 """Publish the code request on the pyin stream."""
245
307
246 self.session.send(self.iopub_socket, u'pyin', {u'code':code,
308 self.session.send(self.iopub_socket, u'pyin',
247 u'execution_count': execution_count}, parent=parent)
309 {u'code':code, u'execution_count': execution_count},
310 parent=parent, ident=self._topic('pyin')
311 )
248
312
249 def execute_request(self, ident, parent):
313 def execute_request(self, stream, ident, parent):
250
314
251 self.session.send(self.iopub_socket,
315 self.session.send(self.iopub_socket,
252 u'status',
316 u'status',
253 {u'execution_state':u'busy'},
317 {u'execution_state':u'busy'},
254 parent=parent )
318 parent=parent,
319 ident=self._topic('status'),
320 )
255
321
256 try:
322 try:
257 content = parent[u'content']
323 content = parent[u'content']
258 code = content[u'code']
324 code = content[u'code']
259 silent = content[u'silent']
325 silent = content[u'silent']
260 except:
326 except:
261 self.log.error("Got bad msg: ")
327 self.log.error("Got bad msg: ")
262 self.log.error(str(Message(parent)))
328 self.log.error("%s", parent)
263 return
329 return
330
331 sub = self._make_subheader()
264
332
265 shell = self.shell # we'll need this a lot here
333 shell = self.shell # we'll need this a lot here
266
334
267 # Replace raw_input. Note that is not sufficient to replace
335 # Replace raw_input. Note that is not sufficient to replace
268 # raw_input in the user namespace.
336 # raw_input in the user namespace.
269 if content.get('allow_stdin', False):
337 if content.get('allow_stdin', False):
270 raw_input = lambda prompt='': self._raw_input(prompt, ident, parent)
338 raw_input = lambda prompt='': self._raw_input(prompt, ident, parent)
271 else:
339 else:
272 raw_input = lambda prompt='' : self._no_raw_input()
340 raw_input = lambda prompt='' : self._no_raw_input()
273
341
274 if py3compat.PY3:
342 if py3compat.PY3:
275 __builtin__.input = raw_input
343 __builtin__.input = raw_input
276 else:
344 else:
277 __builtin__.raw_input = raw_input
345 __builtin__.raw_input = raw_input
278
346
279 # Set the parent message of the display hook and out streams.
347 # Set the parent message of the display hook and out streams.
280 shell.displayhook.set_parent(parent)
348 shell.displayhook.set_parent(parent)
281 shell.display_pub.set_parent(parent)
349 shell.display_pub.set_parent(parent)
282 sys.stdout.set_parent(parent)
350 sys.stdout.set_parent(parent)
283 sys.stderr.set_parent(parent)
351 sys.stderr.set_parent(parent)
284
352
285 # Re-broadcast our input for the benefit of listening clients, and
353 # Re-broadcast our input for the benefit of listening clients, and
286 # start computing output
354 # start computing output
287 if not silent:
355 if not silent:
288 self._publish_pyin(code, parent, shell.execution_count)
356 self._publish_pyin(code, parent, shell.execution_count)
289
357
290 reply_content = {}
358 reply_content = {}
291 try:
359 try:
292 if silent:
360 # FIXME: the shell calls the exception handler itself.
293 # run_code uses 'exec' mode, so no displayhook will fire, and it
361 shell.run_cell(code, store_history=not silent, silent=silent)
294 # doesn't call logging or history manipulations. Print
295 # statements in that code will obviously still execute.
296 shell.run_code(code)
297 else:
298 # FIXME: the shell calls the exception handler itself.
299 shell.run_cell(code, store_history=True)
300 except:
362 except:
301 status = u'error'
363 status = u'error'
302 # FIXME: this code right now isn't being used yet by default,
364 # FIXME: this code right now isn't being used yet by default,
303 # because the run_cell() call above directly fires off exception
365 # because the run_cell() call above directly fires off exception
304 # reporting. This code, therefore, is only active in the scenario
366 # reporting. This code, therefore, is only active in the scenario
305 # where runlines itself has an unhandled exception. We need to
367 # where runlines itself has an unhandled exception. We need to
306 # uniformize this, for all exception construction to come from a
368 # uniformize this, for all exception construction to come from a
307 # single location in the codbase.
369 # single location in the codbase.
308 etype, evalue, tb = sys.exc_info()
370 etype, evalue, tb = sys.exc_info()
309 tb_list = traceback.format_exception(etype, evalue, tb)
371 tb_list = traceback.format_exception(etype, evalue, tb)
310 reply_content.update(shell._showtraceback(etype, evalue, tb_list))
372 reply_content.update(shell._showtraceback(etype, evalue, tb_list))
311 else:
373 else:
312 status = u'ok'
374 status = u'ok'
313
375
314 reply_content[u'status'] = status
376 reply_content[u'status'] = status
315
377
316 # Return the execution counter so clients can display prompts
378 # Return the execution counter so clients can display prompts
317 reply_content['execution_count'] = shell.execution_count -1
379 reply_content['execution_count'] = shell.execution_count - 1
318
380
319 # FIXME - fish exception info out of shell, possibly left there by
381 # FIXME - fish exception info out of shell, possibly left there by
320 # runlines. We'll need to clean up this logic later.
382 # runlines. We'll need to clean up this logic later.
321 if shell._reply_content is not None:
383 if shell._reply_content is not None:
322 reply_content.update(shell._reply_content)
384 reply_content.update(shell._reply_content)
323 # reset after use
385 # reset after use
324 shell._reply_content = None
386 shell._reply_content = None
325
387
326 # At this point, we can tell whether the main code execution succeeded
388 # At this point, we can tell whether the main code execution succeeded
327 # or not. If it did, we proceed to evaluate user_variables/expressions
389 # or not. If it did, we proceed to evaluate user_variables/expressions
328 if reply_content['status'] == 'ok':
390 if reply_content['status'] == 'ok':
329 reply_content[u'user_variables'] = \
391 reply_content[u'user_variables'] = \
330 shell.user_variables(content[u'user_variables'])
392 shell.user_variables(content.get(u'user_variables', []))
331 reply_content[u'user_expressions'] = \
393 reply_content[u'user_expressions'] = \
332 shell.user_expressions(content[u'user_expressions'])
394 shell.user_expressions(content.get(u'user_expressions', {}))
333 else:
395 else:
334 # If there was an error, don't even try to compute variables or
396 # If there was an error, don't even try to compute variables or
335 # expressions
397 # expressions
336 reply_content[u'user_variables'] = {}
398 reply_content[u'user_variables'] = {}
337 reply_content[u'user_expressions'] = {}
399 reply_content[u'user_expressions'] = {}
338
400
339 # Payloads should be retrieved regardless of outcome, so we can both
401 # Payloads should be retrieved regardless of outcome, so we can both
340 # recover partial output (that could have been generated early in a
402 # recover partial output (that could have been generated early in a
341 # block, before an error) and clear the payload system always.
403 # block, before an error) and clear the payload system always.
342 reply_content[u'payload'] = shell.payload_manager.read_payload()
404 reply_content[u'payload'] = shell.payload_manager.read_payload()
343 # Be agressive about clearing the payload because we don't want
405 # Be agressive about clearing the payload because we don't want
344 # it to sit in memory until the next execute_request comes in.
406 # it to sit in memory until the next execute_request comes in.
345 shell.payload_manager.clear_payload()
407 shell.payload_manager.clear_payload()
346
408
347 # Flush output before sending the reply.
409 # Flush output before sending the reply.
348 sys.stdout.flush()
410 sys.stdout.flush()
349 sys.stderr.flush()
411 sys.stderr.flush()
350 # FIXME: on rare occasions, the flush doesn't seem to make it to the
412 # FIXME: on rare occasions, the flush doesn't seem to make it to the
351 # clients... This seems to mitigate the problem, but we definitely need
413 # clients... This seems to mitigate the problem, but we definitely need
352 # to better understand what's going on.
414 # to better understand what's going on.
353 if self._execute_sleep:
415 if self._execute_sleep:
354 time.sleep(self._execute_sleep)
416 time.sleep(self._execute_sleep)
355
417
356 # Send the reply.
418 # Send the reply.
357 reply_content = json_clean(reply_content)
419 reply_content = json_clean(reply_content)
358 reply_msg = self.session.send(self.shell_socket, u'execute_reply',
420
359 reply_content, parent, ident=ident)
421 sub['status'] = reply_content['status']
360 self.log.debug(str(reply_msg))
422 if reply_content['status'] == 'error' and \
423 reply_content['ename'] == 'UnmetDependency':
424 sub['dependencies_met'] = False
425
426 reply_msg = self.session.send(stream, u'execute_reply',
427 reply_content, parent, subheader=sub,
428 ident=ident)
429
430 self.log.debug("%s", reply_msg)
361
431
362 if reply_msg['content']['status'] == u'error':
432 if not silent and reply_msg['content']['status'] == u'error':
363 self._abort_queue()
433 self._abort_queues()
364
434
365 self.session.send(self.iopub_socket,
435 self.session.send(self.iopub_socket,
366 u'status',
436 u'status',
367 {u'execution_state':u'idle'},
437 {u'execution_state':u'idle'},
368 parent=parent )
438 parent=parent,
439 ident=self._topic('status'))
369
440
370 def complete_request(self, ident, parent):
441 def complete_request(self, stream, ident, parent):
371 txt, matches = self._complete(parent)
442 txt, matches = self._complete(parent)
372 matches = {'matches' : matches,
443 matches = {'matches' : matches,
373 'matched_text' : txt,
444 'matched_text' : txt,
374 'status' : 'ok'}
445 'status' : 'ok'}
375 matches = json_clean(matches)
446 matches = json_clean(matches)
376 completion_msg = self.session.send(self.shell_socket, 'complete_reply',
447 completion_msg = self.session.send(stream, 'complete_reply',
377 matches, parent, ident)
448 matches, parent, ident)
378 self.log.debug(str(completion_msg))
449 self.log.debug("%s", completion_msg)
379
450
380 def object_info_request(self, ident, parent):
451 def object_info_request(self, stream, ident, parent):
381 content = parent['content']
452 content = parent['content']
382 object_info = self.shell.object_inspect(content['oname'],
453 object_info = self.shell.object_inspect(content['oname'],
383 detail_level = content.get('detail_level', 0)
454 detail_level = content.get('detail_level', 0)
384 )
455 )
385 # Before we send this object over, we scrub it for JSON usage
456 # Before we send this object over, we scrub it for JSON usage
386 oinfo = json_clean(object_info)
457 oinfo = json_clean(object_info)
387 msg = self.session.send(self.shell_socket, 'object_info_reply',
458 msg = self.session.send(stream, 'object_info_reply',
388 oinfo, parent, ident)
459 oinfo, parent, ident)
389 self.log.debug(msg)
460 self.log.debug("%s", msg)
390
461
391 def history_request(self, ident, parent):
462 def history_request(self, stream, ident, parent):
392 # We need to pull these out, as passing **kwargs doesn't work with
463 # We need to pull these out, as passing **kwargs doesn't work with
393 # unicode keys before Python 2.6.5.
464 # unicode keys before Python 2.6.5.
394 hist_access_type = parent['content']['hist_access_type']
465 hist_access_type = parent['content']['hist_access_type']
395 raw = parent['content']['raw']
466 raw = parent['content']['raw']
396 output = parent['content']['output']
467 output = parent['content']['output']
397 if hist_access_type == 'tail':
468 if hist_access_type == 'tail':
398 n = parent['content']['n']
469 n = parent['content']['n']
399 hist = self.shell.history_manager.get_tail(n, raw=raw, output=output,
470 hist = self.shell.history_manager.get_tail(n, raw=raw, output=output,
400 include_latest=True)
471 include_latest=True)
401
472
402 elif hist_access_type == 'range':
473 elif hist_access_type == 'range':
403 session = parent['content']['session']
474 session = parent['content']['session']
404 start = parent['content']['start']
475 start = parent['content']['start']
405 stop = parent['content']['stop']
476 stop = parent['content']['stop']
406 hist = self.shell.history_manager.get_range(session, start, stop,
477 hist = self.shell.history_manager.get_range(session, start, stop,
407 raw=raw, output=output)
478 raw=raw, output=output)
408
479
409 elif hist_access_type == 'search':
480 elif hist_access_type == 'search':
410 pattern = parent['content']['pattern']
481 pattern = parent['content']['pattern']
411 hist = self.shell.history_manager.search(pattern, raw=raw,
482 hist = self.shell.history_manager.search(pattern, raw=raw,
412 output=output)
483 output=output)
413
484
414 else:
485 else:
415 hist = []
486 hist = []
416 hist = list(hist)
487 hist = list(hist)
417 content = {'history' : hist}
488 content = {'history' : hist}
418 content = json_clean(content)
489 content = json_clean(content)
419 msg = self.session.send(self.shell_socket, 'history_reply',
490 msg = self.session.send(stream, 'history_reply',
420 content, parent, ident)
491 content, parent, ident)
421 self.log.debug("Sending history reply with %i entries", len(hist))
492 self.log.debug("Sending history reply with %i entries", len(hist))
422
493
423 def connect_request(self, ident, parent):
494 def connect_request(self, stream, ident, parent):
424 if self._recorded_ports is not None:
495 if self._recorded_ports is not None:
425 content = self._recorded_ports.copy()
496 content = self._recorded_ports.copy()
426 else:
497 else:
427 content = {}
498 content = {}
428 msg = self.session.send(self.shell_socket, 'connect_reply',
499 msg = self.session.send(stream, 'connect_reply',
429 content, parent, ident)
500 content, parent, ident)
430 self.log.debug(msg)
501 self.log.debug("%s", msg)
431
502
432 def shutdown_request(self, ident, parent):
503 def shutdown_request(self, stream, ident, parent):
433 self.shell.exit_now = True
504 self.shell.exit_now = True
505 content = dict(status='ok')
506 content.update(parent['content'])
507 self.session.send(stream, u'shutdown_reply', content, parent, ident=ident)
508 # same content, but different msg_id for broadcasting on IOPub
434 self._shutdown_message = self.session.msg(u'shutdown_reply',
509 self._shutdown_message = self.session.msg(u'shutdown_reply',
435 parent['content'], parent)
510 content, parent
436 sys.exit(0)
511 )
512
513 self._at_shutdown()
514 # call sys.exit after a short delay
515 loop = ioloop.IOLoop.instance()
516 loop.add_timeout(time.time()+0.1, loop.stop)
517
518 #---------------------------------------------------------------------------
519 # Engine methods
520 #---------------------------------------------------------------------------
521
522 def apply_request(self, stream, ident, parent):
523 try:
524 content = parent[u'content']
525 bufs = parent[u'buffers']
526 msg_id = parent['header']['msg_id']
527 except:
528 self.log.error("Got bad msg: %s", parent, exc_info=True)
529 return
530
531 # Set the parent message of the display hook and out streams.
532 self.shell.displayhook.set_parent(parent)
533 self.shell.display_pub.set_parent(parent)
534 sys.stdout.set_parent(parent)
535 sys.stderr.set_parent(parent)
536
537 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
538 # self.iopub_socket.send(pyin_msg)
539 # self.session.send(self.iopub_socket, u'pyin', {u'code':code},parent=parent)
540 sub = self._make_subheader()
541 try:
542 working = self.shell.user_ns
543
544 prefix = "_"+str(msg_id).replace("-","")+"_"
545
546 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
547
548 fname = getattr(f, '__name__', 'f')
549
550 fname = prefix+"f"
551 argname = prefix+"args"
552 kwargname = prefix+"kwargs"
553 resultname = prefix+"result"
554
555 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
556 # print ns
557 working.update(ns)
558 code = "%s = %s(*%s,**%s)" % (resultname, fname, argname, kwargname)
559 try:
560 exec code in self.shell.user_global_ns, self.shell.user_ns
561 result = working.get(resultname)
562 finally:
563 for key in ns.iterkeys():
564 working.pop(key)
565
566 packed_result,buf = serialize_object(result)
567 result_buf = [packed_result]+buf
568 except:
569 exc_content = self._wrap_exception('apply')
570 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
571 self.session.send(self.iopub_socket, u'pyerr', exc_content, parent=parent,
572 ident=self._topic('pyerr'))
573 reply_content = exc_content
574 result_buf = []
575
576 if exc_content['ename'] == 'UnmetDependency':
577 sub['dependencies_met'] = False
578 else:
579 reply_content = {'status' : 'ok'}
580
581 # put 'ok'/'error' status in header, for scheduler introspection:
582 sub['status'] = reply_content['status']
583
584 # flush i/o
585 sys.stdout.flush()
586 sys.stderr.flush()
587
588 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
589 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
590
591 #---------------------------------------------------------------------------
592 # Control messages
593 #---------------------------------------------------------------------------
594
595 def abort_request(self, stream, ident, parent):
596 """abort a specifig msg by id"""
597 msg_ids = parent['content'].get('msg_ids', None)
598 if isinstance(msg_ids, basestring):
599 msg_ids = [msg_ids]
600 if not msg_ids:
601 self.abort_queues()
602 for mid in msg_ids:
603 self.aborted.add(str(mid))
604
605 content = dict(status='ok')
606 reply_msg = self.session.send(stream, 'abort_reply', content=content,
607 parent=parent, ident=ident)
608 self.log.debug("%s", reply_msg)
609
610 def clear_request(self, stream, idents, parent):
611 """Clear our namespace."""
612 self.shell.reset(False)
613 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
614 content = dict(status='ok'))
615
437
616
438 #---------------------------------------------------------------------------
617 #---------------------------------------------------------------------------
439 # Protected interface
618 # Protected interface
440 #---------------------------------------------------------------------------
619 #---------------------------------------------------------------------------
441
620
442 def _abort_queue(self):
621
622 def _wrap_exception(self, method=None):
623 # import here, because _wrap_exception is only used in parallel,
624 # and parallel has higher min pyzmq version
625 from IPython.parallel.error import wrap_exception
626 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
627 content = wrap_exception(e_info)
628 return content
629
630 def _topic(self, topic):
631 """prefixed topic for IOPub messages"""
632 if self.int_id >= 0:
633 base = "engine.%i" % self.int_id
634 else:
635 base = "kernel.%s" % self.ident
636
637 return py3compat.cast_bytes("%s.%s" % (base, topic))
638
639 def _abort_queues(self):
640 for stream in self.shell_streams:
641 if stream:
642 self._abort_queue(stream)
643
644 def _abort_queue(self, stream):
645 poller = zmq.Poller()
646 poller.register(stream.socket, zmq.POLLIN)
443 while True:
647 while True:
444 try:
648 idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
445 ident,msg = self.session.recv(self.shell_socket, zmq.NOBLOCK)
446 except Exception:
447 self.log.warn("Invalid Message:", exc_info=True)
448 continue
449 if msg is None:
649 if msg is None:
450 break
650 return
451 else:
452 assert ident is not None, \
453 "Unexpected missing message part."
454
651
455 self.log.debug("Aborting:\n"+str(Message(msg)))
652 self.log.info("Aborting:")
653 self.log.info("%s", msg)
456 msg_type = msg['header']['msg_type']
654 msg_type = msg['header']['msg_type']
457 reply_type = msg_type.split('_')[0] + '_reply'
655 reply_type = msg_type.split('_')[0] + '_reply'
458 reply_msg = self.session.send(self.shell_socket, reply_type,
656
459 {'status' : 'aborted'}, msg, ident=ident)
657 status = {'status' : 'aborted'}
460 self.log.debug(reply_msg)
658 sub = {'engine' : self.ident}
659 sub.update(status)
660 reply_msg = self.session.send(stream, reply_type, subheader=sub,
661 content=status, parent=msg, ident=idents)
662 self.log.debug("%s", reply_msg)
461 # We need to wait a bit for requests to come in. This can probably
663 # We need to wait a bit for requests to come in. This can probably
462 # be set shorter for true asynchronous clients.
664 # be set shorter for true asynchronous clients.
463 time.sleep(0.1)
665 poller.poll(50)
666
464
667
465 def _no_raw_input(self):
668 def _no_raw_input(self):
466 """Raise StdinNotImplentedError if active frontend doesn't support
669 """Raise StdinNotImplentedError if active frontend doesn't support
467 stdin."""
670 stdin."""
468 raise StdinNotImplementedError("raw_input was called, but this "
671 raise StdinNotImplementedError("raw_input was called, but this "
469 "frontend does not support stdin.")
672 "frontend does not support stdin.")
470
673
471 def _raw_input(self, prompt, ident, parent):
674 def _raw_input(self, prompt, ident, parent):
472 # Flush output before making the request.
675 # Flush output before making the request.
473 sys.stderr.flush()
676 sys.stderr.flush()
474 sys.stdout.flush()
677 sys.stdout.flush()
475
678
476 # Send the input request.
679 # Send the input request.
477 content = json_clean(dict(prompt=prompt))
680 content = json_clean(dict(prompt=prompt))
478 self.session.send(self.stdin_socket, u'input_request', content, parent,
681 self.session.send(self.stdin_socket, u'input_request', content, parent,
479 ident=ident)
682 ident=ident)
480
683
481 # Await a response.
684 # Await a response.
482 while True:
685 while True:
483 try:
686 try:
484 ident, reply = self.session.recv(self.stdin_socket, 0)
687 ident, reply = self.session.recv(self.stdin_socket, 0)
485 except Exception:
688 except Exception:
486 self.log.warn("Invalid Message:", exc_info=True)
689 self.log.warn("Invalid Message:", exc_info=True)
487 else:
690 else:
488 break
691 break
489 try:
692 try:
490 value = reply['content']['value']
693 value = reply['content']['value']
491 except:
694 except:
492 self.log.error("Got bad raw_input reply: ")
695 self.log.error("Got bad raw_input reply: ")
493 self.log.error(str(Message(parent)))
696 self.log.error("%s", parent)
494 value = ''
697 value = ''
495 if value == '\x04':
698 if value == '\x04':
496 # EOF
699 # EOF
497 raise EOFError
700 raise EOFError
498 return value
701 return value
499
702
500 def _complete(self, msg):
703 def _complete(self, msg):
501 c = msg['content']
704 c = msg['content']
502 try:
705 try:
503 cpos = int(c['cursor_pos'])
706 cpos = int(c['cursor_pos'])
504 except:
707 except:
505 # If we don't get something that we can convert to an integer, at
708 # If we don't get something that we can convert to an integer, at
506 # least attempt the completion guessing the cursor is at the end of
709 # least attempt the completion guessing the cursor is at the end of
507 # the text, if there's any, and otherwise of the line
710 # the text, if there's any, and otherwise of the line
508 cpos = len(c['text'])
711 cpos = len(c['text'])
509 if cpos==0:
712 if cpos==0:
510 cpos = len(c['line'])
713 cpos = len(c['line'])
511 return self.shell.complete(c['text'], c['line'], cpos)
714 return self.shell.complete(c['text'], c['line'], cpos)
512
715
513 def _object_info(self, context):
716 def _object_info(self, context):
514 symbol, leftover = self._symbol_from_context(context)
717 symbol, leftover = self._symbol_from_context(context)
515 if symbol is not None and not leftover:
718 if symbol is not None and not leftover:
516 doc = getattr(symbol, '__doc__', '')
719 doc = getattr(symbol, '__doc__', '')
517 else:
720 else:
518 doc = ''
721 doc = ''
519 object_info = dict(docstring = doc)
722 object_info = dict(docstring = doc)
520 return object_info
723 return object_info
521
724
522 def _symbol_from_context(self, context):
725 def _symbol_from_context(self, context):
523 if not context:
726 if not context:
524 return None, context
727 return None, context
525
728
526 base_symbol_string = context[0]
729 base_symbol_string = context[0]
527 symbol = self.shell.user_ns.get(base_symbol_string, None)
730 symbol = self.shell.user_ns.get(base_symbol_string, None)
528 if symbol is None:
731 if symbol is None:
529 symbol = __builtin__.__dict__.get(base_symbol_string, None)
732 symbol = __builtin__.__dict__.get(base_symbol_string, None)
530 if symbol is None:
733 if symbol is None:
531 return None, context
734 return None, context
532
735
533 context = context[1:]
736 context = context[1:]
534 for i, name in enumerate(context):
737 for i, name in enumerate(context):
535 new_symbol = getattr(symbol, name, None)
738 new_symbol = getattr(symbol, name, None)
536 if new_symbol is None:
739 if new_symbol is None:
537 return symbol, context[i:]
740 return symbol, context[i:]
538 else:
741 else:
539 symbol = new_symbol
742 symbol = new_symbol
540
743
541 return symbol, []
744 return symbol, []
542
745
543 def _at_shutdown(self):
746 def _at_shutdown(self):
544 """Actions taken at shutdown by the kernel, called by python's atexit.
747 """Actions taken at shutdown by the kernel, called by python's atexit.
545 """
748 """
546 # io.rprint("Kernel at_shutdown") # dbg
749 # io.rprint("Kernel at_shutdown") # dbg
547 if self._shutdown_message is not None:
750 if self._shutdown_message is not None:
548 self.session.send(self.shell_socket, self._shutdown_message)
751 self.session.send(self.iopub_socket, self._shutdown_message, ident=self._topic('shutdown'))
549 self.session.send(self.iopub_socket, self._shutdown_message)
752 self.log.debug("%s", self._shutdown_message)
550 self.log.debug(str(self._shutdown_message))
753 [ s.flush(zmq.POLLOUT) for s in self.shell_streams ]
551 # A very short sleep to give zmq time to flush its message buffers
552 # before Python truly shuts down.
553 time.sleep(0.01)
554
754
555 #-----------------------------------------------------------------------------
755 #-----------------------------------------------------------------------------
556 # Aliases and Flags for the IPKernelApp
756 # Aliases and Flags for the IPKernelApp
557 #-----------------------------------------------------------------------------
757 #-----------------------------------------------------------------------------
558
758
559 flags = dict(kernel_flags)
759 flags = dict(kernel_flags)
560 flags.update(shell_flags)
760 flags.update(shell_flags)
561
761
562 addflag = lambda *args: flags.update(boolean_flag(*args))
762 addflag = lambda *args: flags.update(boolean_flag(*args))
563
763
564 flags['pylab'] = (
764 flags['pylab'] = (
565 {'IPKernelApp' : {'pylab' : 'auto'}},
765 {'IPKernelApp' : {'pylab' : 'auto'}},
566 """Pre-load matplotlib and numpy for interactive use with
766 """Pre-load matplotlib and numpy for interactive use with
567 the default matplotlib backend."""
767 the default matplotlib backend."""
568 )
768 )
569
769
570 aliases = dict(kernel_aliases)
770 aliases = dict(kernel_aliases)
571 aliases.update(shell_aliases)
771 aliases.update(shell_aliases)
572
772
573 # it's possible we don't want short aliases for *all* of these:
773 # it's possible we don't want short aliases for *all* of these:
574 aliases.update(dict(
774 aliases.update(dict(
575 pylab='IPKernelApp.pylab',
775 pylab='IPKernelApp.pylab',
576 ))
776 ))
577
777
578 #-----------------------------------------------------------------------------
778 #-----------------------------------------------------------------------------
579 # The IPKernelApp class
779 # The IPKernelApp class
580 #-----------------------------------------------------------------------------
780 #-----------------------------------------------------------------------------
581
781
582 class IPKernelApp(KernelApp, InteractiveShellApp):
782 class IPKernelApp(KernelApp, InteractiveShellApp):
583 name = 'ipkernel'
783 name = 'ipkernel'
584
784
585 aliases = Dict(aliases)
785 aliases = Dict(aliases)
586 flags = Dict(flags)
786 flags = Dict(flags)
587 classes = [Kernel, ZMQInteractiveShell, ProfileDir, Session]
787 classes = [Kernel, ZMQInteractiveShell, ProfileDir, Session]
588
788
589 # configurables
789 # configurables
590 pylab = CaselessStrEnum(['tk', 'qt', 'wx', 'gtk', 'osx', 'inline', 'auto'],
790 pylab = CaselessStrEnum(['tk', 'qt', 'wx', 'gtk', 'osx', 'inline', 'auto'],
591 config=True,
791 config=True,
592 help="""Pre-load matplotlib and numpy for interactive use,
792 help="""Pre-load matplotlib and numpy for interactive use,
593 selecting a particular matplotlib backend and loop integration.
793 selecting a particular matplotlib backend and loop integration.
594 """
794 """
595 )
795 )
596
796
597 @catch_config_error
797 @catch_config_error
598 def initialize(self, argv=None):
798 def initialize(self, argv=None):
599 super(IPKernelApp, self).initialize(argv)
799 super(IPKernelApp, self).initialize(argv)
600 self.init_path()
800 self.init_path()
601 self.init_shell()
801 self.init_shell()
602 self.init_extensions()
802 self.init_extensions()
603 self.init_code()
803 self.init_code()
604
804
605 def init_kernel(self):
805 def init_kernel(self):
806
807 shell_stream = ZMQStream(self.shell_socket)
606
808
607 kernel = Kernel(config=self.config, session=self.session,
809 kernel = Kernel(config=self.config, session=self.session,
608 shell_socket=self.shell_socket,
810 shell_streams=[shell_stream],
609 iopub_socket=self.iopub_socket,
811 iopub_socket=self.iopub_socket,
610 stdin_socket=self.stdin_socket,
812 stdin_socket=self.stdin_socket,
611 log=self.log,
813 log=self.log,
612 profile_dir=self.profile_dir,
814 profile_dir=self.profile_dir,
613 )
815 )
614 self.kernel = kernel
816 self.kernel = kernel
615 kernel.record_ports(self.ports)
817 kernel.record_ports(self.ports)
616 shell = kernel.shell
818 shell = kernel.shell
617 if self.pylab:
819 if self.pylab:
618 try:
820 try:
619 gui, backend = pylabtools.find_gui_and_backend(self.pylab)
821 gui, backend = pylabtools.find_gui_and_backend(self.pylab)
620 shell.enable_pylab(gui, import_all=self.pylab_import_all)
822 shell.enable_pylab(gui, import_all=self.pylab_import_all)
621 except Exception:
823 except Exception:
622 self.log.error("Pylab initialization failed", exc_info=True)
824 self.log.error("Pylab initialization failed", exc_info=True)
623 # print exception straight to stdout, because normally
825 # print exception straight to stdout, because normally
624 # _showtraceback associates the reply with an execution,
826 # _showtraceback associates the reply with an execution,
625 # which means frontends will never draw it, as this exception
827 # which means frontends will never draw it, as this exception
626 # is not associated with any execute request.
828 # is not associated with any execute request.
627
829
628 # replace pyerr-sending traceback with stdout
830 # replace pyerr-sending traceback with stdout
629 _showtraceback = shell._showtraceback
831 _showtraceback = shell._showtraceback
630 def print_tb(etype, evalue, stb):
832 def print_tb(etype, evalue, stb):
631 print ("Error initializing pylab, pylab mode will not "
833 print ("Error initializing pylab, pylab mode will not "
632 "be active", file=io.stderr)
834 "be active", file=io.stderr)
633 print (shell.InteractiveTB.stb2text(stb), file=io.stdout)
835 print (shell.InteractiveTB.stb2text(stb), file=io.stdout)
634 shell._showtraceback = print_tb
836 shell._showtraceback = print_tb
635
837
636 # send the traceback over stdout
838 # send the traceback over stdout
637 shell.showtraceback(tb_offset=0)
839 shell.showtraceback(tb_offset=0)
638
840
639 # restore proper _showtraceback method
841 # restore proper _showtraceback method
640 shell._showtraceback = _showtraceback
842 shell._showtraceback = _showtraceback
641
843
642
844
643 def init_shell(self):
845 def init_shell(self):
644 self.shell = self.kernel.shell
846 self.shell = self.kernel.shell
645 self.shell.configurables.append(self)
847 self.shell.configurables.append(self)
646
848
647
849
648 #-----------------------------------------------------------------------------
850 #-----------------------------------------------------------------------------
649 # Kernel main and launch functions
851 # Kernel main and launch functions
650 #-----------------------------------------------------------------------------
852 #-----------------------------------------------------------------------------
651
853
652 def launch_kernel(*args, **kwargs):
854 def launch_kernel(*args, **kwargs):
653 """Launches a localhost IPython kernel, binding to the specified ports.
855 """Launches a localhost IPython kernel, binding to the specified ports.
654
856
655 This function simply calls entry_point.base_launch_kernel with the right
857 This function simply calls entry_point.base_launch_kernel with the right
656 first command to start an ipkernel. See base_launch_kernel for arguments.
858 first command to start an ipkernel. See base_launch_kernel for arguments.
657
859
658 Returns
860 Returns
659 -------
861 -------
660 A tuple of form:
862 A tuple of form:
661 (kernel_process, shell_port, iopub_port, stdin_port, hb_port)
863 (kernel_process, shell_port, iopub_port, stdin_port, hb_port)
662 where kernel_process is a Popen object and the ports are integers.
864 where kernel_process is a Popen object and the ports are integers.
663 """
865 """
664 return base_launch_kernel('from IPython.zmq.ipkernel import main; main()',
866 return base_launch_kernel('from IPython.zmq.ipkernel import main; main()',
665 *args, **kwargs)
867 *args, **kwargs)
666
868
667
869
668 def embed_kernel(module=None, local_ns=None, **kwargs):
870 def embed_kernel(module=None, local_ns=None, **kwargs):
669 """Embed and start an IPython kernel in a given scope.
871 """Embed and start an IPython kernel in a given scope.
670
872
671 Parameters
873 Parameters
672 ----------
874 ----------
673 module : ModuleType, optional
875 module : ModuleType, optional
674 The module to load into IPython globals (default: caller)
876 The module to load into IPython globals (default: caller)
675 local_ns : dict, optional
877 local_ns : dict, optional
676 The namespace to load into IPython user namespace (default: caller)
878 The namespace to load into IPython user namespace (default: caller)
677
879
678 kwargs : various, optional
880 kwargs : various, optional
679 Further keyword args are relayed to the KernelApp constructor,
881 Further keyword args are relayed to the KernelApp constructor,
680 allowing configuration of the Kernel. Will only have an effect
882 allowing configuration of the Kernel. Will only have an effect
681 on the first embed_kernel call for a given process.
883 on the first embed_kernel call for a given process.
682
884
683 """
885 """
684 # get the app if it exists, or set it up if it doesn't
886 # get the app if it exists, or set it up if it doesn't
685 if IPKernelApp.initialized():
887 if IPKernelApp.initialized():
686 app = IPKernelApp.instance()
888 app = IPKernelApp.instance()
687 else:
889 else:
688 app = IPKernelApp.instance(**kwargs)
890 app = IPKernelApp.instance(**kwargs)
689 app.initialize([])
891 app.initialize([])
892 # Undo unnecessary sys module mangling from init_sys_modules.
893 # This would not be necessary if we could prevent it
894 # in the first place by using a different InteractiveShell
895 # subclass, as in the regular embed case.
896 main = app.kernel.shell._orig_sys_modules_main_mod
897 if main is not None:
898 sys.modules[app.kernel.shell._orig_sys_modules_main_name] = main
690
899
691 # load the calling scope if not given
900 # load the calling scope if not given
692 (caller_module, caller_locals) = extract_module_locals(1)
901 (caller_module, caller_locals) = extract_module_locals(1)
693 if module is None:
902 if module is None:
694 module = caller_module
903 module = caller_module
695 if local_ns is None:
904 if local_ns is None:
696 local_ns = caller_locals
905 local_ns = caller_locals
697
906
698 app.kernel.user_module = module
907 app.kernel.user_module = module
699 app.kernel.user_ns = local_ns
908 app.kernel.user_ns = local_ns
700 app.start()
909 app.start()
701
910
702 def main():
911 def main():
703 """Run an IPKernel as an application"""
912 """Run an IPKernel as an application"""
704 app = IPKernelApp.instance()
913 app = IPKernelApp.instance()
705 app.initialize()
914 app.initialize()
706 app.start()
915 app.start()
707
916
708
917
709 if __name__ == '__main__':
918 if __name__ == '__main__':
710 main()
919 main()
@@ -1,306 +1,333 b''
1 """An Application for launching a kernel
1 """An Application for launching a kernel
2
2
3 Authors
3 Authors
4 -------
4 -------
5 * MinRK
5 * MinRK
6 """
6 """
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2011 The IPython Development Team
8 # Copyright (C) 2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING.txt, distributed as part of this software.
11 # the file COPYING.txt, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 # Standard library imports.
18 # Standard library imports
19 import atexit
19 import json
20 import json
20 import os
21 import os
21 import sys
22 import sys
23 import signal
22
24
23 # System library imports.
25 # System library imports
24 import zmq
26 import zmq
27 from zmq.eventloop import ioloop
25
28
26 # IPython imports.
29 # IPython imports
27 from IPython.core.ultratb import FormattedTB
30 from IPython.core.ultratb import FormattedTB
28 from IPython.core.application import (
31 from IPython.core.application import (
29 BaseIPythonApplication, base_flags, base_aliases, catch_config_error
32 BaseIPythonApplication, base_flags, base_aliases, catch_config_error
30 )
33 )
31 from IPython.utils import io
34 from IPython.utils import io
32 from IPython.utils.localinterfaces import LOCALHOST
35 from IPython.utils.localinterfaces import LOCALHOST
33 from IPython.utils.path import filefind
36 from IPython.utils.path import filefind
34 from IPython.utils.py3compat import str_to_bytes
37 from IPython.utils.py3compat import str_to_bytes
35 from IPython.utils.traitlets import (Any, Instance, Dict, Unicode, Integer, Bool,
38 from IPython.utils.traitlets import (Any, Instance, Dict, Unicode, Integer, Bool,
36 DottedObjectName)
39 DottedObjectName)
37 from IPython.utils.importstring import import_item
40 from IPython.utils.importstring import import_item
38 # local imports
41 # local imports
39 from IPython.zmq.entry_point import write_connection_file
42 from IPython.zmq.entry_point import write_connection_file
40 from IPython.zmq.heartbeat import Heartbeat
43 from IPython.zmq.heartbeat import Heartbeat
41 from IPython.zmq.parentpoller import ParentPollerUnix, ParentPollerWindows
44 from IPython.zmq.parentpoller import ParentPollerUnix, ParentPollerWindows
42 from IPython.zmq.session import (
45 from IPython.zmq.session import (
43 Session, session_flags, session_aliases, default_secure,
46 Session, session_flags, session_aliases, default_secure,
44 )
47 )
45
48
46
49
47 #-----------------------------------------------------------------------------
50 #-----------------------------------------------------------------------------
48 # Flags and Aliases
51 # Flags and Aliases
49 #-----------------------------------------------------------------------------
52 #-----------------------------------------------------------------------------
50
53
51 kernel_aliases = dict(base_aliases)
54 kernel_aliases = dict(base_aliases)
52 kernel_aliases.update({
55 kernel_aliases.update({
53 'ip' : 'KernelApp.ip',
56 'ip' : 'KernelApp.ip',
54 'hb' : 'KernelApp.hb_port',
57 'hb' : 'KernelApp.hb_port',
55 'shell' : 'KernelApp.shell_port',
58 'shell' : 'KernelApp.shell_port',
56 'iopub' : 'KernelApp.iopub_port',
59 'iopub' : 'KernelApp.iopub_port',
57 'stdin' : 'KernelApp.stdin_port',
60 'stdin' : 'KernelApp.stdin_port',
58 'f' : 'KernelApp.connection_file',
61 'f' : 'KernelApp.connection_file',
59 'parent': 'KernelApp.parent',
62 'parent': 'KernelApp.parent',
60 })
63 })
61 if sys.platform.startswith('win'):
64 if sys.platform.startswith('win'):
62 kernel_aliases['interrupt'] = 'KernelApp.interrupt'
65 kernel_aliases['interrupt'] = 'KernelApp.interrupt'
63
66
64 kernel_flags = dict(base_flags)
67 kernel_flags = dict(base_flags)
65 kernel_flags.update({
68 kernel_flags.update({
66 'no-stdout' : (
69 'no-stdout' : (
67 {'KernelApp' : {'no_stdout' : True}},
70 {'KernelApp' : {'no_stdout' : True}},
68 "redirect stdout to the null device"),
71 "redirect stdout to the null device"),
69 'no-stderr' : (
72 'no-stderr' : (
70 {'KernelApp' : {'no_stderr' : True}},
73 {'KernelApp' : {'no_stderr' : True}},
71 "redirect stderr to the null device"),
74 "redirect stderr to the null device"),
72 })
75 })
73
76
74 # inherit flags&aliases for Sessions
77 # inherit flags&aliases for Sessions
75 kernel_aliases.update(session_aliases)
78 kernel_aliases.update(session_aliases)
76 kernel_flags.update(session_flags)
79 kernel_flags.update(session_flags)
77
80
78
81
79
82
80 #-----------------------------------------------------------------------------
83 #-----------------------------------------------------------------------------
81 # Application class for starting a Kernel
84 # Application class for starting a Kernel
82 #-----------------------------------------------------------------------------
85 #-----------------------------------------------------------------------------
83
86
84 class KernelApp(BaseIPythonApplication):
87 class KernelApp(BaseIPythonApplication):
85 name='pykernel'
88 name='ipkernel'
86 aliases = Dict(kernel_aliases)
89 aliases = Dict(kernel_aliases)
87 flags = Dict(kernel_flags)
90 flags = Dict(kernel_flags)
88 classes = [Session]
91 classes = [Session]
89 # the kernel class, as an importstring
92 # the kernel class, as an importstring
90 kernel_class = DottedObjectName('IPython.zmq.pykernel.Kernel')
93 kernel_class = DottedObjectName('IPython.zmq.ipkernel.Kernel')
91 kernel = Any()
94 kernel = Any()
92 poller = Any() # don't restrict this even though current pollers are all Threads
95 poller = Any() # don't restrict this even though current pollers are all Threads
93 heartbeat = Instance(Heartbeat)
96 heartbeat = Instance(Heartbeat)
94 session = Instance('IPython.zmq.session.Session')
97 session = Instance('IPython.zmq.session.Session')
95 ports = Dict()
98 ports = Dict()
99 _full_connection_file = Unicode()
96
100
97 # inherit config file name from parent:
101 # inherit config file name from parent:
98 parent_appname = Unicode(config=True)
102 parent_appname = Unicode(config=True)
99 def _parent_appname_changed(self, name, old, new):
103 def _parent_appname_changed(self, name, old, new):
100 if self.config_file_specified:
104 if self.config_file_specified:
101 # it was manually specified, ignore
105 # it was manually specified, ignore
102 return
106 return
103 self.config_file_name = new.replace('-','_') + u'_config.py'
107 self.config_file_name = new.replace('-','_') + u'_config.py'
104 # don't let this count as specifying the config file
108 # don't let this count as specifying the config file
105 self.config_file_specified = False
109 self.config_file_specified = False
106
110
107 # connection info:
111 # connection info:
108 ip = Unicode(LOCALHOST, config=True,
112 ip = Unicode(LOCALHOST, config=True,
109 help="Set the IP or interface on which the kernel will listen.")
113 help="Set the IP or interface on which the kernel will listen.")
110 hb_port = Integer(0, config=True, help="set the heartbeat port [default: random]")
114 hb_port = Integer(0, config=True, help="set the heartbeat port [default: random]")
111 shell_port = Integer(0, config=True, help="set the shell (XREP) port [default: random]")
115 shell_port = Integer(0, config=True, help="set the shell (XREP) port [default: random]")
112 iopub_port = Integer(0, config=True, help="set the iopub (PUB) port [default: random]")
116 iopub_port = Integer(0, config=True, help="set the iopub (PUB) port [default: random]")
113 stdin_port = Integer(0, config=True, help="set the stdin (XREQ) port [default: random]")
117 stdin_port = Integer(0, config=True, help="set the stdin (XREQ) port [default: random]")
114 connection_file = Unicode('', config=True,
118 connection_file = Unicode('', config=True,
115 help="""JSON file in which to store connection info [default: kernel-<pid>.json]
119 help="""JSON file in which to store connection info [default: kernel-<pid>.json]
116
120
117 This file will contain the IP, ports, and authentication key needed to connect
121 This file will contain the IP, ports, and authentication key needed to connect
118 clients to this kernel. By default, this file will be created in the security-dir
122 clients to this kernel. By default, this file will be created in the security-dir
119 of the current profile, but can be specified by absolute path.
123 of the current profile, but can be specified by absolute path.
120 """)
124 """)
121
125
122 # streams, etc.
126 # streams, etc.
123 no_stdout = Bool(False, config=True, help="redirect stdout to the null device")
127 no_stdout = Bool(False, config=True, help="redirect stdout to the null device")
124 no_stderr = Bool(False, config=True, help="redirect stderr to the null device")
128 no_stderr = Bool(False, config=True, help="redirect stderr to the null device")
125 outstream_class = DottedObjectName('IPython.zmq.iostream.OutStream',
129 outstream_class = DottedObjectName('IPython.zmq.iostream.OutStream',
126 config=True, help="The importstring for the OutStream factory")
130 config=True, help="The importstring for the OutStream factory")
127 displayhook_class = DottedObjectName('IPython.zmq.displayhook.ZMQDisplayHook',
131 displayhook_class = DottedObjectName('IPython.zmq.displayhook.ZMQDisplayHook',
128 config=True, help="The importstring for the DisplayHook factory")
132 config=True, help="The importstring for the DisplayHook factory")
129
133
130 # polling
134 # polling
131 parent = Integer(0, config=True,
135 parent = Integer(0, config=True,
132 help="""kill this process if its parent dies. On Windows, the argument
136 help="""kill this process if its parent dies. On Windows, the argument
133 specifies the HANDLE of the parent process, otherwise it is simply boolean.
137 specifies the HANDLE of the parent process, otherwise it is simply boolean.
134 """)
138 """)
135 interrupt = Integer(0, config=True,
139 interrupt = Integer(0, config=True,
136 help="""ONLY USED ON WINDOWS
140 help="""ONLY USED ON WINDOWS
137 Interrupt this process when the parent is signalled.
141 Interrupt this process when the parent is signalled.
138 """)
142 """)
139
143
140 def init_crash_handler(self):
144 def init_crash_handler(self):
141 # Install minimal exception handling
145 # Install minimal exception handling
142 sys.excepthook = FormattedTB(mode='Verbose', color_scheme='NoColor',
146 sys.excepthook = FormattedTB(mode='Verbose', color_scheme='NoColor',
143 ostream=sys.__stdout__)
147 ostream=sys.__stdout__)
144
148
145 def init_poller(self):
149 def init_poller(self):
146 if sys.platform == 'win32':
150 if sys.platform == 'win32':
147 if self.interrupt or self.parent:
151 if self.interrupt or self.parent:
148 self.poller = ParentPollerWindows(self.interrupt, self.parent)
152 self.poller = ParentPollerWindows(self.interrupt, self.parent)
149 elif self.parent:
153 elif self.parent:
150 self.poller = ParentPollerUnix()
154 self.poller = ParentPollerUnix()
151
155
152 def _bind_socket(self, s, port):
156 def _bind_socket(self, s, port):
153 iface = 'tcp://%s' % self.ip
157 iface = 'tcp://%s' % self.ip
154 if port <= 0:
158 if port <= 0:
155 port = s.bind_to_random_port(iface)
159 port = s.bind_to_random_port(iface)
156 else:
160 else:
157 s.bind(iface + ':%i'%port)
161 s.bind(iface + ':%i'%port)
158 return port
162 return port
159
163
160 def load_connection_file(self):
164 def load_connection_file(self):
161 """load ip/port/hmac config from JSON connection file"""
165 """load ip/port/hmac config from JSON connection file"""
162 try:
166 try:
163 fname = filefind(self.connection_file, ['.', self.profile_dir.security_dir])
167 fname = filefind(self.connection_file, ['.', self.profile_dir.security_dir])
164 except IOError:
168 except IOError:
165 self.log.debug("Connection file not found: %s", self.connection_file)
169 self.log.debug("Connection file not found: %s", self.connection_file)
170 # This means I own it, so I will clean it up:
171 atexit.register(self.cleanup_connection_file)
166 return
172 return
167 self.log.debug(u"Loading connection file %s", fname)
173 self.log.debug(u"Loading connection file %s", fname)
168 with open(fname) as f:
174 with open(fname) as f:
169 s = f.read()
175 s = f.read()
170 cfg = json.loads(s)
176 cfg = json.loads(s)
171 if self.ip == LOCALHOST and 'ip' in cfg:
177 if self.ip == LOCALHOST and 'ip' in cfg:
172 # not overridden by config or cl_args
178 # not overridden by config or cl_args
173 self.ip = cfg['ip']
179 self.ip = cfg['ip']
174 for channel in ('hb', 'shell', 'iopub', 'stdin'):
180 for channel in ('hb', 'shell', 'iopub', 'stdin'):
175 name = channel + '_port'
181 name = channel + '_port'
176 if getattr(self, name) == 0 and name in cfg:
182 if getattr(self, name) == 0 and name in cfg:
177 # not overridden by config or cl_args
183 # not overridden by config or cl_args
178 setattr(self, name, cfg[name])
184 setattr(self, name, cfg[name])
179 if 'key' in cfg:
185 if 'key' in cfg:
180 self.config.Session.key = str_to_bytes(cfg['key'])
186 self.config.Session.key = str_to_bytes(cfg['key'])
181
187
182 def write_connection_file(self):
188 def write_connection_file(self):
183 """write connection info to JSON file"""
189 """write connection info to JSON file"""
184 if os.path.basename(self.connection_file) == self.connection_file:
190 if os.path.basename(self.connection_file) == self.connection_file:
185 cf = os.path.join(self.profile_dir.security_dir, self.connection_file)
191 cf = os.path.join(self.profile_dir.security_dir, self.connection_file)
186 else:
192 else:
187 cf = self.connection_file
193 cf = self.connection_file
188 write_connection_file(cf, ip=self.ip, key=self.session.key,
194 write_connection_file(cf, ip=self.ip, key=self.session.key,
189 shell_port=self.shell_port, stdin_port=self.stdin_port, hb_port=self.hb_port,
195 shell_port=self.shell_port, stdin_port=self.stdin_port, hb_port=self.hb_port,
190 iopub_port=self.iopub_port)
196 iopub_port=self.iopub_port)
197
198 self._full_connection_file = cf
199
200 def cleanup_connection_file(self):
201 cf = self._full_connection_file
202 self.log.debug("cleaning up connection file: %r", cf)
203 try:
204 os.remove(cf)
205 except (IOError, OSError):
206 pass
191
207
192 def init_connection_file(self):
208 def init_connection_file(self):
193 if not self.connection_file:
209 if not self.connection_file:
194 self.connection_file = "kernel-%s.json"%os.getpid()
210 self.connection_file = "kernel-%s.json"%os.getpid()
195 try:
211 try:
196 self.load_connection_file()
212 self.load_connection_file()
197 except Exception:
213 except Exception:
198 self.log.error("Failed to load connection file: %r", self.connection_file, exc_info=True)
214 self.log.error("Failed to load connection file: %r", self.connection_file, exc_info=True)
199 self.exit(1)
215 self.exit(1)
200
216
201 def init_sockets(self):
217 def init_sockets(self):
202 # Create a context, a session, and the kernel sockets.
218 # Create a context, a session, and the kernel sockets.
203 self.log.info("Starting the kernel at pid: %i", os.getpid())
219 self.log.info("Starting the kernel at pid: %i", os.getpid())
204 context = zmq.Context.instance()
220 context = zmq.Context.instance()
205 # Uncomment this to try closing the context.
221 # Uncomment this to try closing the context.
206 # atexit.register(context.term)
222 # atexit.register(context.term)
207
223
208 self.shell_socket = context.socket(zmq.ROUTER)
224 self.shell_socket = context.socket(zmq.ROUTER)
209 self.shell_port = self._bind_socket(self.shell_socket, self.shell_port)
225 self.shell_port = self._bind_socket(self.shell_socket, self.shell_port)
210 self.log.debug("shell ROUTER Channel on port: %i"%self.shell_port)
226 self.log.debug("shell ROUTER Channel on port: %i"%self.shell_port)
211
227
212 self.iopub_socket = context.socket(zmq.PUB)
228 self.iopub_socket = context.socket(zmq.PUB)
213 self.iopub_port = self._bind_socket(self.iopub_socket, self.iopub_port)
229 self.iopub_port = self._bind_socket(self.iopub_socket, self.iopub_port)
214 self.log.debug("iopub PUB Channel on port: %i"%self.iopub_port)
230 self.log.debug("iopub PUB Channel on port: %i"%self.iopub_port)
215
231
216 self.stdin_socket = context.socket(zmq.ROUTER)
232 self.stdin_socket = context.socket(zmq.ROUTER)
217 self.stdin_port = self._bind_socket(self.stdin_socket, self.stdin_port)
233 self.stdin_port = self._bind_socket(self.stdin_socket, self.stdin_port)
218 self.log.debug("stdin ROUTER Channel on port: %i"%self.stdin_port)
234 self.log.debug("stdin ROUTER Channel on port: %i"%self.stdin_port)
219
235
236 def init_heartbeat(self):
237 """start the heart beating"""
220 # heartbeat doesn't share context, because it mustn't be blocked
238 # heartbeat doesn't share context, because it mustn't be blocked
221 # by the GIL, which is accessed by libzmq when freeing zero-copy messages
239 # by the GIL, which is accessed by libzmq when freeing zero-copy messages
222 hb_ctx = zmq.Context()
240 hb_ctx = zmq.Context()
223 self.heartbeat = Heartbeat(hb_ctx, (self.ip, self.hb_port))
241 self.heartbeat = Heartbeat(hb_ctx, (self.ip, self.hb_port))
224 self.hb_port = self.heartbeat.port
242 self.hb_port = self.heartbeat.port
225 self.log.debug("Heartbeat REP Channel on port: %i"%self.hb_port)
243 self.log.debug("Heartbeat REP Channel on port: %i"%self.hb_port)
244 self.heartbeat.start()
226
245
227 # Helper to make it easier to connect to an existing kernel.
246 # Helper to make it easier to connect to an existing kernel.
228 # set log-level to critical, to make sure it is output
247 # set log-level to critical, to make sure it is output
229 self.log.critical("To connect another client to this kernel, use:")
248 self.log.critical("To connect another client to this kernel, use:")
230
249
250 def log_connection_info(self):
251 """display connection info, and store ports"""
231 basename = os.path.basename(self.connection_file)
252 basename = os.path.basename(self.connection_file)
232 if basename == self.connection_file or \
253 if basename == self.connection_file or \
233 os.path.dirname(self.connection_file) == self.profile_dir.security_dir:
254 os.path.dirname(self.connection_file) == self.profile_dir.security_dir:
234 # use shortname
255 # use shortname
235 tail = basename
256 tail = basename
236 if self.profile != 'default':
257 if self.profile != 'default':
237 tail += " --profile %s" % self.profile
258 tail += " --profile %s" % self.profile
238 else:
259 else:
239 tail = self.connection_file
260 tail = self.connection_file
240 self.log.critical("--existing %s", tail)
261 self.log.critical("--existing %s", tail)
241
262
242
263
243 self.ports = dict(shell=self.shell_port, iopub=self.iopub_port,
264 self.ports = dict(shell=self.shell_port, iopub=self.iopub_port,
244 stdin=self.stdin_port, hb=self.hb_port)
265 stdin=self.stdin_port, hb=self.hb_port)
245
266
246 def init_session(self):
267 def init_session(self):
247 """create our session object"""
268 """create our session object"""
248 default_secure(self.config)
269 default_secure(self.config)
249 self.session = Session(config=self.config, username=u'kernel')
270 self.session = Session(config=self.config, username=u'kernel')
250
271
251 def init_blackhole(self):
272 def init_blackhole(self):
252 """redirects stdout/stderr to devnull if necessary"""
273 """redirects stdout/stderr to devnull if necessary"""
253 if self.no_stdout or self.no_stderr:
274 if self.no_stdout or self.no_stderr:
254 blackhole = open(os.devnull, 'w')
275 blackhole = open(os.devnull, 'w')
255 if self.no_stdout:
276 if self.no_stdout:
256 sys.stdout = sys.__stdout__ = blackhole
277 sys.stdout = sys.__stdout__ = blackhole
257 if self.no_stderr:
278 if self.no_stderr:
258 sys.stderr = sys.__stderr__ = blackhole
279 sys.stderr = sys.__stderr__ = blackhole
259
280
260 def init_io(self):
281 def init_io(self):
261 """Redirect input streams and set a display hook."""
282 """Redirect input streams and set a display hook."""
262 if self.outstream_class:
283 if self.outstream_class:
263 outstream_factory = import_item(str(self.outstream_class))
284 outstream_factory = import_item(str(self.outstream_class))
264 sys.stdout = outstream_factory(self.session, self.iopub_socket, u'stdout')
285 sys.stdout = outstream_factory(self.session, self.iopub_socket, u'stdout')
265 sys.stderr = outstream_factory(self.session, self.iopub_socket, u'stderr')
286 sys.stderr = outstream_factory(self.session, self.iopub_socket, u'stderr')
266 if self.displayhook_class:
287 if self.displayhook_class:
267 displayhook_factory = import_item(str(self.displayhook_class))
288 displayhook_factory = import_item(str(self.displayhook_class))
268 sys.displayhook = displayhook_factory(self.session, self.iopub_socket)
289 sys.displayhook = displayhook_factory(self.session, self.iopub_socket)
269
290
291 def init_signal(self):
292 signal.signal(signal.SIGINT, signal.SIG_IGN)
293
270 def init_kernel(self):
294 def init_kernel(self):
271 """Create the Kernel object itself"""
295 """Create the Kernel object itself"""
272 kernel_factory = import_item(str(self.kernel_class))
296 kernel_factory = import_item(str(self.kernel_class))
273 self.kernel = kernel_factory(config=self.config, session=self.session,
297 self.kernel = kernel_factory(config=self.config, session=self.session,
274 shell_socket=self.shell_socket,
298 shell_socket=self.shell_socket,
275 iopub_socket=self.iopub_socket,
299 iopub_socket=self.iopub_socket,
276 stdin_socket=self.stdin_socket,
300 stdin_socket=self.stdin_socket,
277 log=self.log
301 log=self.log
278 )
302 )
279 self.kernel.record_ports(self.ports)
303 self.kernel.record_ports(self.ports)
280
304
281 @catch_config_error
305 @catch_config_error
282 def initialize(self, argv=None):
306 def initialize(self, argv=None):
283 super(KernelApp, self).initialize(argv)
307 super(KernelApp, self).initialize(argv)
284 self.init_blackhole()
308 self.init_blackhole()
285 self.init_connection_file()
309 self.init_connection_file()
286 self.init_session()
310 self.init_session()
287 self.init_poller()
311 self.init_poller()
288 self.init_sockets()
312 self.init_sockets()
289 # writing connection file must be *after* init_sockets
313 self.init_heartbeat()
314 # writing/displaying connection info must be *after* init_sockets/heartbeat
315 self.log_connection_info()
290 self.write_connection_file()
316 self.write_connection_file()
291 self.init_io()
317 self.init_io()
318 self.init_signal()
292 self.init_kernel()
319 self.init_kernel()
293 # flush stdout/stderr, so that anything written to these streams during
320 # flush stdout/stderr, so that anything written to these streams during
294 # initialization do not get associated with the first execution request
321 # initialization do not get associated with the first execution request
295 sys.stdout.flush()
322 sys.stdout.flush()
296 sys.stderr.flush()
323 sys.stderr.flush()
297
324
298 def start(self):
325 def start(self):
299 self.heartbeat.start()
300 if self.poller is not None:
326 if self.poller is not None:
301 self.poller.start()
327 self.poller.start()
328 self.kernel.start()
302 try:
329 try:
303 self.kernel.start()
330 ioloop.IOLoop.instance().start()
304 except KeyboardInterrupt:
331 except KeyboardInterrupt:
305 pass
332 pass
306
333
@@ -1,1000 +1,994 b''
1 """Base classes to manage the interaction with a running kernel.
1 """Base classes to manage the interaction with a running kernel.
2
2
3 TODO
3 TODO
4 * Create logger to handle debugging and console messages.
4 * Create logger to handle debugging and console messages.
5 """
5 """
6
6
7 #-----------------------------------------------------------------------------
7 #-----------------------------------------------------------------------------
8 # Copyright (C) 2008-2011 The IPython Development Team
8 # Copyright (C) 2008-2011 The IPython Development Team
9 #
9 #
10 # Distributed under the terms of the BSD License. The full license is in
10 # Distributed under the terms of the BSD License. The full license is in
11 # the file COPYING, distributed as part of this software.
11 # the file COPYING, distributed as part of this software.
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13
13
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15 # Imports
15 # Imports
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 # Standard library imports.
18 # Standard library imports.
19 import atexit
19 import atexit
20 import errno
20 import errno
21 import json
21 import json
22 from subprocess import Popen
22 from subprocess import Popen
23 import os
23 import os
24 import signal
24 import signal
25 import sys
25 import sys
26 from threading import Thread
26 from threading import Thread
27 import time
27 import time
28
28
29 # System library imports.
29 # System library imports.
30 import zmq
30 import zmq
31 # import ZMQError in top-level namespace, to avoid ugly attribute-error messages
31 # import ZMQError in top-level namespace, to avoid ugly attribute-error messages
32 # during garbage collection of threads at exit:
32 # during garbage collection of threads at exit:
33 from zmq import ZMQError
33 from zmq import ZMQError
34 from zmq.eventloop import ioloop, zmqstream
34 from zmq.eventloop import ioloop, zmqstream
35
35
36 # Local imports.
36 # Local imports.
37 from IPython.config.loader import Config
37 from IPython.config.loader import Config
38 from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS
38 from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS
39 from IPython.utils.traitlets import (
39 from IPython.utils.traitlets import (
40 HasTraits, Any, Instance, Type, Unicode, Integer, Bool
40 HasTraits, Any, Instance, Type, Unicode, Integer, Bool
41 )
41 )
42 from IPython.utils.py3compat import str_to_bytes
42 from IPython.utils.py3compat import str_to_bytes
43 from IPython.zmq.entry_point import write_connection_file
43 from IPython.zmq.entry_point import write_connection_file
44 from session import Session
44 from session import Session
45
45
46 #-----------------------------------------------------------------------------
46 #-----------------------------------------------------------------------------
47 # Constants and exceptions
47 # Constants and exceptions
48 #-----------------------------------------------------------------------------
48 #-----------------------------------------------------------------------------
49
49
50 class InvalidPortNumber(Exception):
50 class InvalidPortNumber(Exception):
51 pass
51 pass
52
52
53 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
54 # Utility functions
54 # Utility functions
55 #-----------------------------------------------------------------------------
55 #-----------------------------------------------------------------------------
56
56
57 # some utilities to validate message structure, these might get moved elsewhere
57 # some utilities to validate message structure, these might get moved elsewhere
58 # if they prove to have more generic utility
58 # if they prove to have more generic utility
59
59
60 def validate_string_list(lst):
60 def validate_string_list(lst):
61 """Validate that the input is a list of strings.
61 """Validate that the input is a list of strings.
62
62
63 Raises ValueError if not."""
63 Raises ValueError if not."""
64 if not isinstance(lst, list):
64 if not isinstance(lst, list):
65 raise ValueError('input %r must be a list' % lst)
65 raise ValueError('input %r must be a list' % lst)
66 for x in lst:
66 for x in lst:
67 if not isinstance(x, basestring):
67 if not isinstance(x, basestring):
68 raise ValueError('element %r in list must be a string' % x)
68 raise ValueError('element %r in list must be a string' % x)
69
69
70
70
71 def validate_string_dict(dct):
71 def validate_string_dict(dct):
72 """Validate that the input is a dict with string keys and values.
72 """Validate that the input is a dict with string keys and values.
73
73
74 Raises ValueError if not."""
74 Raises ValueError if not."""
75 for k,v in dct.iteritems():
75 for k,v in dct.iteritems():
76 if not isinstance(k, basestring):
76 if not isinstance(k, basestring):
77 raise ValueError('key %r in dict must be a string' % k)
77 raise ValueError('key %r in dict must be a string' % k)
78 if not isinstance(v, basestring):
78 if not isinstance(v, basestring):
79 raise ValueError('value %r in dict must be a string' % v)
79 raise ValueError('value %r in dict must be a string' % v)
80
80
81
81
82 #-----------------------------------------------------------------------------
82 #-----------------------------------------------------------------------------
83 # ZMQ Socket Channel classes
83 # ZMQ Socket Channel classes
84 #-----------------------------------------------------------------------------
84 #-----------------------------------------------------------------------------
85
85
86 class ZMQSocketChannel(Thread):
86 class ZMQSocketChannel(Thread):
87 """The base class for the channels that use ZMQ sockets.
87 """The base class for the channels that use ZMQ sockets.
88 """
88 """
89 context = None
89 context = None
90 session = None
90 session = None
91 socket = None
91 socket = None
92 ioloop = None
92 ioloop = None
93 stream = None
93 stream = None
94 _address = None
94 _address = None
95 _exiting = False
95 _exiting = False
96
96
97 def __init__(self, context, session, address):
97 def __init__(self, context, session, address):
98 """Create a channel
98 """Create a channel
99
99
100 Parameters
100 Parameters
101 ----------
101 ----------
102 context : :class:`zmq.Context`
102 context : :class:`zmq.Context`
103 The ZMQ context to use.
103 The ZMQ context to use.
104 session : :class:`session.Session`
104 session : :class:`session.Session`
105 The session to use.
105 The session to use.
106 address : tuple
106 address : tuple
107 Standard (ip, port) tuple that the kernel is listening on.
107 Standard (ip, port) tuple that the kernel is listening on.
108 """
108 """
109 super(ZMQSocketChannel, self).__init__()
109 super(ZMQSocketChannel, self).__init__()
110 self.daemon = True
110 self.daemon = True
111
111
112 self.context = context
112 self.context = context
113 self.session = session
113 self.session = session
114 if address[1] == 0:
114 if address[1] == 0:
115 message = 'The port number for a channel cannot be 0.'
115 message = 'The port number for a channel cannot be 0.'
116 raise InvalidPortNumber(message)
116 raise InvalidPortNumber(message)
117 self._address = address
117 self._address = address
118 atexit.register(self._notice_exit)
118 atexit.register(self._notice_exit)
119
119
120 def _notice_exit(self):
120 def _notice_exit(self):
121 self._exiting = True
121 self._exiting = True
122
122
123 def _run_loop(self):
123 def _run_loop(self):
124 """Run my loop, ignoring EINTR events in the poller"""
124 """Run my loop, ignoring EINTR events in the poller"""
125 while True:
125 while True:
126 try:
126 try:
127 self.ioloop.start()
127 self.ioloop.start()
128 except ZMQError as e:
128 except ZMQError as e:
129 if e.errno == errno.EINTR:
129 if e.errno == errno.EINTR:
130 continue
130 continue
131 else:
131 else:
132 raise
132 raise
133 except Exception:
133 except Exception:
134 if self._exiting:
134 if self._exiting:
135 break
135 break
136 else:
136 else:
137 raise
137 raise
138 else:
138 else:
139 break
139 break
140
140
141 def stop(self):
141 def stop(self):
142 """Stop the channel's activity.
142 """Stop the channel's activity.
143
143
144 This calls :method:`Thread.join` and returns when the thread
144 This calls :method:`Thread.join` and returns when the thread
145 terminates. :class:`RuntimeError` will be raised if
145 terminates. :class:`RuntimeError` will be raised if
146 :method:`self.start` is called again.
146 :method:`self.start` is called again.
147 """
147 """
148 self.join()
148 self.join()
149
149
150 @property
150 @property
151 def address(self):
151 def address(self):
152 """Get the channel's address as an (ip, port) tuple.
152 """Get the channel's address as an (ip, port) tuple.
153
153
154 By the default, the address is (localhost, 0), where 0 means a random
154 By the default, the address is (localhost, 0), where 0 means a random
155 port.
155 port.
156 """
156 """
157 return self._address
157 return self._address
158
158
159 def _queue_send(self, msg):
159 def _queue_send(self, msg):
160 """Queue a message to be sent from the IOLoop's thread.
160 """Queue a message to be sent from the IOLoop's thread.
161
161
162 Parameters
162 Parameters
163 ----------
163 ----------
164 msg : message to send
164 msg : message to send
165
165
166 This is threadsafe, as it uses IOLoop.add_callback to give the loop's
166 This is threadsafe, as it uses IOLoop.add_callback to give the loop's
167 thread control of the action.
167 thread control of the action.
168 """
168 """
169 def thread_send():
169 def thread_send():
170 self.session.send(self.stream, msg)
170 self.session.send(self.stream, msg)
171 self.ioloop.add_callback(thread_send)
171 self.ioloop.add_callback(thread_send)
172
172
173 def _handle_recv(self, msg):
173 def _handle_recv(self, msg):
174 """callback for stream.on_recv
174 """callback for stream.on_recv
175
175
176 unpacks message, and calls handlers with it.
176 unpacks message, and calls handlers with it.
177 """
177 """
178 ident,smsg = self.session.feed_identities(msg)
178 ident,smsg = self.session.feed_identities(msg)
179 self.call_handlers(self.session.unserialize(smsg))
179 self.call_handlers(self.session.unserialize(smsg))
180
180
181
181
182
182
183 class ShellSocketChannel(ZMQSocketChannel):
183 class ShellSocketChannel(ZMQSocketChannel):
184 """The XREQ channel for issues request/replies to the kernel.
184 """The XREQ channel for issues request/replies to the kernel.
185 """
185 """
186
186
187 command_queue = None
187 command_queue = None
188 # flag for whether execute requests should be allowed to call raw_input:
188 # flag for whether execute requests should be allowed to call raw_input:
189 allow_stdin = True
189 allow_stdin = True
190
190
191 def __init__(self, context, session, address):
191 def __init__(self, context, session, address):
192 super(ShellSocketChannel, self).__init__(context, session, address)
192 super(ShellSocketChannel, self).__init__(context, session, address)
193 self.ioloop = ioloop.IOLoop()
193 self.ioloop = ioloop.IOLoop()
194
194
195 def run(self):
195 def run(self):
196 """The thread's main activity. Call start() instead."""
196 """The thread's main activity. Call start() instead."""
197 self.socket = self.context.socket(zmq.DEALER)
197 self.socket = self.context.socket(zmq.DEALER)
198 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
198 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
199 self.socket.connect('tcp://%s:%i' % self.address)
199 self.socket.connect('tcp://%s:%i' % self.address)
200 self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
200 self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
201 self.stream.on_recv(self._handle_recv)
201 self.stream.on_recv(self._handle_recv)
202 self._run_loop()
202 self._run_loop()
203 try:
203 try:
204 self.socket.close()
204 self.socket.close()
205 except:
205 except:
206 pass
206 pass
207
207
208 def stop(self):
208 def stop(self):
209 self.ioloop.stop()
209 self.ioloop.stop()
210 super(ShellSocketChannel, self).stop()
210 super(ShellSocketChannel, self).stop()
211
211
212 def call_handlers(self, msg):
212 def call_handlers(self, msg):
213 """This method is called in the ioloop thread when a message arrives.
213 """This method is called in the ioloop thread when a message arrives.
214
214
215 Subclasses should override this method to handle incoming messages.
215 Subclasses should override this method to handle incoming messages.
216 It is important to remember that this method is called in the thread
216 It is important to remember that this method is called in the thread
217 so that some logic must be done to ensure that the application leve
217 so that some logic must be done to ensure that the application leve
218 handlers are called in the application thread.
218 handlers are called in the application thread.
219 """
219 """
220 raise NotImplementedError('call_handlers must be defined in a subclass.')
220 raise NotImplementedError('call_handlers must be defined in a subclass.')
221
221
222 def execute(self, code, silent=False,
222 def execute(self, code, silent=False,
223 user_variables=None, user_expressions=None, allow_stdin=None):
223 user_variables=None, user_expressions=None, allow_stdin=None):
224 """Execute code in the kernel.
224 """Execute code in the kernel.
225
225
226 Parameters
226 Parameters
227 ----------
227 ----------
228 code : str
228 code : str
229 A string of Python code.
229 A string of Python code.
230
230
231 silent : bool, optional (default False)
231 silent : bool, optional (default False)
232 If set, the kernel will execute the code as quietly possible.
232 If set, the kernel will execute the code as quietly possible.
233
233
234 user_variables : list, optional
234 user_variables : list, optional
235 A list of variable names to pull from the user's namespace. They
235 A list of variable names to pull from the user's namespace. They
236 will come back as a dict with these names as keys and their
236 will come back as a dict with these names as keys and their
237 :func:`repr` as values.
237 :func:`repr` as values.
238
238
239 user_expressions : dict, optional
239 user_expressions : dict, optional
240 A dict with string keys and to pull from the user's
240 A dict with string keys and to pull from the user's
241 namespace. They will come back as a dict with these names as keys
241 namespace. They will come back as a dict with these names as keys
242 and their :func:`repr` as values.
242 and their :func:`repr` as values.
243
243
244 allow_stdin : bool, optional
244 allow_stdin : bool, optional
245 Flag for
245 Flag for
246 A dict with string keys and to pull from the user's
246 A dict with string keys and to pull from the user's
247 namespace. They will come back as a dict with these names as keys
247 namespace. They will come back as a dict with these names as keys
248 and their :func:`repr` as values.
248 and their :func:`repr` as values.
249
249
250 Returns
250 Returns
251 -------
251 -------
252 The msg_id of the message sent.
252 The msg_id of the message sent.
253 """
253 """
254 if user_variables is None:
254 if user_variables is None:
255 user_variables = []
255 user_variables = []
256 if user_expressions is None:
256 if user_expressions is None:
257 user_expressions = {}
257 user_expressions = {}
258 if allow_stdin is None:
258 if allow_stdin is None:
259 allow_stdin = self.allow_stdin
259 allow_stdin = self.allow_stdin
260
260
261
261
262 # Don't waste network traffic if inputs are invalid
262 # Don't waste network traffic if inputs are invalid
263 if not isinstance(code, basestring):
263 if not isinstance(code, basestring):
264 raise ValueError('code %r must be a string' % code)
264 raise ValueError('code %r must be a string' % code)
265 validate_string_list(user_variables)
265 validate_string_list(user_variables)
266 validate_string_dict(user_expressions)
266 validate_string_dict(user_expressions)
267
267
268 # Create class for content/msg creation. Related to, but possibly
268 # Create class for content/msg creation. Related to, but possibly
269 # not in Session.
269 # not in Session.
270 content = dict(code=code, silent=silent,
270 content = dict(code=code, silent=silent,
271 user_variables=user_variables,
271 user_variables=user_variables,
272 user_expressions=user_expressions,
272 user_expressions=user_expressions,
273 allow_stdin=allow_stdin,
273 allow_stdin=allow_stdin,
274 )
274 )
275 msg = self.session.msg('execute_request', content)
275 msg = self.session.msg('execute_request', content)
276 self._queue_send(msg)
276 self._queue_send(msg)
277 return msg['header']['msg_id']
277 return msg['header']['msg_id']
278
278
279 def complete(self, text, line, cursor_pos, block=None):
279 def complete(self, text, line, cursor_pos, block=None):
280 """Tab complete text in the kernel's namespace.
280 """Tab complete text in the kernel's namespace.
281
281
282 Parameters
282 Parameters
283 ----------
283 ----------
284 text : str
284 text : str
285 The text to complete.
285 The text to complete.
286 line : str
286 line : str
287 The full line of text that is the surrounding context for the
287 The full line of text that is the surrounding context for the
288 text to complete.
288 text to complete.
289 cursor_pos : int
289 cursor_pos : int
290 The position of the cursor in the line where the completion was
290 The position of the cursor in the line where the completion was
291 requested.
291 requested.
292 block : str, optional
292 block : str, optional
293 The full block of code in which the completion is being requested.
293 The full block of code in which the completion is being requested.
294
294
295 Returns
295 Returns
296 -------
296 -------
297 The msg_id of the message sent.
297 The msg_id of the message sent.
298 """
298 """
299 content = dict(text=text, line=line, block=block, cursor_pos=cursor_pos)
299 content = dict(text=text, line=line, block=block, cursor_pos=cursor_pos)
300 msg = self.session.msg('complete_request', content)
300 msg = self.session.msg('complete_request', content)
301 self._queue_send(msg)
301 self._queue_send(msg)
302 return msg['header']['msg_id']
302 return msg['header']['msg_id']
303
303
304 def object_info(self, oname, detail_level=0):
304 def object_info(self, oname, detail_level=0):
305 """Get metadata information about an object.
305 """Get metadata information about an object.
306
306
307 Parameters
307 Parameters
308 ----------
308 ----------
309 oname : str
309 oname : str
310 A string specifying the object name.
310 A string specifying the object name.
311 detail_level : int, optional
311 detail_level : int, optional
312 The level of detail for the introspection (0-2)
312 The level of detail for the introspection (0-2)
313
313
314 Returns
314 Returns
315 -------
315 -------
316 The msg_id of the message sent.
316 The msg_id of the message sent.
317 """
317 """
318 content = dict(oname=oname, detail_level=detail_level)
318 content = dict(oname=oname, detail_level=detail_level)
319 msg = self.session.msg('object_info_request', content)
319 msg = self.session.msg('object_info_request', content)
320 self._queue_send(msg)
320 self._queue_send(msg)
321 return msg['header']['msg_id']
321 return msg['header']['msg_id']
322
322
323 def history(self, raw=True, output=False, hist_access_type='range', **kwargs):
323 def history(self, raw=True, output=False, hist_access_type='range', **kwargs):
324 """Get entries from the history list.
324 """Get entries from the history list.
325
325
326 Parameters
326 Parameters
327 ----------
327 ----------
328 raw : bool
328 raw : bool
329 If True, return the raw input.
329 If True, return the raw input.
330 output : bool
330 output : bool
331 If True, then return the output as well.
331 If True, then return the output as well.
332 hist_access_type : str
332 hist_access_type : str
333 'range' (fill in session, start and stop params), 'tail' (fill in n)
333 'range' (fill in session, start and stop params), 'tail' (fill in n)
334 or 'search' (fill in pattern param).
334 or 'search' (fill in pattern param).
335
335
336 session : int
336 session : int
337 For a range request, the session from which to get lines. Session
337 For a range request, the session from which to get lines. Session
338 numbers are positive integers; negative ones count back from the
338 numbers are positive integers; negative ones count back from the
339 current session.
339 current session.
340 start : int
340 start : int
341 The first line number of a history range.
341 The first line number of a history range.
342 stop : int
342 stop : int
343 The final (excluded) line number of a history range.
343 The final (excluded) line number of a history range.
344
344
345 n : int
345 n : int
346 The number of lines of history to get for a tail request.
346 The number of lines of history to get for a tail request.
347
347
348 pattern : str
348 pattern : str
349 The glob-syntax pattern for a search request.
349 The glob-syntax pattern for a search request.
350
350
351 Returns
351 Returns
352 -------
352 -------
353 The msg_id of the message sent.
353 The msg_id of the message sent.
354 """
354 """
355 content = dict(raw=raw, output=output, hist_access_type=hist_access_type,
355 content = dict(raw=raw, output=output, hist_access_type=hist_access_type,
356 **kwargs)
356 **kwargs)
357 msg = self.session.msg('history_request', content)
357 msg = self.session.msg('history_request', content)
358 self._queue_send(msg)
358 self._queue_send(msg)
359 return msg['header']['msg_id']
359 return msg['header']['msg_id']
360
360
361 def shutdown(self, restart=False):
361 def shutdown(self, restart=False):
362 """Request an immediate kernel shutdown.
362 """Request an immediate kernel shutdown.
363
363
364 Upon receipt of the (empty) reply, client code can safely assume that
364 Upon receipt of the (empty) reply, client code can safely assume that
365 the kernel has shut down and it's safe to forcefully terminate it if
365 the kernel has shut down and it's safe to forcefully terminate it if
366 it's still alive.
366 it's still alive.
367
367
368 The kernel will send the reply via a function registered with Python's
368 The kernel will send the reply via a function registered with Python's
369 atexit module, ensuring it's truly done as the kernel is done with all
369 atexit module, ensuring it's truly done as the kernel is done with all
370 normal operation.
370 normal operation.
371 """
371 """
372 # Send quit message to kernel. Once we implement kernel-side setattr,
372 # Send quit message to kernel. Once we implement kernel-side setattr,
373 # this should probably be done that way, but for now this will do.
373 # this should probably be done that way, but for now this will do.
374 msg = self.session.msg('shutdown_request', {'restart':restart})
374 msg = self.session.msg('shutdown_request', {'restart':restart})
375 self._queue_send(msg)
375 self._queue_send(msg)
376 return msg['header']['msg_id']
376 return msg['header']['msg_id']
377
377
378
378
379
379
380 class SubSocketChannel(ZMQSocketChannel):
380 class SubSocketChannel(ZMQSocketChannel):
381 """The SUB channel which listens for messages that the kernel publishes.
381 """The SUB channel which listens for messages that the kernel publishes.
382 """
382 """
383
383
384 def __init__(self, context, session, address):
384 def __init__(self, context, session, address):
385 super(SubSocketChannel, self).__init__(context, session, address)
385 super(SubSocketChannel, self).__init__(context, session, address)
386 self.ioloop = ioloop.IOLoop()
386 self.ioloop = ioloop.IOLoop()
387
387
388 def run(self):
388 def run(self):
389 """The thread's main activity. Call start() instead."""
389 """The thread's main activity. Call start() instead."""
390 self.socket = self.context.socket(zmq.SUB)
390 self.socket = self.context.socket(zmq.SUB)
391 self.socket.setsockopt(zmq.SUBSCRIBE,b'')
391 self.socket.setsockopt(zmq.SUBSCRIBE,b'')
392 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
392 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
393 self.socket.connect('tcp://%s:%i' % self.address)
393 self.socket.connect('tcp://%s:%i' % self.address)
394 self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
394 self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
395 self.stream.on_recv(self._handle_recv)
395 self.stream.on_recv(self._handle_recv)
396 self._run_loop()
396 self._run_loop()
397 try:
397 try:
398 self.socket.close()
398 self.socket.close()
399 except:
399 except:
400 pass
400 pass
401
401
402 def stop(self):
402 def stop(self):
403 self.ioloop.stop()
403 self.ioloop.stop()
404 super(SubSocketChannel, self).stop()
404 super(SubSocketChannel, self).stop()
405
405
406 def call_handlers(self, msg):
406 def call_handlers(self, msg):
407 """This method is called in the ioloop thread when a message arrives.
407 """This method is called in the ioloop thread when a message arrives.
408
408
409 Subclasses should override this method to handle incoming messages.
409 Subclasses should override this method to handle incoming messages.
410 It is important to remember that this method is called in the thread
410 It is important to remember that this method is called in the thread
411 so that some logic must be done to ensure that the application leve
411 so that some logic must be done to ensure that the application leve
412 handlers are called in the application thread.
412 handlers are called in the application thread.
413 """
413 """
414 raise NotImplementedError('call_handlers must be defined in a subclass.')
414 raise NotImplementedError('call_handlers must be defined in a subclass.')
415
415
416 def flush(self, timeout=1.0):
416 def flush(self, timeout=1.0):
417 """Immediately processes all pending messages on the SUB channel.
417 """Immediately processes all pending messages on the SUB channel.
418
418
419 Callers should use this method to ensure that :method:`call_handlers`
419 Callers should use this method to ensure that :method:`call_handlers`
420 has been called for all messages that have been received on the
420 has been called for all messages that have been received on the
421 0MQ SUB socket of this channel.
421 0MQ SUB socket of this channel.
422
422
423 This method is thread safe.
423 This method is thread safe.
424
424
425 Parameters
425 Parameters
426 ----------
426 ----------
427 timeout : float, optional
427 timeout : float, optional
428 The maximum amount of time to spend flushing, in seconds. The
428 The maximum amount of time to spend flushing, in seconds. The
429 default is one second.
429 default is one second.
430 """
430 """
431 # We do the IOLoop callback process twice to ensure that the IOLoop
431 # We do the IOLoop callback process twice to ensure that the IOLoop
432 # gets to perform at least one full poll.
432 # gets to perform at least one full poll.
433 stop_time = time.time() + timeout
433 stop_time = time.time() + timeout
434 for i in xrange(2):
434 for i in xrange(2):
435 self._flushed = False
435 self._flushed = False
436 self.ioloop.add_callback(self._flush)
436 self.ioloop.add_callback(self._flush)
437 while not self._flushed and time.time() < stop_time:
437 while not self._flushed and time.time() < stop_time:
438 time.sleep(0.01)
438 time.sleep(0.01)
439
439
440 def _flush(self):
440 def _flush(self):
441 """Callback for :method:`self.flush`."""
441 """Callback for :method:`self.flush`."""
442 self.stream.flush()
442 self.stream.flush()
443 self._flushed = True
443 self._flushed = True
444
444
445
445
446 class StdInSocketChannel(ZMQSocketChannel):
446 class StdInSocketChannel(ZMQSocketChannel):
447 """A reply channel to handle raw_input requests that the kernel makes."""
447 """A reply channel to handle raw_input requests that the kernel makes."""
448
448
449 msg_queue = None
449 msg_queue = None
450
450
451 def __init__(self, context, session, address):
451 def __init__(self, context, session, address):
452 super(StdInSocketChannel, self).__init__(context, session, address)
452 super(StdInSocketChannel, self).__init__(context, session, address)
453 self.ioloop = ioloop.IOLoop()
453 self.ioloop = ioloop.IOLoop()
454
454
455 def run(self):
455 def run(self):
456 """The thread's main activity. Call start() instead."""
456 """The thread's main activity. Call start() instead."""
457 self.socket = self.context.socket(zmq.DEALER)
457 self.socket = self.context.socket(zmq.DEALER)
458 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
458 self.socket.setsockopt(zmq.IDENTITY, self.session.bsession)
459 self.socket.connect('tcp://%s:%i' % self.address)
459 self.socket.connect('tcp://%s:%i' % self.address)
460 self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
460 self.stream = zmqstream.ZMQStream(self.socket, self.ioloop)
461 self.stream.on_recv(self._handle_recv)
461 self.stream.on_recv(self._handle_recv)
462 self._run_loop()
462 self._run_loop()
463 try:
463 try:
464 self.socket.close()
464 self.socket.close()
465 except:
465 except:
466 pass
466 pass
467
467
468
468
469 def stop(self):
469 def stop(self):
470 self.ioloop.stop()
470 self.ioloop.stop()
471 super(StdInSocketChannel, self).stop()
471 super(StdInSocketChannel, self).stop()
472
472
473 def call_handlers(self, msg):
473 def call_handlers(self, msg):
474 """This method is called in the ioloop thread when a message arrives.
474 """This method is called in the ioloop thread when a message arrives.
475
475
476 Subclasses should override this method to handle incoming messages.
476 Subclasses should override this method to handle incoming messages.
477 It is important to remember that this method is called in the thread
477 It is important to remember that this method is called in the thread
478 so that some logic must be done to ensure that the application leve
478 so that some logic must be done to ensure that the application leve
479 handlers are called in the application thread.
479 handlers are called in the application thread.
480 """
480 """
481 raise NotImplementedError('call_handlers must be defined in a subclass.')
481 raise NotImplementedError('call_handlers must be defined in a subclass.')
482
482
483 def input(self, string):
483 def input(self, string):
484 """Send a string of raw input to the kernel."""
484 """Send a string of raw input to the kernel."""
485 content = dict(value=string)
485 content = dict(value=string)
486 msg = self.session.msg('input_reply', content)
486 msg = self.session.msg('input_reply', content)
487 self._queue_send(msg)
487 self._queue_send(msg)
488
488
489
489
490 class HBSocketChannel(ZMQSocketChannel):
490 class HBSocketChannel(ZMQSocketChannel):
491 """The heartbeat channel which monitors the kernel heartbeat.
491 """The heartbeat channel which monitors the kernel heartbeat.
492
492
493 Note that the heartbeat channel is paused by default. As long as you start
493 Note that the heartbeat channel is paused by default. As long as you start
494 this channel, the kernel manager will ensure that it is paused and un-paused
494 this channel, the kernel manager will ensure that it is paused and un-paused
495 as appropriate.
495 as appropriate.
496 """
496 """
497
497
498 time_to_dead = 3.0
498 time_to_dead = 3.0
499 socket = None
499 socket = None
500 poller = None
500 poller = None
501 _running = None
501 _running = None
502 _pause = None
502 _pause = None
503 _beating = None
503 _beating = None
504
504
505 def __init__(self, context, session, address):
505 def __init__(self, context, session, address):
506 super(HBSocketChannel, self).__init__(context, session, address)
506 super(HBSocketChannel, self).__init__(context, session, address)
507 self._running = False
507 self._running = False
508 self._pause =True
508 self._pause =True
509 self.poller = zmq.Poller()
509 self.poller = zmq.Poller()
510
510
511 def _create_socket(self):
511 def _create_socket(self):
512 if self.socket is not None:
512 if self.socket is not None:
513 # close previous socket, before opening a new one
513 # close previous socket, before opening a new one
514 self.poller.unregister(self.socket)
514 self.poller.unregister(self.socket)
515 self.socket.close()
515 self.socket.close()
516 self.socket = self.context.socket(zmq.REQ)
516 self.socket = self.context.socket(zmq.REQ)
517 self.socket.setsockopt(zmq.LINGER, 0)
517 self.socket.setsockopt(zmq.LINGER, 0)
518 self.socket.connect('tcp://%s:%i' % self.address)
518 self.socket.connect('tcp://%s:%i' % self.address)
519
519
520 self.poller.register(self.socket, zmq.POLLIN)
520 self.poller.register(self.socket, zmq.POLLIN)
521
521
522 def _poll(self, start_time):
522 def _poll(self, start_time):
523 """poll for heartbeat replies until we reach self.time_to_dead
523 """poll for heartbeat replies until we reach self.time_to_dead
524
524
525 Ignores interrupts, and returns the result of poll(), which
525 Ignores interrupts, and returns the result of poll(), which
526 will be an empty list if no messages arrived before the timeout,
526 will be an empty list if no messages arrived before the timeout,
527 or the event tuple if there is a message to receive.
527 or the event tuple if there is a message to receive.
528 """
528 """
529
529
530 until_dead = self.time_to_dead - (time.time() - start_time)
530 until_dead = self.time_to_dead - (time.time() - start_time)
531 # ensure poll at least once
531 # ensure poll at least once
532 until_dead = max(until_dead, 1e-3)
532 until_dead = max(until_dead, 1e-3)
533 events = []
533 events = []
534 while True:
534 while True:
535 try:
535 try:
536 events = self.poller.poll(1000 * until_dead)
536 events = self.poller.poll(1000 * until_dead)
537 except ZMQError as e:
537 except ZMQError as e:
538 if e.errno == errno.EINTR:
538 if e.errno == errno.EINTR:
539 # ignore interrupts during heartbeat
539 # ignore interrupts during heartbeat
540 # this may never actually happen
540 # this may never actually happen
541 until_dead = self.time_to_dead - (time.time() - start_time)
541 until_dead = self.time_to_dead - (time.time() - start_time)
542 until_dead = max(until_dead, 1e-3)
542 until_dead = max(until_dead, 1e-3)
543 pass
543 pass
544 else:
544 else:
545 raise
545 raise
546 except Exception:
546 except Exception:
547 if self._exiting:
547 if self._exiting:
548 break
548 break
549 else:
549 else:
550 raise
550 raise
551 else:
551 else:
552 break
552 break
553 return events
553 return events
554
554
555 def run(self):
555 def run(self):
556 """The thread's main activity. Call start() instead."""
556 """The thread's main activity. Call start() instead."""
557 self._create_socket()
557 self._create_socket()
558 self._running = True
558 self._running = True
559 self._beating = True
559 self._beating = True
560
560
561 while self._running:
561 while self._running:
562 if self._pause:
562 if self._pause:
563 # just sleep, and skip the rest of the loop
563 # just sleep, and skip the rest of the loop
564 time.sleep(self.time_to_dead)
564 time.sleep(self.time_to_dead)
565 continue
565 continue
566
566
567 since_last_heartbeat = 0.0
567 since_last_heartbeat = 0.0
568 # io.rprint('Ping from HB channel') # dbg
568 # io.rprint('Ping from HB channel') # dbg
569 # no need to catch EFSM here, because the previous event was
569 # no need to catch EFSM here, because the previous event was
570 # either a recv or connect, which cannot be followed by EFSM
570 # either a recv or connect, which cannot be followed by EFSM
571 self.socket.send(b'ping')
571 self.socket.send(b'ping')
572 request_time = time.time()
572 request_time = time.time()
573 ready = self._poll(request_time)
573 ready = self._poll(request_time)
574 if ready:
574 if ready:
575 self._beating = True
575 self._beating = True
576 # the poll above guarantees we have something to recv
576 # the poll above guarantees we have something to recv
577 self.socket.recv()
577 self.socket.recv()
578 # sleep the remainder of the cycle
578 # sleep the remainder of the cycle
579 remainder = self.time_to_dead - (time.time() - request_time)
579 remainder = self.time_to_dead - (time.time() - request_time)
580 if remainder > 0:
580 if remainder > 0:
581 time.sleep(remainder)
581 time.sleep(remainder)
582 continue
582 continue
583 else:
583 else:
584 # nothing was received within the time limit, signal heart failure
584 # nothing was received within the time limit, signal heart failure
585 self._beating = False
585 self._beating = False
586 since_last_heartbeat = time.time() - request_time
586 since_last_heartbeat = time.time() - request_time
587 self.call_handlers(since_last_heartbeat)
587 self.call_handlers(since_last_heartbeat)
588 # and close/reopen the socket, because the REQ/REP cycle has been broken
588 # and close/reopen the socket, because the REQ/REP cycle has been broken
589 self._create_socket()
589 self._create_socket()
590 continue
590 continue
591 try:
591 try:
592 self.socket.close()
592 self.socket.close()
593 except:
593 except:
594 pass
594 pass
595
595
596 def pause(self):
596 def pause(self):
597 """Pause the heartbeat."""
597 """Pause the heartbeat."""
598 self._pause = True
598 self._pause = True
599
599
600 def unpause(self):
600 def unpause(self):
601 """Unpause the heartbeat."""
601 """Unpause the heartbeat."""
602 self._pause = False
602 self._pause = False
603
603
604 def is_beating(self):
604 def is_beating(self):
605 """Is the heartbeat running and responsive (and not paused)."""
605 """Is the heartbeat running and responsive (and not paused)."""
606 if self.is_alive() and not self._pause and self._beating:
606 if self.is_alive() and not self._pause and self._beating:
607 return True
607 return True
608 else:
608 else:
609 return False
609 return False
610
610
611 def stop(self):
611 def stop(self):
612 self._running = False
612 self._running = False
613 super(HBSocketChannel, self).stop()
613 super(HBSocketChannel, self).stop()
614
614
615 def call_handlers(self, since_last_heartbeat):
615 def call_handlers(self, since_last_heartbeat):
616 """This method is called in the ioloop thread when a message arrives.
616 """This method is called in the ioloop thread when a message arrives.
617
617
618 Subclasses should override this method to handle incoming messages.
618 Subclasses should override this method to handle incoming messages.
619 It is important to remember that this method is called in the thread
619 It is important to remember that this method is called in the thread
620 so that some logic must be done to ensure that the application level
620 so that some logic must be done to ensure that the application level
621 handlers are called in the application thread.
621 handlers are called in the application thread.
622 """
622 """
623 raise NotImplementedError('call_handlers must be defined in a subclass.')
623 raise NotImplementedError('call_handlers must be defined in a subclass.')
624
624
625
625
626 #-----------------------------------------------------------------------------
626 #-----------------------------------------------------------------------------
627 # Main kernel manager class
627 # Main kernel manager class
628 #-----------------------------------------------------------------------------
628 #-----------------------------------------------------------------------------
629
629
630 class KernelManager(HasTraits):
630 class KernelManager(HasTraits):
631 """ Manages a kernel for a frontend.
631 """ Manages a kernel for a frontend.
632
632
633 The SUB channel is for the frontend to receive messages published by the
633 The SUB channel is for the frontend to receive messages published by the
634 kernel.
634 kernel.
635
635
636 The REQ channel is for the frontend to make requests of the kernel.
636 The REQ channel is for the frontend to make requests of the kernel.
637
637
638 The REP channel is for the kernel to request stdin (raw_input) from the
638 The REP channel is for the kernel to request stdin (raw_input) from the
639 frontend.
639 frontend.
640 """
640 """
641 # config object for passing to child configurables
641 # config object for passing to child configurables
642 config = Instance(Config)
642 config = Instance(Config)
643
643
644 # The PyZMQ Context to use for communication with the kernel.
644 # The PyZMQ Context to use for communication with the kernel.
645 context = Instance(zmq.Context)
645 context = Instance(zmq.Context)
646 def _context_default(self):
646 def _context_default(self):
647 return zmq.Context.instance()
647 return zmq.Context.instance()
648
648
649 # The Session to use for communication with the kernel.
649 # The Session to use for communication with the kernel.
650 session = Instance(Session)
650 session = Instance(Session)
651
651
652 # The kernel process with which the KernelManager is communicating.
652 # The kernel process with which the KernelManager is communicating.
653 kernel = Instance(Popen)
653 kernel = Instance(Popen)
654
654
655 # The addresses for the communication channels.
655 # The addresses for the communication channels.
656 connection_file = Unicode('')
656 connection_file = Unicode('')
657 ip = Unicode(LOCALHOST)
657 ip = Unicode(LOCALHOST)
658 def _ip_changed(self, name, old, new):
658 def _ip_changed(self, name, old, new):
659 if new == '*':
659 if new == '*':
660 self.ip = '0.0.0.0'
660 self.ip = '0.0.0.0'
661 shell_port = Integer(0)
661 shell_port = Integer(0)
662 iopub_port = Integer(0)
662 iopub_port = Integer(0)
663 stdin_port = Integer(0)
663 stdin_port = Integer(0)
664 hb_port = Integer(0)
664 hb_port = Integer(0)
665
665
666 # The classes to use for the various channels.
666 # The classes to use for the various channels.
667 shell_channel_class = Type(ShellSocketChannel)
667 shell_channel_class = Type(ShellSocketChannel)
668 sub_channel_class = Type(SubSocketChannel)
668 sub_channel_class = Type(SubSocketChannel)
669 stdin_channel_class = Type(StdInSocketChannel)
669 stdin_channel_class = Type(StdInSocketChannel)
670 hb_channel_class = Type(HBSocketChannel)
670 hb_channel_class = Type(HBSocketChannel)
671
671
672 # Protected traits.
672 # Protected traits.
673 _launch_args = Any
673 _launch_args = Any
674 _shell_channel = Any
674 _shell_channel = Any
675 _sub_channel = Any
675 _sub_channel = Any
676 _stdin_channel = Any
676 _stdin_channel = Any
677 _hb_channel = Any
677 _hb_channel = Any
678 _connection_file_written=Bool(False)
678 _connection_file_written=Bool(False)
679
679
680 def __init__(self, **kwargs):
680 def __init__(self, **kwargs):
681 super(KernelManager, self).__init__(**kwargs)
681 super(KernelManager, self).__init__(**kwargs)
682 if self.session is None:
682 if self.session is None:
683 self.session = Session(config=self.config)
683 self.session = Session(config=self.config)
684
684
685 def __del__(self):
685 def __del__(self):
686 self.cleanup_connection_file()
686 self.cleanup_connection_file()
687
687
688
688
689 #--------------------------------------------------------------------------
689 #--------------------------------------------------------------------------
690 # Channel management methods:
690 # Channel management methods:
691 #--------------------------------------------------------------------------
691 #--------------------------------------------------------------------------
692
692
693 def start_channels(self, shell=True, sub=True, stdin=True, hb=True):
693 def start_channels(self, shell=True, sub=True, stdin=True, hb=True):
694 """Starts the channels for this kernel.
694 """Starts the channels for this kernel.
695
695
696 This will create the channels if they do not exist and then start
696 This will create the channels if they do not exist and then start
697 them. If port numbers of 0 are being used (random ports) then you
697 them. If port numbers of 0 are being used (random ports) then you
698 must first call :method:`start_kernel`. If the channels have been
698 must first call :method:`start_kernel`. If the channels have been
699 stopped and you call this, :class:`RuntimeError` will be raised.
699 stopped and you call this, :class:`RuntimeError` will be raised.
700 """
700 """
701 if shell:
701 if shell:
702 self.shell_channel.start()
702 self.shell_channel.start()
703 if sub:
703 if sub:
704 self.sub_channel.start()
704 self.sub_channel.start()
705 if stdin:
705 if stdin:
706 self.stdin_channel.start()
706 self.stdin_channel.start()
707 self.shell_channel.allow_stdin = True
707 self.shell_channel.allow_stdin = True
708 else:
708 else:
709 self.shell_channel.allow_stdin = False
709 self.shell_channel.allow_stdin = False
710 if hb:
710 if hb:
711 self.hb_channel.start()
711 self.hb_channel.start()
712
712
713 def stop_channels(self):
713 def stop_channels(self):
714 """Stops all the running channels for this kernel.
714 """Stops all the running channels for this kernel.
715 """
715 """
716 if self.shell_channel.is_alive():
716 if self.shell_channel.is_alive():
717 self.shell_channel.stop()
717 self.shell_channel.stop()
718 if self.sub_channel.is_alive():
718 if self.sub_channel.is_alive():
719 self.sub_channel.stop()
719 self.sub_channel.stop()
720 if self.stdin_channel.is_alive():
720 if self.stdin_channel.is_alive():
721 self.stdin_channel.stop()
721 self.stdin_channel.stop()
722 if self.hb_channel.is_alive():
722 if self.hb_channel.is_alive():
723 self.hb_channel.stop()
723 self.hb_channel.stop()
724
724
725 @property
725 @property
726 def channels_running(self):
726 def channels_running(self):
727 """Are any of the channels created and running?"""
727 """Are any of the channels created and running?"""
728 return (self.shell_channel.is_alive() or self.sub_channel.is_alive() or
728 return (self.shell_channel.is_alive() or self.sub_channel.is_alive() or
729 self.stdin_channel.is_alive() or self.hb_channel.is_alive())
729 self.stdin_channel.is_alive() or self.hb_channel.is_alive())
730
730
731 #--------------------------------------------------------------------------
731 #--------------------------------------------------------------------------
732 # Kernel process management methods:
732 # Kernel process management methods:
733 #--------------------------------------------------------------------------
733 #--------------------------------------------------------------------------
734
734
735 def cleanup_connection_file(self):
735 def cleanup_connection_file(self):
736 """cleanup connection file *if we wrote it*
736 """cleanup connection file *if we wrote it*
737
737
738 Will not raise if the connection file was already removed somehow.
738 Will not raise if the connection file was already removed somehow.
739 """
739 """
740 if self._connection_file_written:
740 if self._connection_file_written:
741 # cleanup connection files on full shutdown of kernel we started
741 # cleanup connection files on full shutdown of kernel we started
742 self._connection_file_written = False
742 self._connection_file_written = False
743 try:
743 try:
744 os.remove(self.connection_file)
744 os.remove(self.connection_file)
745 except OSError:
745 except OSError:
746 pass
746 pass
747
747
748 def load_connection_file(self):
748 def load_connection_file(self):
749 """load connection info from JSON dict in self.connection_file"""
749 """load connection info from JSON dict in self.connection_file"""
750 with open(self.connection_file) as f:
750 with open(self.connection_file) as f:
751 cfg = json.loads(f.read())
751 cfg = json.loads(f.read())
752
752
753 self.ip = cfg['ip']
753 self.ip = cfg['ip']
754 self.shell_port = cfg['shell_port']
754 self.shell_port = cfg['shell_port']
755 self.stdin_port = cfg['stdin_port']
755 self.stdin_port = cfg['stdin_port']
756 self.iopub_port = cfg['iopub_port']
756 self.iopub_port = cfg['iopub_port']
757 self.hb_port = cfg['hb_port']
757 self.hb_port = cfg['hb_port']
758 self.session.key = str_to_bytes(cfg['key'])
758 self.session.key = str_to_bytes(cfg['key'])
759
759
760 def write_connection_file(self):
760 def write_connection_file(self):
761 """write connection info to JSON dict in self.connection_file"""
761 """write connection info to JSON dict in self.connection_file"""
762 if self._connection_file_written:
762 if self._connection_file_written:
763 return
763 return
764 self.connection_file,cfg = write_connection_file(self.connection_file,
764 self.connection_file,cfg = write_connection_file(self.connection_file,
765 ip=self.ip, key=self.session.key,
765 ip=self.ip, key=self.session.key,
766 stdin_port=self.stdin_port, iopub_port=self.iopub_port,
766 stdin_port=self.stdin_port, iopub_port=self.iopub_port,
767 shell_port=self.shell_port, hb_port=self.hb_port)
767 shell_port=self.shell_port, hb_port=self.hb_port)
768 # write_connection_file also sets default ports:
768 # write_connection_file also sets default ports:
769 self.shell_port = cfg['shell_port']
769 self.shell_port = cfg['shell_port']
770 self.stdin_port = cfg['stdin_port']
770 self.stdin_port = cfg['stdin_port']
771 self.iopub_port = cfg['iopub_port']
771 self.iopub_port = cfg['iopub_port']
772 self.hb_port = cfg['hb_port']
772 self.hb_port = cfg['hb_port']
773
773
774 self._connection_file_written = True
774 self._connection_file_written = True
775
775
776 def start_kernel(self, **kw):
776 def start_kernel(self, **kw):
777 """Starts a kernel process and configures the manager to use it.
777 """Starts a kernel process and configures the manager to use it.
778
778
779 If random ports (port=0) are being used, this method must be called
779 If random ports (port=0) are being used, this method must be called
780 before the channels are created.
780 before the channels are created.
781
781
782 Parameters:
782 Parameters:
783 -----------
783 -----------
784 ipython : bool, optional (default True)
785 Whether to use an IPython kernel instead of a plain Python kernel.
786
787 launcher : callable, optional (default None)
784 launcher : callable, optional (default None)
788 A custom function for launching the kernel process (generally a
785 A custom function for launching the kernel process (generally a
789 wrapper around ``entry_point.base_launch_kernel``). In most cases,
786 wrapper around ``entry_point.base_launch_kernel``). In most cases,
790 it should not be necessary to use this parameter.
787 it should not be necessary to use this parameter.
791
788
792 **kw : optional
789 **kw : optional
793 See respective options for IPython and Python kernels.
790 See respective options for IPython and Python kernels.
794 """
791 """
795 if self.ip not in LOCAL_IPS:
792 if self.ip not in LOCAL_IPS:
796 raise RuntimeError("Can only launch a kernel on a local interface. "
793 raise RuntimeError("Can only launch a kernel on a local interface. "
797 "Make sure that the '*_address' attributes are "
794 "Make sure that the '*_address' attributes are "
798 "configured properly. "
795 "configured properly. "
799 "Currently valid addresses are: %s"%LOCAL_IPS
796 "Currently valid addresses are: %s"%LOCAL_IPS
800 )
797 )
801
798
802 # write connection file / get default ports
799 # write connection file / get default ports
803 self.write_connection_file()
800 self.write_connection_file()
804
801
805 self._launch_args = kw.copy()
802 self._launch_args = kw.copy()
806 launch_kernel = kw.pop('launcher', None)
803 launch_kernel = kw.pop('launcher', None)
807 if launch_kernel is None:
804 if launch_kernel is None:
808 if kw.pop('ipython', True):
805 from ipkernel import launch_kernel
809 from ipkernel import launch_kernel
810 else:
811 from pykernel import launch_kernel
812 self.kernel = launch_kernel(fname=self.connection_file, **kw)
806 self.kernel = launch_kernel(fname=self.connection_file, **kw)
813
807
814 def shutdown_kernel(self, restart=False):
808 def shutdown_kernel(self, restart=False):
815 """ Attempts to the stop the kernel process cleanly. If the kernel
809 """ Attempts to the stop the kernel process cleanly. If the kernel
816 cannot be stopped, it is killed, if possible.
810 cannot be stopped, it is killed, if possible.
817 """
811 """
818 # FIXME: Shutdown does not work on Windows due to ZMQ errors!
812 # FIXME: Shutdown does not work on Windows due to ZMQ errors!
819 if sys.platform == 'win32':
813 if sys.platform == 'win32':
820 self.kill_kernel()
814 self.kill_kernel()
821 return
815 return
822
816
823 # Pause the heart beat channel if it exists.
817 # Pause the heart beat channel if it exists.
824 if self._hb_channel is not None:
818 if self._hb_channel is not None:
825 self._hb_channel.pause()
819 self._hb_channel.pause()
826
820
827 # Don't send any additional kernel kill messages immediately, to give
821 # Don't send any additional kernel kill messages immediately, to give
828 # the kernel a chance to properly execute shutdown actions. Wait for at
822 # the kernel a chance to properly execute shutdown actions. Wait for at
829 # most 1s, checking every 0.1s.
823 # most 1s, checking every 0.1s.
830 self.shell_channel.shutdown(restart=restart)
824 self.shell_channel.shutdown(restart=restart)
831 for i in range(10):
825 for i in range(10):
832 if self.is_alive:
826 if self.is_alive:
833 time.sleep(0.1)
827 time.sleep(0.1)
834 else:
828 else:
835 break
829 break
836 else:
830 else:
837 # OK, we've waited long enough.
831 # OK, we've waited long enough.
838 if self.has_kernel:
832 if self.has_kernel:
839 self.kill_kernel()
833 self.kill_kernel()
840
834
841 if not restart and self._connection_file_written:
835 if not restart and self._connection_file_written:
842 # cleanup connection files on full shutdown of kernel we started
836 # cleanup connection files on full shutdown of kernel we started
843 self._connection_file_written = False
837 self._connection_file_written = False
844 try:
838 try:
845 os.remove(self.connection_file)
839 os.remove(self.connection_file)
846 except IOError:
840 except IOError:
847 pass
841 pass
848
842
849 def restart_kernel(self, now=False, **kw):
843 def restart_kernel(self, now=False, **kw):
850 """Restarts a kernel with the arguments that were used to launch it.
844 """Restarts a kernel with the arguments that were used to launch it.
851
845
852 If the old kernel was launched with random ports, the same ports will be
846 If the old kernel was launched with random ports, the same ports will be
853 used for the new kernel.
847 used for the new kernel.
854
848
855 Parameters
849 Parameters
856 ----------
850 ----------
857 now : bool, optional
851 now : bool, optional
858 If True, the kernel is forcefully restarted *immediately*, without
852 If True, the kernel is forcefully restarted *immediately*, without
859 having a chance to do any cleanup action. Otherwise the kernel is
853 having a chance to do any cleanup action. Otherwise the kernel is
860 given 1s to clean up before a forceful restart is issued.
854 given 1s to clean up before a forceful restart is issued.
861
855
862 In all cases the kernel is restarted, the only difference is whether
856 In all cases the kernel is restarted, the only difference is whether
863 it is given a chance to perform a clean shutdown or not.
857 it is given a chance to perform a clean shutdown or not.
864
858
865 **kw : optional
859 **kw : optional
866 Any options specified here will replace those used to launch the
860 Any options specified here will replace those used to launch the
867 kernel.
861 kernel.
868 """
862 """
869 if self._launch_args is None:
863 if self._launch_args is None:
870 raise RuntimeError("Cannot restart the kernel. "
864 raise RuntimeError("Cannot restart the kernel. "
871 "No previous call to 'start_kernel'.")
865 "No previous call to 'start_kernel'.")
872 else:
866 else:
873 # Stop currently running kernel.
867 # Stop currently running kernel.
874 if self.has_kernel:
868 if self.has_kernel:
875 if now:
869 if now:
876 self.kill_kernel()
870 self.kill_kernel()
877 else:
871 else:
878 self.shutdown_kernel(restart=True)
872 self.shutdown_kernel(restart=True)
879
873
880 # Start new kernel.
874 # Start new kernel.
881 self._launch_args.update(kw)
875 self._launch_args.update(kw)
882 self.start_kernel(**self._launch_args)
876 self.start_kernel(**self._launch_args)
883
877
884 # FIXME: Messages get dropped in Windows due to probable ZMQ bug
878 # FIXME: Messages get dropped in Windows due to probable ZMQ bug
885 # unless there is some delay here.
879 # unless there is some delay here.
886 if sys.platform == 'win32':
880 if sys.platform == 'win32':
887 time.sleep(0.2)
881 time.sleep(0.2)
888
882
889 @property
883 @property
890 def has_kernel(self):
884 def has_kernel(self):
891 """Returns whether a kernel process has been specified for the kernel
885 """Returns whether a kernel process has been specified for the kernel
892 manager.
886 manager.
893 """
887 """
894 return self.kernel is not None
888 return self.kernel is not None
895
889
896 def kill_kernel(self):
890 def kill_kernel(self):
897 """ Kill the running kernel. """
891 """ Kill the running kernel. """
898 if self.has_kernel:
892 if self.has_kernel:
899 # Pause the heart beat channel if it exists.
893 # Pause the heart beat channel if it exists.
900 if self._hb_channel is not None:
894 if self._hb_channel is not None:
901 self._hb_channel.pause()
895 self._hb_channel.pause()
902
896
903 # Attempt to kill the kernel.
897 # Attempt to kill the kernel.
904 try:
898 try:
905 self.kernel.kill()
899 self.kernel.kill()
906 except OSError, e:
900 except OSError, e:
907 # In Windows, we will get an Access Denied error if the process
901 # In Windows, we will get an Access Denied error if the process
908 # has already terminated. Ignore it.
902 # has already terminated. Ignore it.
909 if sys.platform == 'win32':
903 if sys.platform == 'win32':
910 if e.winerror != 5:
904 if e.winerror != 5:
911 raise
905 raise
912 # On Unix, we may get an ESRCH error if the process has already
906 # On Unix, we may get an ESRCH error if the process has already
913 # terminated. Ignore it.
907 # terminated. Ignore it.
914 else:
908 else:
915 from errno import ESRCH
909 from errno import ESRCH
916 if e.errno != ESRCH:
910 if e.errno != ESRCH:
917 raise
911 raise
918 self.kernel = None
912 self.kernel = None
919 else:
913 else:
920 raise RuntimeError("Cannot kill kernel. No kernel is running!")
914 raise RuntimeError("Cannot kill kernel. No kernel is running!")
921
915
922 def interrupt_kernel(self):
916 def interrupt_kernel(self):
923 """ Interrupts the kernel. Unlike ``signal_kernel``, this operation is
917 """ Interrupts the kernel. Unlike ``signal_kernel``, this operation is
924 well supported on all platforms.
918 well supported on all platforms.
925 """
919 """
926 if self.has_kernel:
920 if self.has_kernel:
927 if sys.platform == 'win32':
921 if sys.platform == 'win32':
928 from parentpoller import ParentPollerWindows as Poller
922 from parentpoller import ParentPollerWindows as Poller
929 Poller.send_interrupt(self.kernel.win32_interrupt_event)
923 Poller.send_interrupt(self.kernel.win32_interrupt_event)
930 else:
924 else:
931 self.kernel.send_signal(signal.SIGINT)
925 self.kernel.send_signal(signal.SIGINT)
932 else:
926 else:
933 raise RuntimeError("Cannot interrupt kernel. No kernel is running!")
927 raise RuntimeError("Cannot interrupt kernel. No kernel is running!")
934
928
935 def signal_kernel(self, signum):
929 def signal_kernel(self, signum):
936 """ Sends a signal to the kernel. Note that since only SIGTERM is
930 """ Sends a signal to the kernel. Note that since only SIGTERM is
937 supported on Windows, this function is only useful on Unix systems.
931 supported on Windows, this function is only useful on Unix systems.
938 """
932 """
939 if self.has_kernel:
933 if self.has_kernel:
940 self.kernel.send_signal(signum)
934 self.kernel.send_signal(signum)
941 else:
935 else:
942 raise RuntimeError("Cannot signal kernel. No kernel is running!")
936 raise RuntimeError("Cannot signal kernel. No kernel is running!")
943
937
944 @property
938 @property
945 def is_alive(self):
939 def is_alive(self):
946 """Is the kernel process still running?"""
940 """Is the kernel process still running?"""
947 if self.has_kernel:
941 if self.has_kernel:
948 if self.kernel.poll() is None:
942 if self.kernel.poll() is None:
949 return True
943 return True
950 else:
944 else:
951 return False
945 return False
952 elif self._hb_channel is not None:
946 elif self._hb_channel is not None:
953 # We didn't start the kernel with this KernelManager so we
947 # We didn't start the kernel with this KernelManager so we
954 # use the heartbeat.
948 # use the heartbeat.
955 return self._hb_channel.is_beating()
949 return self._hb_channel.is_beating()
956 else:
950 else:
957 # no heartbeat and not local, we can't tell if it's running,
951 # no heartbeat and not local, we can't tell if it's running,
958 # so naively return True
952 # so naively return True
959 return True
953 return True
960
954
961 #--------------------------------------------------------------------------
955 #--------------------------------------------------------------------------
962 # Channels used for communication with the kernel:
956 # Channels used for communication with the kernel:
963 #--------------------------------------------------------------------------
957 #--------------------------------------------------------------------------
964
958
965 @property
959 @property
966 def shell_channel(self):
960 def shell_channel(self):
967 """Get the REQ socket channel object to make requests of the kernel."""
961 """Get the REQ socket channel object to make requests of the kernel."""
968 if self._shell_channel is None:
962 if self._shell_channel is None:
969 self._shell_channel = self.shell_channel_class(self.context,
963 self._shell_channel = self.shell_channel_class(self.context,
970 self.session,
964 self.session,
971 (self.ip, self.shell_port))
965 (self.ip, self.shell_port))
972 return self._shell_channel
966 return self._shell_channel
973
967
974 @property
968 @property
975 def sub_channel(self):
969 def sub_channel(self):
976 """Get the SUB socket channel object."""
970 """Get the SUB socket channel object."""
977 if self._sub_channel is None:
971 if self._sub_channel is None:
978 self._sub_channel = self.sub_channel_class(self.context,
972 self._sub_channel = self.sub_channel_class(self.context,
979 self.session,
973 self.session,
980 (self.ip, self.iopub_port))
974 (self.ip, self.iopub_port))
981 return self._sub_channel
975 return self._sub_channel
982
976
983 @property
977 @property
984 def stdin_channel(self):
978 def stdin_channel(self):
985 """Get the REP socket channel object to handle stdin (raw_input)."""
979 """Get the REP socket channel object to handle stdin (raw_input)."""
986 if self._stdin_channel is None:
980 if self._stdin_channel is None:
987 self._stdin_channel = self.stdin_channel_class(self.context,
981 self._stdin_channel = self.stdin_channel_class(self.context,
988 self.session,
982 self.session,
989 (self.ip, self.stdin_port))
983 (self.ip, self.stdin_port))
990 return self._stdin_channel
984 return self._stdin_channel
991
985
992 @property
986 @property
993 def hb_channel(self):
987 def hb_channel(self):
994 """Get the heartbeat socket channel object to check that the
988 """Get the heartbeat socket channel object to check that the
995 kernel is alive."""
989 kernel is alive."""
996 if self._hb_channel is None:
990 if self._hb_channel is None:
997 self._hb_channel = self.hb_channel_class(self.context,
991 self._hb_channel = self.hb_channel_class(self.context,
998 self.session,
992 self.session,
999 (self.ip, self.hb_port))
993 (self.ip, self.hb_port))
1000 return self._hb_channel
994 return self._hb_channel
@@ -1,756 +1,756 b''
1 """Session object for building, serializing, sending, and receiving messages in
1 """Session object for building, serializing, sending, and receiving messages in
2 IPython. The Session object supports serialization, HMAC signatures, and
2 IPython. The Session object supports serialization, HMAC signatures, and
3 metadata on messages.
3 metadata on messages.
4
4
5 Also defined here are utilities for working with Sessions:
5 Also defined here are utilities for working with Sessions:
6 * A SessionFactory to be used as a base class for configurables that work with
6 * A SessionFactory to be used as a base class for configurables that work with
7 Sessions.
7 Sessions.
8 * A Message object for convenience that allows attribute-access to the msg dict.
8 * A Message object for convenience that allows attribute-access to the msg dict.
9
9
10 Authors:
10 Authors:
11
11
12 * Min RK
12 * Min RK
13 * Brian Granger
13 * Brian Granger
14 * Fernando Perez
14 * Fernando Perez
15 """
15 """
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17 # Copyright (C) 2010-2011 The IPython Development Team
17 # Copyright (C) 2010-2011 The IPython Development Team
18 #
18 #
19 # Distributed under the terms of the BSD License. The full license is in
19 # Distributed under the terms of the BSD License. The full license is in
20 # the file COPYING, distributed as part of this software.
20 # the file COPYING, distributed as part of this software.
21 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
22
22
23 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
24 # Imports
24 # Imports
25 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
26
26
27 import hmac
27 import hmac
28 import logging
28 import logging
29 import os
29 import os
30 import pprint
30 import pprint
31 import uuid
31 import uuid
32 from datetime import datetime
32 from datetime import datetime
33
33
34 try:
34 try:
35 import cPickle
35 import cPickle
36 pickle = cPickle
36 pickle = cPickle
37 except:
37 except:
38 cPickle = None
38 cPickle = None
39 import pickle
39 import pickle
40
40
41 import zmq
41 import zmq
42 from zmq.utils import jsonapi
42 from zmq.utils import jsonapi
43 from zmq.eventloop.ioloop import IOLoop
43 from zmq.eventloop.ioloop import IOLoop
44 from zmq.eventloop.zmqstream import ZMQStream
44 from zmq.eventloop.zmqstream import ZMQStream
45
45
46 from IPython.config.application import Application, boolean_flag
46 from IPython.config.application import Application, boolean_flag
47 from IPython.config.configurable import Configurable, LoggingConfigurable
47 from IPython.config.configurable import Configurable, LoggingConfigurable
48 from IPython.utils.importstring import import_item
48 from IPython.utils.importstring import import_item
49 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
49 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
50 from IPython.utils.py3compat import str_to_bytes
50 from IPython.utils.py3compat import str_to_bytes
51 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
51 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
52 DottedObjectName, CUnicode)
52 DottedObjectName, CUnicode)
53
53
54 #-----------------------------------------------------------------------------
54 #-----------------------------------------------------------------------------
55 # utility functions
55 # utility functions
56 #-----------------------------------------------------------------------------
56 #-----------------------------------------------------------------------------
57
57
58 def squash_unicode(obj):
58 def squash_unicode(obj):
59 """coerce unicode back to bytestrings."""
59 """coerce unicode back to bytestrings."""
60 if isinstance(obj,dict):
60 if isinstance(obj,dict):
61 for key in obj.keys():
61 for key in obj.keys():
62 obj[key] = squash_unicode(obj[key])
62 obj[key] = squash_unicode(obj[key])
63 if isinstance(key, unicode):
63 if isinstance(key, unicode):
64 obj[squash_unicode(key)] = obj.pop(key)
64 obj[squash_unicode(key)] = obj.pop(key)
65 elif isinstance(obj, list):
65 elif isinstance(obj, list):
66 for i,v in enumerate(obj):
66 for i,v in enumerate(obj):
67 obj[i] = squash_unicode(v)
67 obj[i] = squash_unicode(v)
68 elif isinstance(obj, unicode):
68 elif isinstance(obj, unicode):
69 obj = obj.encode('utf8')
69 obj = obj.encode('utf8')
70 return obj
70 return obj
71
71
72 #-----------------------------------------------------------------------------
72 #-----------------------------------------------------------------------------
73 # globals and defaults
73 # globals and defaults
74 #-----------------------------------------------------------------------------
74 #-----------------------------------------------------------------------------
75
75
76
76
77 # ISO8601-ify datetime objects
77 # ISO8601-ify datetime objects
78 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default)
78 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default)
79 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
79 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
80
80
81 pickle_packer = lambda o: pickle.dumps(o,-1)
81 pickle_packer = lambda o: pickle.dumps(o,-1)
82 pickle_unpacker = pickle.loads
82 pickle_unpacker = pickle.loads
83
83
84 default_packer = json_packer
84 default_packer = json_packer
85 default_unpacker = json_unpacker
85 default_unpacker = json_unpacker
86
86
87 DELIM=b"<IDS|MSG>"
87 DELIM=b"<IDS|MSG>"
88
88
89
89
90 #-----------------------------------------------------------------------------
90 #-----------------------------------------------------------------------------
91 # Mixin tools for apps that use Sessions
91 # Mixin tools for apps that use Sessions
92 #-----------------------------------------------------------------------------
92 #-----------------------------------------------------------------------------
93
93
94 session_aliases = dict(
94 session_aliases = dict(
95 ident = 'Session.session',
95 ident = 'Session.session',
96 user = 'Session.username',
96 user = 'Session.username',
97 keyfile = 'Session.keyfile',
97 keyfile = 'Session.keyfile',
98 )
98 )
99
99
100 session_flags = {
100 session_flags = {
101 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
101 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
102 'keyfile' : '' }},
102 'keyfile' : '' }},
103 """Use HMAC digests for authentication of messages.
103 """Use HMAC digests for authentication of messages.
104 Setting this flag will generate a new UUID to use as the HMAC key.
104 Setting this flag will generate a new UUID to use as the HMAC key.
105 """),
105 """),
106 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
106 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
107 """Don't authenticate messages."""),
107 """Don't authenticate messages."""),
108 }
108 }
109
109
110 def default_secure(cfg):
110 def default_secure(cfg):
111 """Set the default behavior for a config environment to be secure.
111 """Set the default behavior for a config environment to be secure.
112
112
113 If Session.key/keyfile have not been set, set Session.key to
113 If Session.key/keyfile have not been set, set Session.key to
114 a new random UUID.
114 a new random UUID.
115 """
115 """
116
116
117 if 'Session' in cfg:
117 if 'Session' in cfg:
118 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
118 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
119 return
119 return
120 # key/keyfile not specified, generate new UUID:
120 # key/keyfile not specified, generate new UUID:
121 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
121 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
122
122
123
123
124 #-----------------------------------------------------------------------------
124 #-----------------------------------------------------------------------------
125 # Classes
125 # Classes
126 #-----------------------------------------------------------------------------
126 #-----------------------------------------------------------------------------
127
127
128 class SessionFactory(LoggingConfigurable):
128 class SessionFactory(LoggingConfigurable):
129 """The Base class for configurables that have a Session, Context, logger,
129 """The Base class for configurables that have a Session, Context, logger,
130 and IOLoop.
130 and IOLoop.
131 """
131 """
132
132
133 logname = Unicode('')
133 logname = Unicode('')
134 def _logname_changed(self, name, old, new):
134 def _logname_changed(self, name, old, new):
135 self.log = logging.getLogger(new)
135 self.log = logging.getLogger(new)
136
136
137 # not configurable:
137 # not configurable:
138 context = Instance('zmq.Context')
138 context = Instance('zmq.Context')
139 def _context_default(self):
139 def _context_default(self):
140 return zmq.Context.instance()
140 return zmq.Context.instance()
141
141
142 session = Instance('IPython.zmq.session.Session')
142 session = Instance('IPython.zmq.session.Session')
143
143
144 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
144 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
145 def _loop_default(self):
145 def _loop_default(self):
146 return IOLoop.instance()
146 return IOLoop.instance()
147
147
148 def __init__(self, **kwargs):
148 def __init__(self, **kwargs):
149 super(SessionFactory, self).__init__(**kwargs)
149 super(SessionFactory, self).__init__(**kwargs)
150
150
151 if self.session is None:
151 if self.session is None:
152 # construct the session
152 # construct the session
153 self.session = Session(**kwargs)
153 self.session = Session(**kwargs)
154
154
155
155
156 class Message(object):
156 class Message(object):
157 """A simple message object that maps dict keys to attributes.
157 """A simple message object that maps dict keys to attributes.
158
158
159 A Message can be created from a dict and a dict from a Message instance
159 A Message can be created from a dict and a dict from a Message instance
160 simply by calling dict(msg_obj)."""
160 simply by calling dict(msg_obj)."""
161
161
162 def __init__(self, msg_dict):
162 def __init__(self, msg_dict):
163 dct = self.__dict__
163 dct = self.__dict__
164 for k, v in dict(msg_dict).iteritems():
164 for k, v in dict(msg_dict).iteritems():
165 if isinstance(v, dict):
165 if isinstance(v, dict):
166 v = Message(v)
166 v = Message(v)
167 dct[k] = v
167 dct[k] = v
168
168
169 # Having this iterator lets dict(msg_obj) work out of the box.
169 # Having this iterator lets dict(msg_obj) work out of the box.
170 def __iter__(self):
170 def __iter__(self):
171 return iter(self.__dict__.iteritems())
171 return iter(self.__dict__.iteritems())
172
172
173 def __repr__(self):
173 def __repr__(self):
174 return repr(self.__dict__)
174 return repr(self.__dict__)
175
175
176 def __str__(self):
176 def __str__(self):
177 return pprint.pformat(self.__dict__)
177 return pprint.pformat(self.__dict__)
178
178
179 def __contains__(self, k):
179 def __contains__(self, k):
180 return k in self.__dict__
180 return k in self.__dict__
181
181
182 def __getitem__(self, k):
182 def __getitem__(self, k):
183 return self.__dict__[k]
183 return self.__dict__[k]
184
184
185
185
186 def msg_header(msg_id, msg_type, username, session):
186 def msg_header(msg_id, msg_type, username, session):
187 date = datetime.now()
187 date = datetime.now()
188 return locals()
188 return locals()
189
189
190 def extract_header(msg_or_header):
190 def extract_header(msg_or_header):
191 """Given a message or header, return the header."""
191 """Given a message or header, return the header."""
192 if not msg_or_header:
192 if not msg_or_header:
193 return {}
193 return {}
194 try:
194 try:
195 # See if msg_or_header is the entire message.
195 # See if msg_or_header is the entire message.
196 h = msg_or_header['header']
196 h = msg_or_header['header']
197 except KeyError:
197 except KeyError:
198 try:
198 try:
199 # See if msg_or_header is just the header
199 # See if msg_or_header is just the header
200 h = msg_or_header['msg_id']
200 h = msg_or_header['msg_id']
201 except KeyError:
201 except KeyError:
202 raise
202 raise
203 else:
203 else:
204 h = msg_or_header
204 h = msg_or_header
205 if not isinstance(h, dict):
205 if not isinstance(h, dict):
206 h = dict(h)
206 h = dict(h)
207 return h
207 return h
208
208
209 class Session(Configurable):
209 class Session(Configurable):
210 """Object for handling serialization and sending of messages.
210 """Object for handling serialization and sending of messages.
211
211
212 The Session object handles building messages and sending them
212 The Session object handles building messages and sending them
213 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
213 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
214 other over the network via Session objects, and only need to work with the
214 other over the network via Session objects, and only need to work with the
215 dict-based IPython message spec. The Session will handle
215 dict-based IPython message spec. The Session will handle
216 serialization/deserialization, security, and metadata.
216 serialization/deserialization, security, and metadata.
217
217
218 Sessions support configurable serialiization via packer/unpacker traits,
218 Sessions support configurable serialiization via packer/unpacker traits,
219 and signing with HMAC digests via the key/keyfile traits.
219 and signing with HMAC digests via the key/keyfile traits.
220
220
221 Parameters
221 Parameters
222 ----------
222 ----------
223
223
224 debug : bool
224 debug : bool
225 whether to trigger extra debugging statements
225 whether to trigger extra debugging statements
226 packer/unpacker : str : 'json', 'pickle' or import_string
226 packer/unpacker : str : 'json', 'pickle' or import_string
227 importstrings for methods to serialize message parts. If just
227 importstrings for methods to serialize message parts. If just
228 'json' or 'pickle', predefined JSON and pickle packers will be used.
228 'json' or 'pickle', predefined JSON and pickle packers will be used.
229 Otherwise, the entire importstring must be used.
229 Otherwise, the entire importstring must be used.
230
230
231 The functions must accept at least valid JSON input, and output *bytes*.
231 The functions must accept at least valid JSON input, and output *bytes*.
232
232
233 For example, to use msgpack:
233 For example, to use msgpack:
234 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
234 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
235 pack/unpack : callables
235 pack/unpack : callables
236 You can also set the pack/unpack callables for serialization directly.
236 You can also set the pack/unpack callables for serialization directly.
237 session : bytes
237 session : bytes
238 the ID of this Session object. The default is to generate a new UUID.
238 the ID of this Session object. The default is to generate a new UUID.
239 username : unicode
239 username : unicode
240 username added to message headers. The default is to ask the OS.
240 username added to message headers. The default is to ask the OS.
241 key : bytes
241 key : bytes
242 The key used to initialize an HMAC signature. If unset, messages
242 The key used to initialize an HMAC signature. If unset, messages
243 will not be signed or checked.
243 will not be signed or checked.
244 keyfile : filepath
244 keyfile : filepath
245 The file containing a key. If this is set, `key` will be initialized
245 The file containing a key. If this is set, `key` will be initialized
246 to the contents of the file.
246 to the contents of the file.
247
247
248 """
248 """
249
249
250 debug=Bool(False, config=True, help="""Debug output in the Session""")
250 debug=Bool(False, config=True, help="""Debug output in the Session""")
251
251
252 packer = DottedObjectName('json',config=True,
252 packer = DottedObjectName('json',config=True,
253 help="""The name of the packer for serializing messages.
253 help="""The name of the packer for serializing messages.
254 Should be one of 'json', 'pickle', or an import name
254 Should be one of 'json', 'pickle', or an import name
255 for a custom callable serializer.""")
255 for a custom callable serializer.""")
256 def _packer_changed(self, name, old, new):
256 def _packer_changed(self, name, old, new):
257 if new.lower() == 'json':
257 if new.lower() == 'json':
258 self.pack = json_packer
258 self.pack = json_packer
259 self.unpack = json_unpacker
259 self.unpack = json_unpacker
260 elif new.lower() == 'pickle':
260 elif new.lower() == 'pickle':
261 self.pack = pickle_packer
261 self.pack = pickle_packer
262 self.unpack = pickle_unpacker
262 self.unpack = pickle_unpacker
263 else:
263 else:
264 self.pack = import_item(str(new))
264 self.pack = import_item(str(new))
265
265
266 unpacker = DottedObjectName('json', config=True,
266 unpacker = DottedObjectName('json', config=True,
267 help="""The name of the unpacker for unserializing messages.
267 help="""The name of the unpacker for unserializing messages.
268 Only used with custom functions for `packer`.""")
268 Only used with custom functions for `packer`.""")
269 def _unpacker_changed(self, name, old, new):
269 def _unpacker_changed(self, name, old, new):
270 if new.lower() == 'json':
270 if new.lower() == 'json':
271 self.pack = json_packer
271 self.pack = json_packer
272 self.unpack = json_unpacker
272 self.unpack = json_unpacker
273 elif new.lower() == 'pickle':
273 elif new.lower() == 'pickle':
274 self.pack = pickle_packer
274 self.pack = pickle_packer
275 self.unpack = pickle_unpacker
275 self.unpack = pickle_unpacker
276 else:
276 else:
277 self.unpack = import_item(str(new))
277 self.unpack = import_item(str(new))
278
278
279 session = CUnicode(u'', config=True,
279 session = CUnicode(u'', config=True,
280 help="""The UUID identifying this session.""")
280 help="""The UUID identifying this session.""")
281 def _session_default(self):
281 def _session_default(self):
282 u = unicode(uuid.uuid4())
282 u = unicode(uuid.uuid4())
283 self.bsession = u.encode('ascii')
283 self.bsession = u.encode('ascii')
284 return u
284 return u
285
285
286 def _session_changed(self, name, old, new):
286 def _session_changed(self, name, old, new):
287 self.bsession = self.session.encode('ascii')
287 self.bsession = self.session.encode('ascii')
288
288
289 # bsession is the session as bytes
289 # bsession is the session as bytes
290 bsession = CBytes(b'')
290 bsession = CBytes(b'')
291
291
292 username = Unicode(os.environ.get('USER',u'username'), config=True,
292 username = Unicode(os.environ.get('USER',u'username'), config=True,
293 help="""Username for the Session. Default is your system username.""")
293 help="""Username for the Session. Default is your system username.""")
294
294
295 # message signature related traits:
295 # message signature related traits:
296
296
297 key = CBytes(b'', config=True,
297 key = CBytes(b'', config=True,
298 help="""execution key, for extra authentication.""")
298 help="""execution key, for extra authentication.""")
299 def _key_changed(self, name, old, new):
299 def _key_changed(self, name, old, new):
300 if new:
300 if new:
301 self.auth = hmac.HMAC(new)
301 self.auth = hmac.HMAC(new)
302 else:
302 else:
303 self.auth = None
303 self.auth = None
304 auth = Instance(hmac.HMAC)
304 auth = Instance(hmac.HMAC)
305 digest_history = Set()
305 digest_history = Set()
306
306
307 keyfile = Unicode('', config=True,
307 keyfile = Unicode('', config=True,
308 help="""path to file containing execution key.""")
308 help="""path to file containing execution key.""")
309 def _keyfile_changed(self, name, old, new):
309 def _keyfile_changed(self, name, old, new):
310 with open(new, 'rb') as f:
310 with open(new, 'rb') as f:
311 self.key = f.read().strip()
311 self.key = f.read().strip()
312
312
313 # serialization traits:
313 # serialization traits:
314
314
315 pack = Any(default_packer) # the actual packer function
315 pack = Any(default_packer) # the actual packer function
316 def _pack_changed(self, name, old, new):
316 def _pack_changed(self, name, old, new):
317 if not callable(new):
317 if not callable(new):
318 raise TypeError("packer must be callable, not %s"%type(new))
318 raise TypeError("packer must be callable, not %s"%type(new))
319
319
320 unpack = Any(default_unpacker) # the actual packer function
320 unpack = Any(default_unpacker) # the actual packer function
321 def _unpack_changed(self, name, old, new):
321 def _unpack_changed(self, name, old, new):
322 # unpacker is not checked - it is assumed to be
322 # unpacker is not checked - it is assumed to be
323 if not callable(new):
323 if not callable(new):
324 raise TypeError("unpacker must be callable, not %s"%type(new))
324 raise TypeError("unpacker must be callable, not %s"%type(new))
325
325
326 def __init__(self, **kwargs):
326 def __init__(self, **kwargs):
327 """create a Session object
327 """create a Session object
328
328
329 Parameters
329 Parameters
330 ----------
330 ----------
331
331
332 debug : bool
332 debug : bool
333 whether to trigger extra debugging statements
333 whether to trigger extra debugging statements
334 packer/unpacker : str : 'json', 'pickle' or import_string
334 packer/unpacker : str : 'json', 'pickle' or import_string
335 importstrings for methods to serialize message parts. If just
335 importstrings for methods to serialize message parts. If just
336 'json' or 'pickle', predefined JSON and pickle packers will be used.
336 'json' or 'pickle', predefined JSON and pickle packers will be used.
337 Otherwise, the entire importstring must be used.
337 Otherwise, the entire importstring must be used.
338
338
339 The functions must accept at least valid JSON input, and output
339 The functions must accept at least valid JSON input, and output
340 *bytes*.
340 *bytes*.
341
341
342 For example, to use msgpack:
342 For example, to use msgpack:
343 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
343 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
344 pack/unpack : callables
344 pack/unpack : callables
345 You can also set the pack/unpack callables for serialization
345 You can also set the pack/unpack callables for serialization
346 directly.
346 directly.
347 session : unicode (must be ascii)
347 session : unicode (must be ascii)
348 the ID of this Session object. The default is to generate a new
348 the ID of this Session object. The default is to generate a new
349 UUID.
349 UUID.
350 bsession : bytes
350 bsession : bytes
351 The session as bytes
351 The session as bytes
352 username : unicode
352 username : unicode
353 username added to message headers. The default is to ask the OS.
353 username added to message headers. The default is to ask the OS.
354 key : bytes
354 key : bytes
355 The key used to initialize an HMAC signature. If unset, messages
355 The key used to initialize an HMAC signature. If unset, messages
356 will not be signed or checked.
356 will not be signed or checked.
357 keyfile : filepath
357 keyfile : filepath
358 The file containing a key. If this is set, `key` will be
358 The file containing a key. If this is set, `key` will be
359 initialized to the contents of the file.
359 initialized to the contents of the file.
360 """
360 """
361 super(Session, self).__init__(**kwargs)
361 super(Session, self).__init__(**kwargs)
362 self._check_packers()
362 self._check_packers()
363 self.none = self.pack({})
363 self.none = self.pack({})
364 # ensure self._session_default() if necessary, so bsession is defined:
364 # ensure self._session_default() if necessary, so bsession is defined:
365 self.session
365 self.session
366
366
367 @property
367 @property
368 def msg_id(self):
368 def msg_id(self):
369 """always return new uuid"""
369 """always return new uuid"""
370 return str(uuid.uuid4())
370 return str(uuid.uuid4())
371
371
372 def _check_packers(self):
372 def _check_packers(self):
373 """check packers for binary data and datetime support."""
373 """check packers for binary data and datetime support."""
374 pack = self.pack
374 pack = self.pack
375 unpack = self.unpack
375 unpack = self.unpack
376
376
377 # check simple serialization
377 # check simple serialization
378 msg = dict(a=[1,'hi'])
378 msg = dict(a=[1,'hi'])
379 try:
379 try:
380 packed = pack(msg)
380 packed = pack(msg)
381 except Exception:
381 except Exception:
382 raise ValueError("packer could not serialize a simple message")
382 raise ValueError("packer could not serialize a simple message")
383
383
384 # ensure packed message is bytes
384 # ensure packed message is bytes
385 if not isinstance(packed, bytes):
385 if not isinstance(packed, bytes):
386 raise ValueError("message packed to %r, but bytes are required"%type(packed))
386 raise ValueError("message packed to %r, but bytes are required"%type(packed))
387
387
388 # check that unpack is pack's inverse
388 # check that unpack is pack's inverse
389 try:
389 try:
390 unpacked = unpack(packed)
390 unpacked = unpack(packed)
391 except Exception:
391 except Exception:
392 raise ValueError("unpacker could not handle the packer's output")
392 raise ValueError("unpacker could not handle the packer's output")
393
393
394 # check datetime support
394 # check datetime support
395 msg = dict(t=datetime.now())
395 msg = dict(t=datetime.now())
396 try:
396 try:
397 unpacked = unpack(pack(msg))
397 unpacked = unpack(pack(msg))
398 except Exception:
398 except Exception:
399 self.pack = lambda o: pack(squash_dates(o))
399 self.pack = lambda o: pack(squash_dates(o))
400 self.unpack = lambda s: extract_dates(unpack(s))
400 self.unpack = lambda s: extract_dates(unpack(s))
401
401
402 def msg_header(self, msg_type):
402 def msg_header(self, msg_type):
403 return msg_header(self.msg_id, msg_type, self.username, self.session)
403 return msg_header(self.msg_id, msg_type, self.username, self.session)
404
404
405 def msg(self, msg_type, content=None, parent=None, subheader=None, header=None):
405 def msg(self, msg_type, content=None, parent=None, subheader=None, header=None):
406 """Return the nested message dict.
406 """Return the nested message dict.
407
407
408 This format is different from what is sent over the wire. The
408 This format is different from what is sent over the wire. The
409 serialize/unserialize methods converts this nested message dict to the wire
409 serialize/unserialize methods converts this nested message dict to the wire
410 format, which is a list of message parts.
410 format, which is a list of message parts.
411 """
411 """
412 msg = {}
412 msg = {}
413 header = self.msg_header(msg_type) if header is None else header
413 header = self.msg_header(msg_type) if header is None else header
414 msg['header'] = header
414 msg['header'] = header
415 msg['msg_id'] = header['msg_id']
415 msg['msg_id'] = header['msg_id']
416 msg['msg_type'] = header['msg_type']
416 msg['msg_type'] = header['msg_type']
417 msg['parent_header'] = {} if parent is None else extract_header(parent)
417 msg['parent_header'] = {} if parent is None else extract_header(parent)
418 msg['content'] = {} if content is None else content
418 msg['content'] = {} if content is None else content
419 sub = {} if subheader is None else subheader
419 sub = {} if subheader is None else subheader
420 msg['header'].update(sub)
420 msg['header'].update(sub)
421 return msg
421 return msg
422
422
423 def sign(self, msg_list):
423 def sign(self, msg_list):
424 """Sign a message with HMAC digest. If no auth, return b''.
424 """Sign a message with HMAC digest. If no auth, return b''.
425
425
426 Parameters
426 Parameters
427 ----------
427 ----------
428 msg_list : list
428 msg_list : list
429 The [p_header,p_parent,p_content] part of the message list.
429 The [p_header,p_parent,p_content] part of the message list.
430 """
430 """
431 if self.auth is None:
431 if self.auth is None:
432 return b''
432 return b''
433 h = self.auth.copy()
433 h = self.auth.copy()
434 for m in msg_list:
434 for m in msg_list:
435 h.update(m)
435 h.update(m)
436 return str_to_bytes(h.hexdigest())
436 return str_to_bytes(h.hexdigest())
437
437
438 def serialize(self, msg, ident=None):
438 def serialize(self, msg, ident=None):
439 """Serialize the message components to bytes.
439 """Serialize the message components to bytes.
440
440
441 This is roughly the inverse of unserialize. The serialize/unserialize
441 This is roughly the inverse of unserialize. The serialize/unserialize
442 methods work with full message lists, whereas pack/unpack work with
442 methods work with full message lists, whereas pack/unpack work with
443 the individual message parts in the message list.
443 the individual message parts in the message list.
444
444
445 Parameters
445 Parameters
446 ----------
446 ----------
447 msg : dict or Message
447 msg : dict or Message
448 The nexted message dict as returned by the self.msg method.
448 The nexted message dict as returned by the self.msg method.
449
449
450 Returns
450 Returns
451 -------
451 -------
452 msg_list : list
452 msg_list : list
453 The list of bytes objects to be sent with the format:
453 The list of bytes objects to be sent with the format:
454 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
454 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
455 buffer1,buffer2,...]. In this list, the p_* entities are
455 buffer1,buffer2,...]. In this list, the p_* entities are
456 the packed or serialized versions, so if JSON is used, these
456 the packed or serialized versions, so if JSON is used, these
457 are utf8 encoded JSON strings.
457 are utf8 encoded JSON strings.
458 """
458 """
459 content = msg.get('content', {})
459 content = msg.get('content', {})
460 if content is None:
460 if content is None:
461 content = self.none
461 content = self.none
462 elif isinstance(content, dict):
462 elif isinstance(content, dict):
463 content = self.pack(content)
463 content = self.pack(content)
464 elif isinstance(content, bytes):
464 elif isinstance(content, bytes):
465 # content is already packed, as in a relayed message
465 # content is already packed, as in a relayed message
466 pass
466 pass
467 elif isinstance(content, unicode):
467 elif isinstance(content, unicode):
468 # should be bytes, but JSON often spits out unicode
468 # should be bytes, but JSON often spits out unicode
469 content = content.encode('utf8')
469 content = content.encode('utf8')
470 else:
470 else:
471 raise TypeError("Content incorrect type: %s"%type(content))
471 raise TypeError("Content incorrect type: %s"%type(content))
472
472
473 real_message = [self.pack(msg['header']),
473 real_message = [self.pack(msg['header']),
474 self.pack(msg['parent_header']),
474 self.pack(msg['parent_header']),
475 content
475 content
476 ]
476 ]
477
477
478 to_send = []
478 to_send = []
479
479
480 if isinstance(ident, list):
480 if isinstance(ident, list):
481 # accept list of idents
481 # accept list of idents
482 to_send.extend(ident)
482 to_send.extend(ident)
483 elif ident is not None:
483 elif ident is not None:
484 to_send.append(ident)
484 to_send.append(ident)
485 to_send.append(DELIM)
485 to_send.append(DELIM)
486
486
487 signature = self.sign(real_message)
487 signature = self.sign(real_message)
488 to_send.append(signature)
488 to_send.append(signature)
489
489
490 to_send.extend(real_message)
490 to_send.extend(real_message)
491
491
492 return to_send
492 return to_send
493
493
494 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
494 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
495 buffers=None, subheader=None, track=False, header=None):
495 buffers=None, subheader=None, track=False, header=None):
496 """Build and send a message via stream or socket.
496 """Build and send a message via stream or socket.
497
497
498 The message format used by this function internally is as follows:
498 The message format used by this function internally is as follows:
499
499
500 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
500 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
501 buffer1,buffer2,...]
501 buffer1,buffer2,...]
502
502
503 The serialize/unserialize methods convert the nested message dict into this
503 The serialize/unserialize methods convert the nested message dict into this
504 format.
504 format.
505
505
506 Parameters
506 Parameters
507 ----------
507 ----------
508
508
509 stream : zmq.Socket or ZMQStream
509 stream : zmq.Socket or ZMQStream
510 The socket-like object used to send the data.
510 The socket-like object used to send the data.
511 msg_or_type : str or Message/dict
511 msg_or_type : str or Message/dict
512 Normally, msg_or_type will be a msg_type unless a message is being
512 Normally, msg_or_type will be a msg_type unless a message is being
513 sent more than once. If a header is supplied, this can be set to
513 sent more than once. If a header is supplied, this can be set to
514 None and the msg_type will be pulled from the header.
514 None and the msg_type will be pulled from the header.
515
515
516 content : dict or None
516 content : dict or None
517 The content of the message (ignored if msg_or_type is a message).
517 The content of the message (ignored if msg_or_type is a message).
518 header : dict or None
518 header : dict or None
519 The header dict for the message (ignores if msg_to_type is a message).
519 The header dict for the message (ignores if msg_to_type is a message).
520 parent : Message or dict or None
520 parent : Message or dict or None
521 The parent or parent header describing the parent of this message
521 The parent or parent header describing the parent of this message
522 (ignored if msg_or_type is a message).
522 (ignored if msg_or_type is a message).
523 ident : bytes or list of bytes
523 ident : bytes or list of bytes
524 The zmq.IDENTITY routing path.
524 The zmq.IDENTITY routing path.
525 subheader : dict or None
525 subheader : dict or None
526 Extra header keys for this message's header (ignored if msg_or_type
526 Extra header keys for this message's header (ignored if msg_or_type
527 is a message).
527 is a message).
528 buffers : list or None
528 buffers : list or None
529 The already-serialized buffers to be appended to the message.
529 The already-serialized buffers to be appended to the message.
530 track : bool
530 track : bool
531 Whether to track. Only for use with Sockets, because ZMQStream
531 Whether to track. Only for use with Sockets, because ZMQStream
532 objects cannot track messages.
532 objects cannot track messages.
533
533
534 Returns
534 Returns
535 -------
535 -------
536 msg : dict
536 msg : dict
537 The constructed message.
537 The constructed message.
538 (msg,tracker) : (dict, MessageTracker)
538 (msg,tracker) : (dict, MessageTracker)
539 if track=True, then a 2-tuple will be returned,
539 if track=True, then a 2-tuple will be returned,
540 the first element being the constructed
540 the first element being the constructed
541 message, and the second being the MessageTracker
541 message, and the second being the MessageTracker
542
542
543 """
543 """
544
544
545 if not isinstance(stream, (zmq.Socket, ZMQStream)):
545 if not isinstance(stream, (zmq.Socket, ZMQStream)):
546 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
546 raise TypeError("stream must be Socket or ZMQStream, not %r"%type(stream))
547 elif track and isinstance(stream, ZMQStream):
547 elif track and isinstance(stream, ZMQStream):
548 raise TypeError("ZMQStream cannot track messages")
548 raise TypeError("ZMQStream cannot track messages")
549
549
550 if isinstance(msg_or_type, (Message, dict)):
550 if isinstance(msg_or_type, (Message, dict)):
551 # We got a Message or message dict, not a msg_type so don't
551 # We got a Message or message dict, not a msg_type so don't
552 # build a new Message.
552 # build a new Message.
553 msg = msg_or_type
553 msg = msg_or_type
554 else:
554 else:
555 msg = self.msg(msg_or_type, content=content, parent=parent,
555 msg = self.msg(msg_or_type, content=content, parent=parent,
556 subheader=subheader, header=header)
556 subheader=subheader, header=header)
557
557
558 buffers = [] if buffers is None else buffers
558 buffers = [] if buffers is None else buffers
559 to_send = self.serialize(msg, ident)
559 to_send = self.serialize(msg, ident)
560 flag = 0
560 flag = 0
561 if buffers:
561 if buffers:
562 flag = zmq.SNDMORE
562 flag = zmq.SNDMORE
563 _track = False
563 _track = False
564 else:
564 else:
565 _track=track
565 _track=track
566 if track:
566 if track:
567 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
567 tracker = stream.send_multipart(to_send, flag, copy=False, track=_track)
568 else:
568 else:
569 tracker = stream.send_multipart(to_send, flag, copy=False)
569 tracker = stream.send_multipart(to_send, flag, copy=False)
570 for b in buffers[:-1]:
570 for b in buffers[:-1]:
571 stream.send(b, flag, copy=False)
571 stream.send(b, flag, copy=False)
572 if buffers:
572 if buffers:
573 if track:
573 if track:
574 tracker = stream.send(buffers[-1], copy=False, track=track)
574 tracker = stream.send(buffers[-1], copy=False, track=track)
575 else:
575 else:
576 tracker = stream.send(buffers[-1], copy=False)
576 tracker = stream.send(buffers[-1], copy=False)
577
577
578 # omsg = Message(msg)
578 # omsg = Message(msg)
579 if self.debug:
579 if self.debug:
580 pprint.pprint(msg)
580 pprint.pprint(msg)
581 pprint.pprint(to_send)
581 pprint.pprint(to_send)
582 pprint.pprint(buffers)
582 pprint.pprint(buffers)
583
583
584 msg['tracker'] = tracker
584 msg['tracker'] = tracker
585
585
586 return msg
586 return msg
587
587
588 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
588 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
589 """Send a raw message via ident path.
589 """Send a raw message via ident path.
590
590
591 This method is used to send a already serialized message.
591 This method is used to send a already serialized message.
592
592
593 Parameters
593 Parameters
594 ----------
594 ----------
595 stream : ZMQStream or Socket
595 stream : ZMQStream or Socket
596 The ZMQ stream or socket to use for sending the message.
596 The ZMQ stream or socket to use for sending the message.
597 msg_list : list
597 msg_list : list
598 The serialized list of messages to send. This only includes the
598 The serialized list of messages to send. This only includes the
599 [p_header,p_parent,p_content,buffer1,buffer2,...] portion of
599 [p_header,p_parent,p_content,buffer1,buffer2,...] portion of
600 the message.
600 the message.
601 ident : ident or list
601 ident : ident or list
602 A single ident or a list of idents to use in sending.
602 A single ident or a list of idents to use in sending.
603 """
603 """
604 to_send = []
604 to_send = []
605 if isinstance(ident, bytes):
605 if isinstance(ident, bytes):
606 ident = [ident]
606 ident = [ident]
607 if ident is not None:
607 if ident is not None:
608 to_send.extend(ident)
608 to_send.extend(ident)
609
609
610 to_send.append(DELIM)
610 to_send.append(DELIM)
611 to_send.append(self.sign(msg_list))
611 to_send.append(self.sign(msg_list))
612 to_send.extend(msg_list)
612 to_send.extend(msg_list)
613 stream.send_multipart(msg_list, flags, copy=copy)
613 stream.send_multipart(msg_list, flags, copy=copy)
614
614
615 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
615 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
616 """Receive and unpack a message.
616 """Receive and unpack a message.
617
617
618 Parameters
618 Parameters
619 ----------
619 ----------
620 socket : ZMQStream or Socket
620 socket : ZMQStream or Socket
621 The socket or stream to use in receiving.
621 The socket or stream to use in receiving.
622
622
623 Returns
623 Returns
624 -------
624 -------
625 [idents], msg
625 [idents], msg
626 [idents] is a list of idents and msg is a nested message dict of
626 [idents] is a list of idents and msg is a nested message dict of
627 same format as self.msg returns.
627 same format as self.msg returns.
628 """
628 """
629 if isinstance(socket, ZMQStream):
629 if isinstance(socket, ZMQStream):
630 socket = socket.socket
630 socket = socket.socket
631 try:
631 try:
632 msg_list = socket.recv_multipart(mode)
632 msg_list = socket.recv_multipart(mode, copy=copy)
633 except zmq.ZMQError as e:
633 except zmq.ZMQError as e:
634 if e.errno == zmq.EAGAIN:
634 if e.errno == zmq.EAGAIN:
635 # We can convert EAGAIN to None as we know in this case
635 # We can convert EAGAIN to None as we know in this case
636 # recv_multipart won't return None.
636 # recv_multipart won't return None.
637 return None,None
637 return None,None
638 else:
638 else:
639 raise
639 raise
640 # split multipart message into identity list and message dict
640 # split multipart message into identity list and message dict
641 # invalid large messages can cause very expensive string comparisons
641 # invalid large messages can cause very expensive string comparisons
642 idents, msg_list = self.feed_identities(msg_list, copy)
642 idents, msg_list = self.feed_identities(msg_list, copy)
643 try:
643 try:
644 return idents, self.unserialize(msg_list, content=content, copy=copy)
644 return idents, self.unserialize(msg_list, content=content, copy=copy)
645 except Exception as e:
645 except Exception as e:
646 # TODO: handle it
646 # TODO: handle it
647 raise e
647 raise e
648
648
649 def feed_identities(self, msg_list, copy=True):
649 def feed_identities(self, msg_list, copy=True):
650 """Split the identities from the rest of the message.
650 """Split the identities from the rest of the message.
651
651
652 Feed until DELIM is reached, then return the prefix as idents and
652 Feed until DELIM is reached, then return the prefix as idents and
653 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
653 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
654 but that would be silly.
654 but that would be silly.
655
655
656 Parameters
656 Parameters
657 ----------
657 ----------
658 msg_list : a list of Message or bytes objects
658 msg_list : a list of Message or bytes objects
659 The message to be split.
659 The message to be split.
660 copy : bool
660 copy : bool
661 flag determining whether the arguments are bytes or Messages
661 flag determining whether the arguments are bytes or Messages
662
662
663 Returns
663 Returns
664 -------
664 -------
665 (idents, msg_list) : two lists
665 (idents, msg_list) : two lists
666 idents will always be a list of bytes, each of which is a ZMQ
666 idents will always be a list of bytes, each of which is a ZMQ
667 identity. msg_list will be a list of bytes or zmq.Messages of the
667 identity. msg_list will be a list of bytes or zmq.Messages of the
668 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
668 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
669 should be unpackable/unserializable via self.unserialize at this
669 should be unpackable/unserializable via self.unserialize at this
670 point.
670 point.
671 """
671 """
672 if copy:
672 if copy:
673 idx = msg_list.index(DELIM)
673 idx = msg_list.index(DELIM)
674 return msg_list[:idx], msg_list[idx+1:]
674 return msg_list[:idx], msg_list[idx+1:]
675 else:
675 else:
676 failed = True
676 failed = True
677 for idx,m in enumerate(msg_list):
677 for idx,m in enumerate(msg_list):
678 if m.bytes == DELIM:
678 if m.bytes == DELIM:
679 failed = False
679 failed = False
680 break
680 break
681 if failed:
681 if failed:
682 raise ValueError("DELIM not in msg_list")
682 raise ValueError("DELIM not in msg_list")
683 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
683 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
684 return [m.bytes for m in idents], msg_list
684 return [m.bytes for m in idents], msg_list
685
685
686 def unserialize(self, msg_list, content=True, copy=True):
686 def unserialize(self, msg_list, content=True, copy=True):
687 """Unserialize a msg_list to a nested message dict.
687 """Unserialize a msg_list to a nested message dict.
688
688
689 This is roughly the inverse of serialize. The serialize/unserialize
689 This is roughly the inverse of serialize. The serialize/unserialize
690 methods work with full message lists, whereas pack/unpack work with
690 methods work with full message lists, whereas pack/unpack work with
691 the individual message parts in the message list.
691 the individual message parts in the message list.
692
692
693 Parameters:
693 Parameters:
694 -----------
694 -----------
695 msg_list : list of bytes or Message objects
695 msg_list : list of bytes or Message objects
696 The list of message parts of the form [HMAC,p_header,p_parent,
696 The list of message parts of the form [HMAC,p_header,p_parent,
697 p_content,buffer1,buffer2,...].
697 p_content,buffer1,buffer2,...].
698 content : bool (True)
698 content : bool (True)
699 Whether to unpack the content dict (True), or leave it packed
699 Whether to unpack the content dict (True), or leave it packed
700 (False).
700 (False).
701 copy : bool (True)
701 copy : bool (True)
702 Whether to return the bytes (True), or the non-copying Message
702 Whether to return the bytes (True), or the non-copying Message
703 object in each place (False).
703 object in each place (False).
704
704
705 Returns
705 Returns
706 -------
706 -------
707 msg : dict
707 msg : dict
708 The nested message dict with top-level keys [header, parent_header,
708 The nested message dict with top-level keys [header, parent_header,
709 content, buffers].
709 content, buffers].
710 """
710 """
711 minlen = 4
711 minlen = 4
712 message = {}
712 message = {}
713 if not copy:
713 if not copy:
714 for i in range(minlen):
714 for i in range(minlen):
715 msg_list[i] = msg_list[i].bytes
715 msg_list[i] = msg_list[i].bytes
716 if self.auth is not None:
716 if self.auth is not None:
717 signature = msg_list[0]
717 signature = msg_list[0]
718 if not signature:
718 if not signature:
719 raise ValueError("Unsigned Message")
719 raise ValueError("Unsigned Message")
720 if signature in self.digest_history:
720 if signature in self.digest_history:
721 raise ValueError("Duplicate Signature: %r"%signature)
721 raise ValueError("Duplicate Signature: %r"%signature)
722 self.digest_history.add(signature)
722 self.digest_history.add(signature)
723 check = self.sign(msg_list[1:4])
723 check = self.sign(msg_list[1:4])
724 if not signature == check:
724 if not signature == check:
725 raise ValueError("Invalid Signature: %r"%signature)
725 raise ValueError("Invalid Signature: %r"%signature)
726 if not len(msg_list) >= minlen:
726 if not len(msg_list) >= minlen:
727 raise TypeError("malformed message, must have at least %i elements"%minlen)
727 raise TypeError("malformed message, must have at least %i elements"%minlen)
728 header = self.unpack(msg_list[1])
728 header = self.unpack(msg_list[1])
729 message['header'] = header
729 message['header'] = header
730 message['msg_id'] = header['msg_id']
730 message['msg_id'] = header['msg_id']
731 message['msg_type'] = header['msg_type']
731 message['msg_type'] = header['msg_type']
732 message['parent_header'] = self.unpack(msg_list[2])
732 message['parent_header'] = self.unpack(msg_list[2])
733 if content:
733 if content:
734 message['content'] = self.unpack(msg_list[3])
734 message['content'] = self.unpack(msg_list[3])
735 else:
735 else:
736 message['content'] = msg_list[3]
736 message['content'] = msg_list[3]
737
737
738 message['buffers'] = msg_list[4:]
738 message['buffers'] = msg_list[4:]
739 return message
739 return message
740
740
741 def test_msg2obj():
741 def test_msg2obj():
742 am = dict(x=1)
742 am = dict(x=1)
743 ao = Message(am)
743 ao = Message(am)
744 assert ao.x == am['x']
744 assert ao.x == am['x']
745
745
746 am['y'] = dict(z=1)
746 am['y'] = dict(z=1)
747 ao = Message(am)
747 ao = Message(am)
748 assert ao.y.z == am['y']['z']
748 assert ao.y.z == am['y']['z']
749
749
750 k1, k2 = 'y', 'z'
750 k1, k2 = 'y', 'z'
751 assert ao[k1][k2] == am[k1][k2]
751 assert ao[k1][k2] == am[k1][k2]
752
752
753 am2 = dict(ao)
753 am2 = dict(ao)
754 assert am['x'] == am2['x']
754 assert am['x'] == am2['x']
755 assert am['y']['z'] == am2['y']['z']
755 assert am['y']['z'] == am2['y']['z']
756
756
@@ -1,159 +1,191 b''
1 """test IPython.embed_kernel()"""
1 """test IPython.embed_kernel()"""
2
2
3 #-------------------------------------------------------------------------------
3 #-------------------------------------------------------------------------------
4 # Copyright (C) 2012 The IPython Development Team
4 # Copyright (C) 2012 The IPython Development Team
5 #
5 #
6 # Distributed under the terms of the BSD License. The full license is in
6 # Distributed under the terms of the BSD License. The full license is in
7 # the file COPYING, distributed as part of this software.
7 # the file COPYING, distributed as part of this software.
8 #-------------------------------------------------------------------------------
8 #-------------------------------------------------------------------------------
9
9
10 #-------------------------------------------------------------------------------
10 #-------------------------------------------------------------------------------
11 # Imports
11 # Imports
12 #-------------------------------------------------------------------------------
12 #-------------------------------------------------------------------------------
13
13
14 import os
14 import os
15 import shutil
15 import shutil
16 import sys
16 import sys
17 import tempfile
17 import tempfile
18 import time
18 import time
19
19
20 from contextlib import contextmanager
20 from contextlib import contextmanager
21 from subprocess import Popen, PIPE
21 from subprocess import Popen, PIPE
22
22
23 import nose.tools as nt
23 import nose.tools as nt
24
24
25 from IPython.zmq.blockingkernelmanager import BlockingKernelManager
25 from IPython.zmq.blockingkernelmanager import BlockingKernelManager
26 from IPython.utils import path
26 from IPython.utils import path, py3compat
27
27
28
28
29 #-------------------------------------------------------------------------------
29 #-------------------------------------------------------------------------------
30 # Tests
30 # Tests
31 #-------------------------------------------------------------------------------
31 #-------------------------------------------------------------------------------
32
32
33 def setup():
33 def setup():
34 """setup temporary IPYTHONDIR for tests"""
34 """setup temporary IPYTHONDIR for tests"""
35 global IPYTHONDIR
35 global IPYTHONDIR
36 global env
36 global env
37 global save_get_ipython_dir
37 global save_get_ipython_dir
38
38
39 IPYTHONDIR = tempfile.mkdtemp()
39 IPYTHONDIR = tempfile.mkdtemp()
40 env = dict(IPYTHONDIR=IPYTHONDIR)
40 env = dict(IPYTHONDIR=IPYTHONDIR)
41 save_get_ipython_dir = path.get_ipython_dir
41 save_get_ipython_dir = path.get_ipython_dir
42 path.get_ipython_dir = lambda : IPYTHONDIR
42 path.get_ipython_dir = lambda : IPYTHONDIR
43
43
44
44
45 def teardown():
45 def teardown():
46 path.get_ipython_dir = save_get_ipython_dir
46 path.get_ipython_dir = save_get_ipython_dir
47
47
48 try:
48 try:
49 shutil.rmtree(IPYTHONDIR)
49 shutil.rmtree(IPYTHONDIR)
50 except (OSError, IOError):
50 except (OSError, IOError):
51 # no such file
51 # no such file
52 pass
52 pass
53
53
54
54
55 @contextmanager
55 @contextmanager
56 def setup_kernel(cmd):
56 def setup_kernel(cmd):
57 """start an embedded kernel in a subprocess, and wait for it to be ready
57 """start an embedded kernel in a subprocess, and wait for it to be ready
58
58
59 Returns
59 Returns
60 -------
60 -------
61 kernel_manager: connected KernelManager instance
61 kernel_manager: connected KernelManager instance
62 """
62 """
63 kernel = Popen([sys.executable, '-c', cmd], stdout=PIPE, stderr=PIPE, env=env)
63 kernel = Popen([sys.executable, '-c', cmd], stdout=PIPE, stderr=PIPE, env=env)
64 connection_file = os.path.join(IPYTHONDIR,
64 connection_file = os.path.join(IPYTHONDIR,
65 'profile_default',
65 'profile_default',
66 'security',
66 'security',
67 'kernel-%i.json' % kernel.pid
67 'kernel-%i.json' % kernel.pid
68 )
68 )
69 # wait for connection file to exist, timeout after 5s
69 # wait for connection file to exist, timeout after 5s
70 tic = time.time()
70 tic = time.time()
71 while not os.path.exists(connection_file) and kernel.poll() is None and time.time() < tic + 5:
71 while not os.path.exists(connection_file) and kernel.poll() is None and time.time() < tic + 10:
72 time.sleep(0.1)
72 time.sleep(0.1)
73
73
74 if kernel.poll() is not None:
75 o,e = kernel.communicate()
76 e = py3compat.cast_unicode(e)
77 raise IOError("Kernel failed to start:\n%s" % e)
78
74 if not os.path.exists(connection_file):
79 if not os.path.exists(connection_file):
75 if kernel.poll() is None:
80 if kernel.poll() is None:
76 kernel.terminate()
81 kernel.terminate()
77 raise IOError("Connection file %r never arrived" % connection_file)
82 raise IOError("Connection file %r never arrived" % connection_file)
78
83
79 if kernel.poll() is not None:
80 raise IOError("Kernel failed to start")
81
82 km = BlockingKernelManager(connection_file=connection_file)
84 km = BlockingKernelManager(connection_file=connection_file)
83 km.load_connection_file()
85 km.load_connection_file()
84 km.start_channels()
86 km.start_channels()
85
87
86 try:
88 try:
87 yield km
89 yield km
88 finally:
90 finally:
89 km.stop_channels()
91 km.stop_channels()
90 kernel.terminate()
92 kernel.terminate()
91
93
92 def test_embed_kernel_basic():
94 def test_embed_kernel_basic():
93 """IPython.embed_kernel() is basically functional"""
95 """IPython.embed_kernel() is basically functional"""
94 cmd = '\n'.join([
96 cmd = '\n'.join([
95 'from IPython import embed_kernel',
97 'from IPython import embed_kernel',
96 'def go():',
98 'def go():',
97 ' a=5',
99 ' a=5',
98 ' b="hi there"',
100 ' b="hi there"',
99 ' embed_kernel()',
101 ' embed_kernel()',
100 'go()',
102 'go()',
101 '',
103 '',
102 ])
104 ])
103
105
104 with setup_kernel(cmd) as km:
106 with setup_kernel(cmd) as km:
105 shell = km.shell_channel
107 shell = km.shell_channel
106
108
107 # oinfo a (int)
109 # oinfo a (int)
108 msg_id = shell.object_info('a')
110 msg_id = shell.object_info('a')
109 msg = shell.get_msg(block=True, timeout=2)
111 msg = shell.get_msg(block=True, timeout=2)
110 content = msg['content']
112 content = msg['content']
111 nt.assert_true(content['found'])
113 nt.assert_true(content['found'])
112
114
113 msg_id = shell.execute("c=a*2")
115 msg_id = shell.execute("c=a*2")
114 msg = shell.get_msg(block=True, timeout=2)
116 msg = shell.get_msg(block=True, timeout=2)
115 content = msg['content']
117 content = msg['content']
116 nt.assert_equals(content['status'], u'ok')
118 nt.assert_equals(content['status'], u'ok')
117
119
118 # oinfo c (should be 10)
120 # oinfo c (should be 10)
119 msg_id = shell.object_info('c')
121 msg_id = shell.object_info('c')
120 msg = shell.get_msg(block=True, timeout=2)
122 msg = shell.get_msg(block=True, timeout=2)
121 content = msg['content']
123 content = msg['content']
122 nt.assert_true(content['found'])
124 nt.assert_true(content['found'])
123 nt.assert_equals(content['string_form'], u'10')
125 nt.assert_equals(content['string_form'], u'10')
124
126
125 def test_embed_kernel_namespace():
127 def test_embed_kernel_namespace():
126 """IPython.embed_kernel() inherits calling namespace"""
128 """IPython.embed_kernel() inherits calling namespace"""
127 cmd = '\n'.join([
129 cmd = '\n'.join([
128 'from IPython import embed_kernel',
130 'from IPython import embed_kernel',
129 'def go():',
131 'def go():',
130 ' a=5',
132 ' a=5',
131 ' b="hi there"',
133 ' b="hi there"',
132 ' embed_kernel()',
134 ' embed_kernel()',
133 'go()',
135 'go()',
134 '',
136 '',
135 ])
137 ])
136
138
137 with setup_kernel(cmd) as km:
139 with setup_kernel(cmd) as km:
138 shell = km.shell_channel
140 shell = km.shell_channel
139
141
140 # oinfo a (int)
142 # oinfo a (int)
141 msg_id = shell.object_info('a')
143 msg_id = shell.object_info('a')
142 msg = shell.get_msg(block=True, timeout=2)
144 msg = shell.get_msg(block=True, timeout=2)
143 content = msg['content']
145 content = msg['content']
144 nt.assert_true(content['found'])
146 nt.assert_true(content['found'])
145 nt.assert_equals(content['string_form'], u'5')
147 nt.assert_equals(content['string_form'], u'5')
146
148
147 # oinfo b (str)
149 # oinfo b (str)
148 msg_id = shell.object_info('b')
150 msg_id = shell.object_info('b')
149 msg = shell.get_msg(block=True, timeout=2)
151 msg = shell.get_msg(block=True, timeout=2)
150 content = msg['content']
152 content = msg['content']
151 nt.assert_true(content['found'])
153 nt.assert_true(content['found'])
152 nt.assert_equals(content['string_form'], u'hi there')
154 nt.assert_equals(content['string_form'], u'hi there')
153
155
154 # oinfo c (undefined)
156 # oinfo c (undefined)
155 msg_id = shell.object_info('c')
157 msg_id = shell.object_info('c')
156 msg = shell.get_msg(block=True, timeout=2)
158 msg = shell.get_msg(block=True, timeout=2)
157 content = msg['content']
159 content = msg['content']
158 nt.assert_false(content['found'])
160 nt.assert_false(content['found'])
159
161
162 def test_embed_kernel_reentrant():
163 """IPython.embed_kernel() can be called multiple times"""
164 cmd = '\n'.join([
165 'from IPython import embed_kernel',
166 'count = 0',
167 'def go():',
168 ' global count',
169 ' embed_kernel()',
170 ' count = count + 1',
171 '',
172 'while True:'
173 ' go()',
174 '',
175 ])
176
177 with setup_kernel(cmd) as km:
178 shell = km.shell_channel
179 for i in range(5):
180 msg_id = shell.object_info('count')
181 msg = shell.get_msg(block=True, timeout=2)
182 content = msg['content']
183 nt.assert_true(content['found'])
184 nt.assert_equals(content['string_form'], unicode(i))
185
186 # exit from embed_kernel
187 shell.execute("get_ipython().exit_now = True")
188 msg = shell.get_msg(block=True, timeout=2)
189 time.sleep(0.2)
190
191
@@ -1,525 +1,541 b''
1 """A ZMQ-based subclass of InteractiveShell.
1 """A ZMQ-based subclass of InteractiveShell.
2
2
3 This code is meant to ease the refactoring of the base InteractiveShell into
3 This code is meant to ease the refactoring of the base InteractiveShell into
4 something with a cleaner architecture for 2-process use, without actually
4 something with a cleaner architecture for 2-process use, without actually
5 breaking InteractiveShell itself. So we're doing something a bit ugly, where
5 breaking InteractiveShell itself. So we're doing something a bit ugly, where
6 we subclass and override what we want to fix. Once this is working well, we
6 we subclass and override what we want to fix. Once this is working well, we
7 can go back to the base class and refactor the code for a cleaner inheritance
7 can go back to the base class and refactor the code for a cleaner inheritance
8 implementation that doesn't rely on so much monkeypatching.
8 implementation that doesn't rely on so much monkeypatching.
9
9
10 But this lets us maintain a fully working IPython as we develop the new
10 But this lets us maintain a fully working IPython as we develop the new
11 machinery. This should thus be thought of as scaffolding.
11 machinery. This should thus be thought of as scaffolding.
12 """
12 """
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14 # Imports
14 # Imports
15 #-----------------------------------------------------------------------------
15 #-----------------------------------------------------------------------------
16 from __future__ import print_function
16 from __future__ import print_function
17
17
18 # Stdlib
18 # Stdlib
19 import inspect
19 import inspect
20 import os
20 import os
21 import sys
21 import sys
22 import time
22 from subprocess import Popen, PIPE
23 from subprocess import Popen, PIPE
23
24
25 # System library imports
26 from zmq.eventloop import ioloop
27
24 # Our own
28 # Our own
25 from IPython.core.interactiveshell import (
29 from IPython.core.interactiveshell import (
26 InteractiveShell, InteractiveShellABC
30 InteractiveShell, InteractiveShellABC
27 )
31 )
28 from IPython.core import page, pylabtools
32 from IPython.core import page, pylabtools
29 from IPython.core.autocall import ZMQExitAutocall
33 from IPython.core.autocall import ZMQExitAutocall
30 from IPython.core.displaypub import DisplayPublisher
34 from IPython.core.displaypub import DisplayPublisher
31 from IPython.core.macro import Macro
35 from IPython.core.macro import Macro
32 from IPython.core.magic import MacroToEdit
36 from IPython.core.magic import MacroToEdit
33 from IPython.core.payloadpage import install_payload_page
37 from IPython.core.payloadpage import install_payload_page
34 from IPython.lib.kernel import (
38 from IPython.lib.kernel import (
35 get_connection_file, get_connection_info, connect_qtconsole
39 get_connection_file, get_connection_info, connect_qtconsole
36 )
40 )
37 from IPython.testing.skipdoctest import skip_doctest
41 from IPython.testing.skipdoctest import skip_doctest
38 from IPython.utils import io
42 from IPython.utils import io
39 from IPython.utils.jsonutil import json_clean
43 from IPython.utils.jsonutil import json_clean
40 from IPython.utils.path import get_py_filename
44 from IPython.utils.path import get_py_filename
41 from IPython.utils.process import arg_split
45 from IPython.utils.process import arg_split
42 from IPython.utils.traitlets import Instance, Type, Dict, CBool
46 from IPython.utils.traitlets import Instance, Type, Dict, CBool, CBytes
43 from IPython.utils.warn import warn, error
47 from IPython.utils.warn import warn, error
44 from IPython.zmq.displayhook import ZMQShellDisplayHook, _encode_binary
48 from IPython.zmq.displayhook import ZMQShellDisplayHook, _encode_binary
45 from IPython.zmq.session import extract_header
49 from IPython.zmq.session import extract_header
46 from session import Session
50 from session import Session
47
51
48
52
49 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
50 # Functions and classes
54 # Functions and classes
51 #-----------------------------------------------------------------------------
55 #-----------------------------------------------------------------------------
52
56
53 class ZMQDisplayPublisher(DisplayPublisher):
57 class ZMQDisplayPublisher(DisplayPublisher):
54 """A display publisher that publishes data using a ZeroMQ PUB socket."""
58 """A display publisher that publishes data using a ZeroMQ PUB socket."""
55
59
56 session = Instance(Session)
60 session = Instance(Session)
57 pub_socket = Instance('zmq.Socket')
61 pub_socket = Instance('zmq.Socket')
58 parent_header = Dict({})
62 parent_header = Dict({})
63 topic = CBytes(b'displaypub')
59
64
60 def set_parent(self, parent):
65 def set_parent(self, parent):
61 """Set the parent for outbound messages."""
66 """Set the parent for outbound messages."""
62 self.parent_header = extract_header(parent)
67 self.parent_header = extract_header(parent)
63
68
64 def _flush_streams(self):
69 def _flush_streams(self):
65 """flush IO Streams prior to display"""
70 """flush IO Streams prior to display"""
66 sys.stdout.flush()
71 sys.stdout.flush()
67 sys.stderr.flush()
72 sys.stderr.flush()
68
73
69 def publish(self, source, data, metadata=None):
74 def publish(self, source, data, metadata=None):
70 self._flush_streams()
75 self._flush_streams()
71 if metadata is None:
76 if metadata is None:
72 metadata = {}
77 metadata = {}
73 self._validate_data(source, data, metadata)
78 self._validate_data(source, data, metadata)
74 content = {}
79 content = {}
75 content['source'] = source
80 content['source'] = source
76 _encode_binary(data)
81 _encode_binary(data)
77 content['data'] = data
82 content['data'] = data
78 content['metadata'] = metadata
83 content['metadata'] = metadata
79 self.session.send(
84 self.session.send(
80 self.pub_socket, u'display_data', json_clean(content),
85 self.pub_socket, u'display_data', json_clean(content),
81 parent=self.parent_header
86 parent=self.parent_header, ident=self.topic,
82 )
87 )
83
88
84 def clear_output(self, stdout=True, stderr=True, other=True):
89 def clear_output(self, stdout=True, stderr=True, other=True):
85 content = dict(stdout=stdout, stderr=stderr, other=other)
90 content = dict(stdout=stdout, stderr=stderr, other=other)
86
91
87 if stdout:
92 if stdout:
88 print('\r', file=sys.stdout, end='')
93 print('\r', file=sys.stdout, end='')
89 if stderr:
94 if stderr:
90 print('\r', file=sys.stderr, end='')
95 print('\r', file=sys.stderr, end='')
91
96
92 self._flush_streams()
97 self._flush_streams()
93
98
94 self.session.send(
99 self.session.send(
95 self.pub_socket, u'clear_output', content,
100 self.pub_socket, u'clear_output', content,
96 parent=self.parent_header
101 parent=self.parent_header, ident=self.topic,
97 )
102 )
98
103
99 class ZMQInteractiveShell(InteractiveShell):
104 class ZMQInteractiveShell(InteractiveShell):
100 """A subclass of InteractiveShell for ZMQ."""
105 """A subclass of InteractiveShell for ZMQ."""
101
106
102 displayhook_class = Type(ZMQShellDisplayHook)
107 displayhook_class = Type(ZMQShellDisplayHook)
103 display_pub_class = Type(ZMQDisplayPublisher)
108 display_pub_class = Type(ZMQDisplayPublisher)
104
109
105 # Override the traitlet in the parent class, because there's no point using
110 # Override the traitlet in the parent class, because there's no point using
106 # readline for the kernel. Can be removed when the readline code is moved
111 # readline for the kernel. Can be removed when the readline code is moved
107 # to the terminal frontend.
112 # to the terminal frontend.
108 colors_force = CBool(True)
113 colors_force = CBool(True)
109 readline_use = CBool(False)
114 readline_use = CBool(False)
110 # autoindent has no meaning in a zmqshell, and attempting to enable it
115 # autoindent has no meaning in a zmqshell, and attempting to enable it
111 # will print a warning in the absence of readline.
116 # will print a warning in the absence of readline.
112 autoindent = CBool(False)
117 autoindent = CBool(False)
113
118
114 exiter = Instance(ZMQExitAutocall)
119 exiter = Instance(ZMQExitAutocall)
115 def _exiter_default(self):
120 def _exiter_default(self):
116 return ZMQExitAutocall(self)
121 return ZMQExitAutocall(self)
122
123 def _exit_now_changed(self, name, old, new):
124 """stop eventloop when exit_now fires"""
125 if new:
126 loop = ioloop.IOLoop.instance()
127 loop.add_timeout(time.time()+0.1, loop.stop)
117
128
118 keepkernel_on_exit = None
129 keepkernel_on_exit = None
119
130
120 # Over ZeroMQ, GUI control isn't done with PyOS_InputHook as there is no
131 # Over ZeroMQ, GUI control isn't done with PyOS_InputHook as there is no
121 # interactive input being read; we provide event loop support in ipkernel
132 # interactive input being read; we provide event loop support in ipkernel
122 from .eventloops import enable_gui
133 from .eventloops import enable_gui
123 enable_gui = staticmethod(enable_gui)
134 enable_gui = staticmethod(enable_gui)
124
135
125 def init_environment(self):
136 def init_environment(self):
126 """Configure the user's environment.
137 """Configure the user's environment.
127
138
128 """
139 """
129 env = os.environ
140 env = os.environ
130 # These two ensure 'ls' produces nice coloring on BSD-derived systems
141 # These two ensure 'ls' produces nice coloring on BSD-derived systems
131 env['TERM'] = 'xterm-color'
142 env['TERM'] = 'xterm-color'
132 env['CLICOLOR'] = '1'
143 env['CLICOLOR'] = '1'
133 # Since normal pagers don't work at all (over pexpect we don't have
144 # Since normal pagers don't work at all (over pexpect we don't have
134 # single-key control of the subprocess), try to disable paging in
145 # single-key control of the subprocess), try to disable paging in
135 # subprocesses as much as possible.
146 # subprocesses as much as possible.
136 env['PAGER'] = 'cat'
147 env['PAGER'] = 'cat'
137 env['GIT_PAGER'] = 'cat'
148 env['GIT_PAGER'] = 'cat'
138
149
139 # And install the payload version of page.
150 # And install the payload version of page.
140 install_payload_page()
151 install_payload_page()
141
152
142 def auto_rewrite_input(self, cmd):
153 def auto_rewrite_input(self, cmd):
143 """Called to show the auto-rewritten input for autocall and friends.
154 """Called to show the auto-rewritten input for autocall and friends.
144
155
145 FIXME: this payload is currently not correctly processed by the
156 FIXME: this payload is currently not correctly processed by the
146 frontend.
157 frontend.
147 """
158 """
148 new = self.prompt_manager.render('rewrite') + cmd
159 new = self.prompt_manager.render('rewrite') + cmd
149 payload = dict(
160 payload = dict(
150 source='IPython.zmq.zmqshell.ZMQInteractiveShell.auto_rewrite_input',
161 source='IPython.zmq.zmqshell.ZMQInteractiveShell.auto_rewrite_input',
151 transformed_input=new,
162 transformed_input=new,
152 )
163 )
153 self.payload_manager.write_payload(payload)
164 self.payload_manager.write_payload(payload)
154
165
155 def ask_exit(self):
166 def ask_exit(self):
156 """Engage the exit actions."""
167 """Engage the exit actions."""
168 self.exit_now = True
157 payload = dict(
169 payload = dict(
158 source='IPython.zmq.zmqshell.ZMQInteractiveShell.ask_exit',
170 source='IPython.zmq.zmqshell.ZMQInteractiveShell.ask_exit',
159 exit=True,
171 exit=True,
160 keepkernel=self.keepkernel_on_exit,
172 keepkernel=self.keepkernel_on_exit,
161 )
173 )
162 self.payload_manager.write_payload(payload)
174 self.payload_manager.write_payload(payload)
163
175
164 def _showtraceback(self, etype, evalue, stb):
176 def _showtraceback(self, etype, evalue, stb):
165
177
166 exc_content = {
178 exc_content = {
167 u'traceback' : stb,
179 u'traceback' : stb,
168 u'ename' : unicode(etype.__name__),
180 u'ename' : unicode(etype.__name__),
169 u'evalue' : unicode(evalue)
181 u'evalue' : unicode(evalue)
170 }
182 }
171
183
172 dh = self.displayhook
184 dh = self.displayhook
173 # Send exception info over pub socket for other clients than the caller
185 # Send exception info over pub socket for other clients than the caller
174 # to pick up
186 # to pick up
175 exc_msg = dh.session.send(dh.pub_socket, u'pyerr', json_clean(exc_content), dh.parent_header)
187 topic = None
188 if dh.topic:
189 topic = dh.topic.replace(b'pyout', b'pyerr')
190
191 exc_msg = dh.session.send(dh.pub_socket, u'pyerr', json_clean(exc_content), dh.parent_header, ident=topic)
176
192
177 # FIXME - Hack: store exception info in shell object. Right now, the
193 # FIXME - Hack: store exception info in shell object. Right now, the
178 # caller is reading this info after the fact, we need to fix this logic
194 # caller is reading this info after the fact, we need to fix this logic
179 # to remove this hack. Even uglier, we need to store the error status
195 # to remove this hack. Even uglier, we need to store the error status
180 # here, because in the main loop, the logic that sets it is being
196 # here, because in the main loop, the logic that sets it is being
181 # skipped because runlines swallows the exceptions.
197 # skipped because runlines swallows the exceptions.
182 exc_content[u'status'] = u'error'
198 exc_content[u'status'] = u'error'
183 self._reply_content = exc_content
199 self._reply_content = exc_content
184 # /FIXME
200 # /FIXME
185
201
186 return exc_content
202 return exc_content
187
203
188 #------------------------------------------------------------------------
204 #------------------------------------------------------------------------
189 # Magic overrides
205 # Magic overrides
190 #------------------------------------------------------------------------
206 #------------------------------------------------------------------------
191 # Once the base class stops inheriting from magic, this code needs to be
207 # Once the base class stops inheriting from magic, this code needs to be
192 # moved into a separate machinery as well. For now, at least isolate here
208 # moved into a separate machinery as well. For now, at least isolate here
193 # the magics which this class needs to implement differently from the base
209 # the magics which this class needs to implement differently from the base
194 # class, or that are unique to it.
210 # class, or that are unique to it.
195
211
196 def magic_doctest_mode(self,parameter_s=''):
212 def magic_doctest_mode(self,parameter_s=''):
197 """Toggle doctest mode on and off.
213 """Toggle doctest mode on and off.
198
214
199 This mode is intended to make IPython behave as much as possible like a
215 This mode is intended to make IPython behave as much as possible like a
200 plain Python shell, from the perspective of how its prompts, exceptions
216 plain Python shell, from the perspective of how its prompts, exceptions
201 and output look. This makes it easy to copy and paste parts of a
217 and output look. This makes it easy to copy and paste parts of a
202 session into doctests. It does so by:
218 session into doctests. It does so by:
203
219
204 - Changing the prompts to the classic ``>>>`` ones.
220 - Changing the prompts to the classic ``>>>`` ones.
205 - Changing the exception reporting mode to 'Plain'.
221 - Changing the exception reporting mode to 'Plain'.
206 - Disabling pretty-printing of output.
222 - Disabling pretty-printing of output.
207
223
208 Note that IPython also supports the pasting of code snippets that have
224 Note that IPython also supports the pasting of code snippets that have
209 leading '>>>' and '...' prompts in them. This means that you can paste
225 leading '>>>' and '...' prompts in them. This means that you can paste
210 doctests from files or docstrings (even if they have leading
226 doctests from files or docstrings (even if they have leading
211 whitespace), and the code will execute correctly. You can then use
227 whitespace), and the code will execute correctly. You can then use
212 '%history -t' to see the translated history; this will give you the
228 '%history -t' to see the translated history; this will give you the
213 input after removal of all the leading prompts and whitespace, which
229 input after removal of all the leading prompts and whitespace, which
214 can be pasted back into an editor.
230 can be pasted back into an editor.
215
231
216 With these features, you can switch into this mode easily whenever you
232 With these features, you can switch into this mode easily whenever you
217 need to do testing and changes to doctests, without having to leave
233 need to do testing and changes to doctests, without having to leave
218 your existing IPython session.
234 your existing IPython session.
219 """
235 """
220
236
221 from IPython.utils.ipstruct import Struct
237 from IPython.utils.ipstruct import Struct
222
238
223 # Shorthands
239 # Shorthands
224 shell = self.shell
240 shell = self.shell
225 disp_formatter = self.shell.display_formatter
241 disp_formatter = self.shell.display_formatter
226 ptformatter = disp_formatter.formatters['text/plain']
242 ptformatter = disp_formatter.formatters['text/plain']
227 # dstore is a data store kept in the instance metadata bag to track any
243 # dstore is a data store kept in the instance metadata bag to track any
228 # changes we make, so we can undo them later.
244 # changes we make, so we can undo them later.
229 dstore = shell.meta.setdefault('doctest_mode', Struct())
245 dstore = shell.meta.setdefault('doctest_mode', Struct())
230 save_dstore = dstore.setdefault
246 save_dstore = dstore.setdefault
231
247
232 # save a few values we'll need to recover later
248 # save a few values we'll need to recover later
233 mode = save_dstore('mode', False)
249 mode = save_dstore('mode', False)
234 save_dstore('rc_pprint', ptformatter.pprint)
250 save_dstore('rc_pprint', ptformatter.pprint)
235 save_dstore('rc_plain_text_only',disp_formatter.plain_text_only)
251 save_dstore('rc_plain_text_only',disp_formatter.plain_text_only)
236 save_dstore('xmode', shell.InteractiveTB.mode)
252 save_dstore('xmode', shell.InteractiveTB.mode)
237
253
238 if mode == False:
254 if mode == False:
239 # turn on
255 # turn on
240 ptformatter.pprint = False
256 ptformatter.pprint = False
241 disp_formatter.plain_text_only = True
257 disp_formatter.plain_text_only = True
242 shell.magic_xmode('Plain')
258 shell.magic_xmode('Plain')
243 else:
259 else:
244 # turn off
260 # turn off
245 ptformatter.pprint = dstore.rc_pprint
261 ptformatter.pprint = dstore.rc_pprint
246 disp_formatter.plain_text_only = dstore.rc_plain_text_only
262 disp_formatter.plain_text_only = dstore.rc_plain_text_only
247 shell.magic_xmode(dstore.xmode)
263 shell.magic_xmode(dstore.xmode)
248
264
249 # Store new mode and inform on console
265 # Store new mode and inform on console
250 dstore.mode = bool(1-int(mode))
266 dstore.mode = bool(1-int(mode))
251 mode_label = ['OFF','ON'][dstore.mode]
267 mode_label = ['OFF','ON'][dstore.mode]
252 print('Doctest mode is:', mode_label)
268 print('Doctest mode is:', mode_label)
253
269
254 # Send the payload back so that clients can modify their prompt display
270 # Send the payload back so that clients can modify their prompt display
255 payload = dict(
271 payload = dict(
256 source='IPython.zmq.zmqshell.ZMQInteractiveShell.magic_doctest_mode',
272 source='IPython.zmq.zmqshell.ZMQInteractiveShell.magic_doctest_mode',
257 mode=dstore.mode)
273 mode=dstore.mode)
258 self.payload_manager.write_payload(payload)
274 self.payload_manager.write_payload(payload)
259
275
260 @skip_doctest
276 @skip_doctest
261 def magic_edit(self,parameter_s='',last_call=['','']):
277 def magic_edit(self,parameter_s='',last_call=['','']):
262 """Bring up an editor and execute the resulting code.
278 """Bring up an editor and execute the resulting code.
263
279
264 Usage:
280 Usage:
265 %edit [options] [args]
281 %edit [options] [args]
266
282
267 %edit runs an external text editor. You will need to set the command for
283 %edit runs an external text editor. You will need to set the command for
268 this editor via the ``TerminalInteractiveShell.editor`` option in your
284 this editor via the ``TerminalInteractiveShell.editor`` option in your
269 configuration file before it will work.
285 configuration file before it will work.
270
286
271 This command allows you to conveniently edit multi-line code right in
287 This command allows you to conveniently edit multi-line code right in
272 your IPython session.
288 your IPython session.
273
289
274 If called without arguments, %edit opens up an empty editor with a
290 If called without arguments, %edit opens up an empty editor with a
275 temporary file and will execute the contents of this file when you
291 temporary file and will execute the contents of this file when you
276 close it (don't forget to save it!).
292 close it (don't forget to save it!).
277
293
278
294
279 Options:
295 Options:
280
296
281 -n <number>: open the editor at a specified line number. By default,
297 -n <number>: open the editor at a specified line number. By default,
282 the IPython editor hook uses the unix syntax 'editor +N filename', but
298 the IPython editor hook uses the unix syntax 'editor +N filename', but
283 you can configure this by providing your own modified hook if your
299 you can configure this by providing your own modified hook if your
284 favorite editor supports line-number specifications with a different
300 favorite editor supports line-number specifications with a different
285 syntax.
301 syntax.
286
302
287 -p: this will call the editor with the same data as the previous time
303 -p: this will call the editor with the same data as the previous time
288 it was used, regardless of how long ago (in your current session) it
304 it was used, regardless of how long ago (in your current session) it
289 was.
305 was.
290
306
291 -r: use 'raw' input. This option only applies to input taken from the
307 -r: use 'raw' input. This option only applies to input taken from the
292 user's history. By default, the 'processed' history is used, so that
308 user's history. By default, the 'processed' history is used, so that
293 magics are loaded in their transformed version to valid Python. If
309 magics are loaded in their transformed version to valid Python. If
294 this option is given, the raw input as typed as the command line is
310 this option is given, the raw input as typed as the command line is
295 used instead. When you exit the editor, it will be executed by
311 used instead. When you exit the editor, it will be executed by
296 IPython's own processor.
312 IPython's own processor.
297
313
298 -x: do not execute the edited code immediately upon exit. This is
314 -x: do not execute the edited code immediately upon exit. This is
299 mainly useful if you are editing programs which need to be called with
315 mainly useful if you are editing programs which need to be called with
300 command line arguments, which you can then do using %run.
316 command line arguments, which you can then do using %run.
301
317
302
318
303 Arguments:
319 Arguments:
304
320
305 If arguments are given, the following possibilites exist:
321 If arguments are given, the following possibilites exist:
306
322
307 - The arguments are numbers or pairs of colon-separated numbers (like
323 - The arguments are numbers or pairs of colon-separated numbers (like
308 1 4:8 9). These are interpreted as lines of previous input to be
324 1 4:8 9). These are interpreted as lines of previous input to be
309 loaded into the editor. The syntax is the same of the %macro command.
325 loaded into the editor. The syntax is the same of the %macro command.
310
326
311 - If the argument doesn't start with a number, it is evaluated as a
327 - If the argument doesn't start with a number, it is evaluated as a
312 variable and its contents loaded into the editor. You can thus edit
328 variable and its contents loaded into the editor. You can thus edit
313 any string which contains python code (including the result of
329 any string which contains python code (including the result of
314 previous edits).
330 previous edits).
315
331
316 - If the argument is the name of an object (other than a string),
332 - If the argument is the name of an object (other than a string),
317 IPython will try to locate the file where it was defined and open the
333 IPython will try to locate the file where it was defined and open the
318 editor at the point where it is defined. You can use `%edit function`
334 editor at the point where it is defined. You can use `%edit function`
319 to load an editor exactly at the point where 'function' is defined,
335 to load an editor exactly at the point where 'function' is defined,
320 edit it and have the file be executed automatically.
336 edit it and have the file be executed automatically.
321
337
322 If the object is a macro (see %macro for details), this opens up your
338 If the object is a macro (see %macro for details), this opens up your
323 specified editor with a temporary file containing the macro's data.
339 specified editor with a temporary file containing the macro's data.
324 Upon exit, the macro is reloaded with the contents of the file.
340 Upon exit, the macro is reloaded with the contents of the file.
325
341
326 Note: opening at an exact line is only supported under Unix, and some
342 Note: opening at an exact line is only supported under Unix, and some
327 editors (like kedit and gedit up to Gnome 2.8) do not understand the
343 editors (like kedit and gedit up to Gnome 2.8) do not understand the
328 '+NUMBER' parameter necessary for this feature. Good editors like
344 '+NUMBER' parameter necessary for this feature. Good editors like
329 (X)Emacs, vi, jed, pico and joe all do.
345 (X)Emacs, vi, jed, pico and joe all do.
330
346
331 - If the argument is not found as a variable, IPython will look for a
347 - If the argument is not found as a variable, IPython will look for a
332 file with that name (adding .py if necessary) and load it into the
348 file with that name (adding .py if necessary) and load it into the
333 editor. It will execute its contents with execfile() when you exit,
349 editor. It will execute its contents with execfile() when you exit,
334 loading any code in the file into your interactive namespace.
350 loading any code in the file into your interactive namespace.
335
351
336 After executing your code, %edit will return as output the code you
352 After executing your code, %edit will return as output the code you
337 typed in the editor (except when it was an existing file). This way
353 typed in the editor (except when it was an existing file). This way
338 you can reload the code in further invocations of %edit as a variable,
354 you can reload the code in further invocations of %edit as a variable,
339 via _<NUMBER> or Out[<NUMBER>], where <NUMBER> is the prompt number of
355 via _<NUMBER> or Out[<NUMBER>], where <NUMBER> is the prompt number of
340 the output.
356 the output.
341
357
342 Note that %edit is also available through the alias %ed.
358 Note that %edit is also available through the alias %ed.
343
359
344 This is an example of creating a simple function inside the editor and
360 This is an example of creating a simple function inside the editor and
345 then modifying it. First, start up the editor:
361 then modifying it. First, start up the editor:
346
362
347 In [1]: ed
363 In [1]: ed
348 Editing... done. Executing edited code...
364 Editing... done. Executing edited code...
349 Out[1]: 'def foo():n print "foo() was defined in an editing session"n'
365 Out[1]: 'def foo():n print "foo() was defined in an editing session"n'
350
366
351 We can then call the function foo():
367 We can then call the function foo():
352
368
353 In [2]: foo()
369 In [2]: foo()
354 foo() was defined in an editing session
370 foo() was defined in an editing session
355
371
356 Now we edit foo. IPython automatically loads the editor with the
372 Now we edit foo. IPython automatically loads the editor with the
357 (temporary) file where foo() was previously defined:
373 (temporary) file where foo() was previously defined:
358
374
359 In [3]: ed foo
375 In [3]: ed foo
360 Editing... done. Executing edited code...
376 Editing... done. Executing edited code...
361
377
362 And if we call foo() again we get the modified version:
378 And if we call foo() again we get the modified version:
363
379
364 In [4]: foo()
380 In [4]: foo()
365 foo() has now been changed!
381 foo() has now been changed!
366
382
367 Here is an example of how to edit a code snippet successive
383 Here is an example of how to edit a code snippet successive
368 times. First we call the editor:
384 times. First we call the editor:
369
385
370 In [5]: ed
386 In [5]: ed
371 Editing... done. Executing edited code...
387 Editing... done. Executing edited code...
372 hello
388 hello
373 Out[5]: "print 'hello'n"
389 Out[5]: "print 'hello'n"
374
390
375 Now we call it again with the previous output (stored in _):
391 Now we call it again with the previous output (stored in _):
376
392
377 In [6]: ed _
393 In [6]: ed _
378 Editing... done. Executing edited code...
394 Editing... done. Executing edited code...
379 hello world
395 hello world
380 Out[6]: "print 'hello world'n"
396 Out[6]: "print 'hello world'n"
381
397
382 Now we call it with the output #8 (stored in _8, also as Out[8]):
398 Now we call it with the output #8 (stored in _8, also as Out[8]):
383
399
384 In [7]: ed _8
400 In [7]: ed _8
385 Editing... done. Executing edited code...
401 Editing... done. Executing edited code...
386 hello again
402 hello again
387 Out[7]: "print 'hello again'n"
403 Out[7]: "print 'hello again'n"
388 """
404 """
389
405
390 opts,args = self.parse_options(parameter_s,'prn:')
406 opts,args = self.parse_options(parameter_s,'prn:')
391
407
392 try:
408 try:
393 filename, lineno, _ = self._find_edit_target(args, opts, last_call)
409 filename, lineno, _ = self._find_edit_target(args, opts, last_call)
394 except MacroToEdit as e:
410 except MacroToEdit as e:
395 # TODO: Implement macro editing over 2 processes.
411 # TODO: Implement macro editing over 2 processes.
396 print("Macro editing not yet implemented in 2-process model.")
412 print("Macro editing not yet implemented in 2-process model.")
397 return
413 return
398
414
399 # Make sure we send to the client an absolute path, in case the working
415 # Make sure we send to the client an absolute path, in case the working
400 # directory of client and kernel don't match
416 # directory of client and kernel don't match
401 filename = os.path.abspath(filename)
417 filename = os.path.abspath(filename)
402
418
403 payload = {
419 payload = {
404 'source' : 'IPython.zmq.zmqshell.ZMQInteractiveShell.edit_magic',
420 'source' : 'IPython.zmq.zmqshell.ZMQInteractiveShell.edit_magic',
405 'filename' : filename,
421 'filename' : filename,
406 'line_number' : lineno
422 'line_number' : lineno
407 }
423 }
408 self.payload_manager.write_payload(payload)
424 self.payload_manager.write_payload(payload)
409
425
410 # A few magics that are adapted to the specifics of using pexpect and a
426 # A few magics that are adapted to the specifics of using pexpect and a
411 # remote terminal
427 # remote terminal
412
428
413 def magic_clear(self, arg_s):
429 def magic_clear(self, arg_s):
414 """Clear the terminal."""
430 """Clear the terminal."""
415 if os.name == 'posix':
431 if os.name == 'posix':
416 self.shell.system("clear")
432 self.shell.system("clear")
417 else:
433 else:
418 self.shell.system("cls")
434 self.shell.system("cls")
419
435
420 if os.name == 'nt':
436 if os.name == 'nt':
421 # This is the usual name in windows
437 # This is the usual name in windows
422 magic_cls = magic_clear
438 magic_cls = magic_clear
423
439
424 # Terminal pagers won't work over pexpect, but we do have our own pager
440 # Terminal pagers won't work over pexpect, but we do have our own pager
425
441
426 def magic_less(self, arg_s):
442 def magic_less(self, arg_s):
427 """Show a file through the pager.
443 """Show a file through the pager.
428
444
429 Files ending in .py are syntax-highlighted."""
445 Files ending in .py are syntax-highlighted."""
430 cont = open(arg_s).read()
446 cont = open(arg_s).read()
431 if arg_s.endswith('.py'):
447 if arg_s.endswith('.py'):
432 cont = self.shell.pycolorize(cont)
448 cont = self.shell.pycolorize(cont)
433 page.page(cont)
449 page.page(cont)
434
450
435 magic_more = magic_less
451 magic_more = magic_less
436
452
437 # Man calls a pager, so we also need to redefine it
453 # Man calls a pager, so we also need to redefine it
438 if os.name == 'posix':
454 if os.name == 'posix':
439 def magic_man(self, arg_s):
455 def magic_man(self, arg_s):
440 """Find the man page for the given command and display in pager."""
456 """Find the man page for the given command and display in pager."""
441 page.page(self.shell.getoutput('man %s | col -b' % arg_s,
457 page.page(self.shell.getoutput('man %s | col -b' % arg_s,
442 split=False))
458 split=False))
443
459
444 # FIXME: this is specific to the GUI, so we should let the gui app load
460 # FIXME: this is specific to the GUI, so we should let the gui app load
445 # magics at startup that are only for the gui. Once the gui app has proper
461 # magics at startup that are only for the gui. Once the gui app has proper
446 # profile and configuration management, we can have it initialize a kernel
462 # profile and configuration management, we can have it initialize a kernel
447 # with a special config file that provides these.
463 # with a special config file that provides these.
448 def magic_guiref(self, arg_s):
464 def magic_guiref(self, arg_s):
449 """Show a basic reference about the GUI console."""
465 """Show a basic reference about the GUI console."""
450 from IPython.core.usage import gui_reference
466 from IPython.core.usage import gui_reference
451 page.page(gui_reference, auto_html=True)
467 page.page(gui_reference, auto_html=True)
452
468
453 def magic_connect_info(self, arg_s):
469 def magic_connect_info(self, arg_s):
454 """Print information for connecting other clients to this kernel
470 """Print information for connecting other clients to this kernel
455
471
456 It will print the contents of this session's connection file, as well as
472 It will print the contents of this session's connection file, as well as
457 shortcuts for local clients.
473 shortcuts for local clients.
458
474
459 In the simplest case, when called from the most recently launched kernel,
475 In the simplest case, when called from the most recently launched kernel,
460 secondary clients can be connected, simply with:
476 secondary clients can be connected, simply with:
461
477
462 $> ipython <app> --existing
478 $> ipython <app> --existing
463
479
464 """
480 """
465
481
466 from IPython.core.application import BaseIPythonApplication as BaseIPApp
482 from IPython.core.application import BaseIPythonApplication as BaseIPApp
467
483
468 if BaseIPApp.initialized():
484 if BaseIPApp.initialized():
469 app = BaseIPApp.instance()
485 app = BaseIPApp.instance()
470 security_dir = app.profile_dir.security_dir
486 security_dir = app.profile_dir.security_dir
471 profile = app.profile
487 profile = app.profile
472 else:
488 else:
473 profile = 'default'
489 profile = 'default'
474 security_dir = ''
490 security_dir = ''
475
491
476 try:
492 try:
477 connection_file = get_connection_file()
493 connection_file = get_connection_file()
478 info = get_connection_info(unpack=False)
494 info = get_connection_info(unpack=False)
479 except Exception as e:
495 except Exception as e:
480 error("Could not get connection info: %r" % e)
496 error("Could not get connection info: %r" % e)
481 return
497 return
482
498
483 # add profile flag for non-default profile
499 # add profile flag for non-default profile
484 profile_flag = "--profile %s" % profile if profile != 'default' else ""
500 profile_flag = "--profile %s" % profile if profile != 'default' else ""
485
501
486 # if it's in the security dir, truncate to basename
502 # if it's in the security dir, truncate to basename
487 if security_dir == os.path.dirname(connection_file):
503 if security_dir == os.path.dirname(connection_file):
488 connection_file = os.path.basename(connection_file)
504 connection_file = os.path.basename(connection_file)
489
505
490
506
491 print (info + '\n')
507 print (info + '\n')
492 print ("Paste the above JSON into a file, and connect with:\n"
508 print ("Paste the above JSON into a file, and connect with:\n"
493 " $> ipython <app> --existing <file>\n"
509 " $> ipython <app> --existing <file>\n"
494 "or, if you are local, you can connect with just:\n"
510 "or, if you are local, you can connect with just:\n"
495 " $> ipython <app> --existing {0} {1}\n"
511 " $> ipython <app> --existing {0} {1}\n"
496 "or even just:\n"
512 "or even just:\n"
497 " $> ipython <app> --existing {1}\n"
513 " $> ipython <app> --existing {1}\n"
498 "if this is the most recent IPython session you have started.".format(
514 "if this is the most recent IPython session you have started.".format(
499 connection_file, profile_flag
515 connection_file, profile_flag
500 )
516 )
501 )
517 )
502
518
503 def magic_qtconsole(self, arg_s):
519 def magic_qtconsole(self, arg_s):
504 """Open a qtconsole connected to this kernel.
520 """Open a qtconsole connected to this kernel.
505
521
506 Useful for connecting a qtconsole to running notebooks, for better
522 Useful for connecting a qtconsole to running notebooks, for better
507 debugging.
523 debugging.
508 """
524 """
509 try:
525 try:
510 p = connect_qtconsole(argv=arg_split(arg_s, os.name=='posix'))
526 p = connect_qtconsole(argv=arg_split(arg_s, os.name=='posix'))
511 except Exception as e:
527 except Exception as e:
512 error("Could not start qtconsole: %r" % e)
528 error("Could not start qtconsole: %r" % e)
513 return
529 return
514
530
515 def set_next_input(self, text):
531 def set_next_input(self, text):
516 """Send the specified text to the frontend to be presented at the next
532 """Send the specified text to the frontend to be presented at the next
517 input cell."""
533 input cell."""
518 payload = dict(
534 payload = dict(
519 source='IPython.zmq.zmqshell.ZMQInteractiveShell.set_next_input',
535 source='IPython.zmq.zmqshell.ZMQInteractiveShell.set_next_input',
520 text=text
536 text=text
521 )
537 )
522 self.payload_manager.write_payload(payload)
538 self.payload_manager.write_payload(payload)
523
539
524
540
525 InteractiveShellABC.register(ZMQInteractiveShell)
541 InteractiveShellABC.register(ZMQInteractiveShell)
@@ -1,137 +1,137 b''
1 .. _parallel_db:
1 .. _parallel_db:
2
2
3 =======================
3 =======================
4 IPython's Task Database
4 IPython's Task Database
5 =======================
5 =======================
6
6
7 The IPython Hub stores all task requests and results in a database. Currently supported backends
7 The IPython Hub stores all task requests and results in a database. Currently supported backends
8 are: MongoDB, SQLite (the default), and an in-memory DictDB. The most common use case for
8 are: MongoDB, SQLite (the default), and an in-memory DictDB. The most common use case for
9 this is clients requesting results for tasks they did not submit, via:
9 this is clients requesting results for tasks they did not submit, via:
10
10
11 .. sourcecode:: ipython
11 .. sourcecode:: ipython
12
12
13 In [1]: rc.get_result(task_id)
13 In [1]: rc.get_result(task_id)
14
14
15 However, since we have this DB backend, we provide a direct query method in the :class:`client`
15 However, since we have this DB backend, we provide a direct query method in the :class:`client`
16 for users who want deeper introspection into their task history. The :meth:`db_query` method of
16 for users who want deeper introspection into their task history. The :meth:`db_query` method of
17 the Client is modeled after MongoDB queries, so if you have used MongoDB it should look
17 the Client is modeled after MongoDB queries, so if you have used MongoDB it should look
18 familiar. In fact, when the MongoDB backend is in use, the query is relayed directly. However,
18 familiar. In fact, when the MongoDB backend is in use, the query is relayed directly. However,
19 when using other backends, the interface is emulated and only a subset of queries is possible.
19 when using other backends, the interface is emulated and only a subset of queries is possible.
20
20
21 .. seealso::
21 .. seealso::
22
22
23 MongoDB query docs: http://www.mongodb.org/display/DOCS/Querying
23 MongoDB query docs: http://www.mongodb.org/display/DOCS/Querying
24
24
25 :meth:`Client.db_query` takes a dictionary query object, with keys from the TaskRecord key list,
25 :meth:`Client.db_query` takes a dictionary query object, with keys from the TaskRecord key list,
26 and values of either exact values to test, or MongoDB queries, which are dicts of The form:
26 and values of either exact values to test, or MongoDB queries, which are dicts of The form:
27 ``{'operator' : 'argument(s)'}``. There is also an optional `keys` argument, that specifies
27 ``{'operator' : 'argument(s)'}``. There is also an optional `keys` argument, that specifies
28 which subset of keys should be retrieved. The default is to retrieve all keys excluding the
28 which subset of keys should be retrieved. The default is to retrieve all keys excluding the
29 request and result buffers. :meth:`db_query` returns a list of TaskRecord dicts. Also like
29 request and result buffers. :meth:`db_query` returns a list of TaskRecord dicts. Also like
30 MongoDB, the `msg_id` key will always be included, whether requested or not.
30 MongoDB, the `msg_id` key will always be included, whether requested or not.
31
31
32 TaskRecord keys:
32 TaskRecord keys:
33
33
34 =============== =============== =============
34 =============== =============== =============
35 Key Type Description
35 Key Type Description
36 =============== =============== =============
36 =============== =============== =============
37 msg_id uuid(bytes) The msg ID
37 msg_id uuid(ascii) The msg ID
38 header dict The request header
38 header dict The request header
39 content dict The request content (likely empty)
39 content dict The request content (likely empty)
40 buffers list(bytes) buffers containing serialized request objects
40 buffers list(bytes) buffers containing serialized request objects
41 submitted datetime timestamp for time of submission (set by client)
41 submitted datetime timestamp for time of submission (set by client)
42 client_uuid uuid(bytes) IDENT of client's socket
42 client_uuid uuid(bytes) IDENT of client's socket
43 engine_uuid uuid(bytes) IDENT of engine's socket
43 engine_uuid uuid(bytes) IDENT of engine's socket
44 started datetime time task began execution on engine
44 started datetime time task began execution on engine
45 completed datetime time task finished execution (success or failure) on engine
45 completed datetime time task finished execution (success or failure) on engine
46 resubmitted datetime time of resubmission (if applicable)
46 resubmitted uuid(ascii) msg_id of resubmitted task (if applicable)
47 result_header dict header for result
47 result_header dict header for result
48 result_content dict content for result
48 result_content dict content for result
49 result_buffers list(bytes) buffers containing serialized request objects
49 result_buffers list(bytes) buffers containing serialized request objects
50 queue bytes The name of the queue for the task ('mux' or 'task')
50 queue bytes The name of the queue for the task ('mux' or 'task')
51 pyin <unused> Python input (unused)
51 pyin <unused> Python input (unused)
52 pyout <unused> Python output (unused)
52 pyout <unused> Python output (unused)
53 pyerr <unused> Python traceback (unused)
53 pyerr <unused> Python traceback (unused)
54 stdout str Stream of stdout data
54 stdout str Stream of stdout data
55 stderr str Stream of stderr data
55 stderr str Stream of stderr data
56
56
57 =============== =============== =============
57 =============== =============== =============
58
58
59 MongoDB operators we emulate on all backends:
59 MongoDB operators we emulate on all backends:
60
60
61 ========== =================
61 ========== =================
62 Operator Python equivalent
62 Operator Python equivalent
63 ========== =================
63 ========== =================
64 '$in' in
64 '$in' in
65 '$nin' not in
65 '$nin' not in
66 '$eq' ==
66 '$eq' ==
67 '$ne' !=
67 '$ne' !=
68 '$ge' >
68 '$ge' >
69 '$gte' >=
69 '$gte' >=
70 '$le' <
70 '$le' <
71 '$lte' <=
71 '$lte' <=
72 ========== =================
72 ========== =================
73
73
74
74
75 The DB Query is useful for two primary cases:
75 The DB Query is useful for two primary cases:
76
76
77 1. deep polling of task status or metadata
77 1. deep polling of task status or metadata
78 2. selecting a subset of tasks, on which to perform a later operation (e.g. wait on result, purge records, resubmit,...)
78 2. selecting a subset of tasks, on which to perform a later operation (e.g. wait on result, purge records, resubmit,...)
79
79
80 Example Queries
80 Example Queries
81 ===============
81 ===============
82
82
83
83
84 To get all msg_ids that are not completed, only retrieving their ID and start time:
84 To get all msg_ids that are not completed, only retrieving their ID and start time:
85
85
86 .. sourcecode:: ipython
86 .. sourcecode:: ipython
87
87
88 In [1]: incomplete = rc.db_query({'complete' : None}, keys=['msg_id', 'started'])
88 In [1]: incomplete = rc.db_query({'complete' : None}, keys=['msg_id', 'started'])
89
89
90 All jobs started in the last hour by me:
90 All jobs started in the last hour by me:
91
91
92 .. sourcecode:: ipython
92 .. sourcecode:: ipython
93
93
94 In [1]: from datetime import datetime, timedelta
94 In [1]: from datetime import datetime, timedelta
95
95
96 In [2]: hourago = datetime.now() - timedelta(1./24)
96 In [2]: hourago = datetime.now() - timedelta(1./24)
97
97
98 In [3]: recent = rc.db_query({'started' : {'$gte' : hourago },
98 In [3]: recent = rc.db_query({'started' : {'$gte' : hourago },
99 'client_uuid' : rc.session.session})
99 'client_uuid' : rc.session.session})
100
100
101 All jobs started more than an hour ago, by clients *other than me*:
101 All jobs started more than an hour ago, by clients *other than me*:
102
102
103 .. sourcecode:: ipython
103 .. sourcecode:: ipython
104
104
105 In [3]: recent = rc.db_query({'started' : {'$le' : hourago },
105 In [3]: recent = rc.db_query({'started' : {'$le' : hourago },
106 'client_uuid' : {'$ne' : rc.session.session}})
106 'client_uuid' : {'$ne' : rc.session.session}})
107
107
108 Result headers for all jobs on engine 3 or 4:
108 Result headers for all jobs on engine 3 or 4:
109
109
110 .. sourcecode:: ipython
110 .. sourcecode:: ipython
111
111
112 In [1]: uuids = map(rc._engines.get, (3,4))
112 In [1]: uuids = map(rc._engines.get, (3,4))
113
113
114 In [2]: hist34 = rc.db_query({'engine_uuid' : {'$in' : uuids }, keys='result_header')
114 In [2]: hist34 = rc.db_query({'engine_uuid' : {'$in' : uuids }, keys='result_header')
115
115
116
116
117 Cost
117 Cost
118 ====
118 ====
119
119
120 The advantage of the database backends is, of course, that large amounts of
120 The advantage of the database backends is, of course, that large amounts of
121 data can be stored that won't fit in memory. The default 'backend' is actually
121 data can be stored that won't fit in memory. The default 'backend' is actually
122 to just store all of this information in a Python dictionary. This is very fast,
122 to just store all of this information in a Python dictionary. This is very fast,
123 but will run out of memory quickly if you move a lot of data around, or your
123 but will run out of memory quickly if you move a lot of data around, or your
124 cluster is to run for a long time.
124 cluster is to run for a long time.
125
125
126 Unfortunately, the DB backends (SQLite and MongoDB) right now are rather slow,
126 Unfortunately, the DB backends (SQLite and MongoDB) right now are rather slow,
127 and can still consume large amounts of resources, particularly if large tasks
127 and can still consume large amounts of resources, particularly if large tasks
128 or results are being created at a high frequency.
128 or results are being created at a high frequency.
129
129
130 For this reason, we have added :class:`~.NoDB`,a dummy backend that doesn't
130 For this reason, we have added :class:`~.NoDB`,a dummy backend that doesn't
131 actually store any information. When you use this database, nothing is stored,
131 actually store any information. When you use this database, nothing is stored,
132 and any request for results will result in a KeyError. This obviously prevents
132 and any request for results will result in a KeyError. This obviously prevents
133 later requests for results and task resubmission from functioning, but
133 later requests for results and task resubmission from functioning, but
134 sometimes those nice features are not as useful as keeping Hub memory under
134 sometimes those nice features are not as useful as keeping Hub memory under
135 control.
135 control.
136
136
137
137
@@ -1,414 +0,0 b''
1 """
2 Kernel adapted from kernel.py to use ZMQ Streams
3
4 Authors:
5
6 * Min RK
7 * Brian Granger
8 * Fernando Perez
9 * Evan Patterson
10 """
11 #-----------------------------------------------------------------------------
12 # Copyright (C) 2010-2011 The IPython Development Team
13 #
14 # Distributed under the terms of the BSD License. The full license is in
15 # the file COPYING, distributed as part of this software.
16 #-----------------------------------------------------------------------------
17
18 #-----------------------------------------------------------------------------
19 # Imports
20 #-----------------------------------------------------------------------------
21
22 # Standard library imports.
23 from __future__ import print_function
24
25 import sys
26 import time
27
28 from code import CommandCompiler
29 from datetime import datetime
30 from pprint import pprint
31
32 # System library imports.
33 import zmq
34 from zmq.eventloop import ioloop, zmqstream
35
36 # Local imports.
37 from IPython.utils.traitlets import Instance, List, Integer, Dict, Set, Unicode, CBytes
38 from IPython.zmq.completer import KernelCompleter
39
40 from IPython.parallel.error import wrap_exception
41 from IPython.parallel.factory import SessionFactory
42 from IPython.parallel.util import serialize_object, unpack_apply_message, asbytes
43
44 def printer(*args):
45 pprint(args, stream=sys.__stdout__)
46
47
48 class _Passer(zmqstream.ZMQStream):
49 """Empty class that implements `send()` that does nothing.
50
51 Subclass ZMQStream for Session typechecking
52
53 """
54 def __init__(self, *args, **kwargs):
55 pass
56
57 def send(self, *args, **kwargs):
58 pass
59 send_multipart = send
60
61
62 #-----------------------------------------------------------------------------
63 # Main kernel class
64 #-----------------------------------------------------------------------------
65
66 class Kernel(SessionFactory):
67
68 #---------------------------------------------------------------------------
69 # Kernel interface
70 #---------------------------------------------------------------------------
71
72 # kwargs:
73 exec_lines = List(Unicode, config=True,
74 help="List of lines to execute")
75
76 # identities:
77 int_id = Integer(-1)
78 bident = CBytes()
79 ident = Unicode()
80 def _ident_changed(self, name, old, new):
81 self.bident = asbytes(new)
82
83 user_ns = Dict(config=True, help="""Set the user's namespace of the Kernel""")
84
85 control_stream = Instance(zmqstream.ZMQStream)
86 task_stream = Instance(zmqstream.ZMQStream)
87 iopub_stream = Instance(zmqstream.ZMQStream)
88 client = Instance('IPython.parallel.Client')
89
90 # internals
91 shell_streams = List()
92 compiler = Instance(CommandCompiler, (), {})
93 completer = Instance(KernelCompleter)
94
95 aborted = Set()
96 shell_handlers = Dict()
97 control_handlers = Dict()
98
99 def _set_prefix(self):
100 self.prefix = "engine.%s"%self.int_id
101
102 def _connect_completer(self):
103 self.completer = KernelCompleter(self.user_ns)
104
105 def __init__(self, **kwargs):
106 super(Kernel, self).__init__(**kwargs)
107 self._set_prefix()
108 self._connect_completer()
109
110 self.on_trait_change(self._set_prefix, 'id')
111 self.on_trait_change(self._connect_completer, 'user_ns')
112
113 # Build dict of handlers for message types
114 for msg_type in ['execute_request', 'complete_request', 'apply_request',
115 'clear_request']:
116 self.shell_handlers[msg_type] = getattr(self, msg_type)
117
118 for msg_type in ['shutdown_request', 'abort_request']+self.shell_handlers.keys():
119 self.control_handlers[msg_type] = getattr(self, msg_type)
120
121 self._initial_exec_lines()
122
123 def _wrap_exception(self, method=None):
124 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
125 content=wrap_exception(e_info)
126 return content
127
128 def _initial_exec_lines(self):
129 s = _Passer()
130 content = dict(silent=True, user_variable=[],user_expressions=[])
131 for line in self.exec_lines:
132 self.log.debug("executing initialization: %s"%line)
133 content.update({'code':line})
134 msg = self.session.msg('execute_request', content)
135 self.execute_request(s, [], msg)
136
137
138 #-------------------- control handlers -----------------------------
139 def abort_queues(self):
140 for stream in self.shell_streams:
141 if stream:
142 self.abort_queue(stream)
143
144 def abort_queue(self, stream):
145 while True:
146 idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
147 if msg is None:
148 return
149
150 self.log.info("Aborting:")
151 self.log.info(str(msg))
152 msg_type = msg['header']['msg_type']
153 reply_type = msg_type.split('_')[0] + '_reply'
154 # reply_msg = self.session.msg(reply_type, {'status' : 'aborted'}, msg)
155 # self.reply_socket.send(ident,zmq.SNDMORE)
156 # self.reply_socket.send_json(reply_msg)
157 reply_msg = self.session.send(stream, reply_type,
158 content={'status' : 'aborted'}, parent=msg, ident=idents)
159 self.log.debug(str(reply_msg))
160 # We need to wait a bit for requests to come in. This can probably
161 # be set shorter for true asynchronous clients.
162 time.sleep(0.05)
163
164 def abort_request(self, stream, ident, parent):
165 """abort a specifig msg by id"""
166 msg_ids = parent['content'].get('msg_ids', None)
167 if isinstance(msg_ids, basestring):
168 msg_ids = [msg_ids]
169 if not msg_ids:
170 self.abort_queues()
171 for mid in msg_ids:
172 self.aborted.add(str(mid))
173
174 content = dict(status='ok')
175 reply_msg = self.session.send(stream, 'abort_reply', content=content,
176 parent=parent, ident=ident)
177 self.log.debug(str(reply_msg))
178
179 def shutdown_request(self, stream, ident, parent):
180 """kill ourself. This should really be handled in an external process"""
181 try:
182 self.abort_queues()
183 except:
184 content = self._wrap_exception('shutdown')
185 else:
186 content = dict(parent['content'])
187 content['status'] = 'ok'
188 msg = self.session.send(stream, 'shutdown_reply',
189 content=content, parent=parent, ident=ident)
190 self.log.debug(str(msg))
191 dc = ioloop.DelayedCallback(lambda : sys.exit(0), 1000, self.loop)
192 dc.start()
193
194 def dispatch_control(self, msg):
195 idents,msg = self.session.feed_identities(msg, copy=False)
196 try:
197 msg = self.session.unserialize(msg, content=True, copy=False)
198 except:
199 self.log.error("Invalid Message", exc_info=True)
200 return
201 else:
202 self.log.debug("Control received, %s", msg)
203
204 header = msg['header']
205 msg_id = header['msg_id']
206 msg_type = header['msg_type']
207
208 handler = self.control_handlers.get(msg_type, None)
209 if handler is None:
210 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r"%msg_type)
211 else:
212 handler(self.control_stream, idents, msg)
213
214
215 #-------------------- queue helpers ------------------------------
216
217 def check_dependencies(self, dependencies):
218 if not dependencies:
219 return True
220 if len(dependencies) == 2 and dependencies[0] in 'any all'.split():
221 anyorall = dependencies[0]
222 dependencies = dependencies[1]
223 else:
224 anyorall = 'all'
225 results = self.client.get_results(dependencies,status_only=True)
226 if results['status'] != 'ok':
227 return False
228
229 if anyorall == 'any':
230 if not results['completed']:
231 return False
232 else:
233 if results['pending']:
234 return False
235
236 return True
237
238 def check_aborted(self, msg_id):
239 return msg_id in self.aborted
240
241 #-------------------- queue handlers -----------------------------
242
243 def clear_request(self, stream, idents, parent):
244 """Clear our namespace."""
245 self.user_ns = {}
246 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
247 content = dict(status='ok'))
248 self._initial_exec_lines()
249
250 def execute_request(self, stream, ident, parent):
251 self.log.debug('execute request %s'%parent)
252 try:
253 code = parent[u'content'][u'code']
254 except:
255 self.log.error("Got bad msg: %s"%parent, exc_info=True)
256 return
257 self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent,
258 ident=asbytes('%s.pyin'%self.prefix))
259 started = datetime.now()
260 try:
261 comp_code = self.compiler(code, '<zmq-kernel>')
262 # allow for not overriding displayhook
263 if hasattr(sys.displayhook, 'set_parent'):
264 sys.displayhook.set_parent(parent)
265 sys.stdout.set_parent(parent)
266 sys.stderr.set_parent(parent)
267 exec comp_code in self.user_ns, self.user_ns
268 except:
269 exc_content = self._wrap_exception('execute')
270 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
271 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
272 ident=asbytes('%s.pyerr'%self.prefix))
273 reply_content = exc_content
274 else:
275 reply_content = {'status' : 'ok'}
276
277 reply_msg = self.session.send(stream, u'execute_reply', reply_content, parent=parent,
278 ident=ident, subheader = dict(started=started))
279 self.log.debug(str(reply_msg))
280 if reply_msg['content']['status'] == u'error':
281 self.abort_queues()
282
283 def complete_request(self, stream, ident, parent):
284 matches = {'matches' : self.complete(parent),
285 'status' : 'ok'}
286 completion_msg = self.session.send(stream, 'complete_reply',
287 matches, parent, ident)
288 # print >> sys.__stdout__, completion_msg
289
290 def complete(self, msg):
291 return self.completer.complete(msg.content.line, msg.content.text)
292
293 def apply_request(self, stream, ident, parent):
294 # flush previous reply, so this request won't block it
295 stream.flush(zmq.POLLOUT)
296 try:
297 content = parent[u'content']
298 bufs = parent[u'buffers']
299 msg_id = parent['header']['msg_id']
300 # bound = parent['header'].get('bound', False)
301 except:
302 self.log.error("Got bad msg: %s"%parent, exc_info=True)
303 return
304 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
305 # self.iopub_stream.send(pyin_msg)
306 # self.session.send(self.iopub_stream, u'pyin', {u'code':code},parent=parent)
307 sub = {'dependencies_met' : True, 'engine' : self.ident,
308 'started': datetime.now()}
309 try:
310 # allow for not overriding displayhook
311 if hasattr(sys.displayhook, 'set_parent'):
312 sys.displayhook.set_parent(parent)
313 sys.stdout.set_parent(parent)
314 sys.stderr.set_parent(parent)
315 # exec "f(*args,**kwargs)" in self.user_ns, self.user_ns
316 working = self.user_ns
317 # suffix =
318 prefix = "_"+str(msg_id).replace("-","")+"_"
319
320 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
321 # if bound:
322 # bound_ns = Namespace(working)
323 # args = [bound_ns]+list(args)
324
325 fname = getattr(f, '__name__', 'f')
326
327 fname = prefix+"f"
328 argname = prefix+"args"
329 kwargname = prefix+"kwargs"
330 resultname = prefix+"result"
331
332 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
333 # print ns
334 working.update(ns)
335 code = "%s=%s(*%s,**%s)"%(resultname, fname, argname, kwargname)
336 try:
337 exec code in working,working
338 result = working.get(resultname)
339 finally:
340 for key in ns.iterkeys():
341 working.pop(key)
342 # if bound:
343 # working.update(bound_ns)
344
345 packed_result,buf = serialize_object(result)
346 result_buf = [packed_result]+buf
347 except:
348 exc_content = self._wrap_exception('apply')
349 # exc_msg = self.session.msg(u'pyerr', exc_content, parent)
350 self.session.send(self.iopub_stream, u'pyerr', exc_content, parent=parent,
351 ident=asbytes('%s.pyerr'%self.prefix))
352 reply_content = exc_content
353 result_buf = []
354
355 if exc_content['ename'] == 'UnmetDependency':
356 sub['dependencies_met'] = False
357 else:
358 reply_content = {'status' : 'ok'}
359
360 # put 'ok'/'error' status in header, for scheduler introspection:
361 sub['status'] = reply_content['status']
362
363 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
364 parent=parent, ident=ident,buffers=result_buf, subheader=sub)
365
366 # flush i/o
367 # should this be before reply_msg is sent, like in the single-kernel code,
368 # or should nothing get in the way of real results?
369 sys.stdout.flush()
370 sys.stderr.flush()
371
372 def dispatch_queue(self, stream, msg):
373 self.control_stream.flush()
374 idents,msg = self.session.feed_identities(msg, copy=False)
375 try:
376 msg = self.session.unserialize(msg, content=True, copy=False)
377 except:
378 self.log.error("Invalid Message", exc_info=True)
379 return
380 else:
381 self.log.debug("Message received, %s", msg)
382
383
384 header = msg['header']
385 msg_id = header['msg_id']
386 msg_type = msg['header']['msg_type']
387 if self.check_aborted(msg_id):
388 self.aborted.remove(msg_id)
389 # is it safe to assume a msg_id will not be resubmitted?
390 reply_type = msg_type.split('_')[0] + '_reply'
391 status = {'status' : 'aborted'}
392 reply_msg = self.session.send(stream, reply_type, subheader=status,
393 content=status, parent=msg, ident=idents)
394 return
395 handler = self.shell_handlers.get(msg_type, None)
396 if handler is None:
397 self.log.error("UNKNOWN MESSAGE TYPE: %r"%msg_type)
398 else:
399 handler(stream, idents, msg)
400
401 def start(self):
402 #### stream mode:
403 if self.control_stream:
404 self.control_stream.on_recv(self.dispatch_control, copy=False)
405
406 def make_dispatcher(stream):
407 def dispatcher(msg):
408 return self.dispatch_queue(stream, msg)
409 return dispatcher
410
411 for s in self.shell_streams:
412 s.on_recv(make_dispatcher(s), copy=False)
413
414
@@ -1,282 +0,0 b''
1 #!/usr/bin/env python
2 """A simple interactive kernel that talks to a frontend over 0MQ.
3
4 Things to do:
5
6 * Implement `set_parent` logic. Right before doing exec, the Kernel should
7 call set_parent on all the PUB objects with the message about to be executed.
8 * Implement random port and security key logic.
9 * Implement control messages.
10 * Implement event loop and poll version.
11 """
12
13 #-----------------------------------------------------------------------------
14 # Imports
15 #-----------------------------------------------------------------------------
16
17 # Standard library imports.
18 import __builtin__
19 from code import CommandCompiler
20 import sys
21 import time
22 import traceback
23
24 # System library imports.
25 import zmq
26
27 # Local imports.
28 from IPython.utils import py3compat
29 from IPython.utils.traitlets import HasTraits, Instance, Dict, Float
30 from completer import KernelCompleter
31 from entry_point import base_launch_kernel
32 from session import Session, Message
33 from kernelapp import KernelApp
34
35 #-----------------------------------------------------------------------------
36 # Main kernel class
37 #-----------------------------------------------------------------------------
38
39 class Kernel(HasTraits):
40
41 # Private interface
42
43 # Time to sleep after flushing the stdout/err buffers in each execute
44 # cycle. While this introduces a hard limit on the minimal latency of the
45 # execute cycle, it helps prevent output synchronization problems for
46 # clients.
47 # Units are in seconds. The minimum zmq latency on local host is probably
48 # ~150 microseconds, set this to 500us for now. We may need to increase it
49 # a little if it's not enough after more interactive testing.
50 _execute_sleep = Float(0.0005, config=True)
51
52 # This is a dict of port number that the kernel is listening on. It is set
53 # by record_ports and used by connect_request.
54 _recorded_ports = Dict()
55
56 #---------------------------------------------------------------------------
57 # Kernel interface
58 #---------------------------------------------------------------------------
59
60 session = Instance(Session)
61 shell_socket = Instance('zmq.Socket')
62 iopub_socket = Instance('zmq.Socket')
63 stdin_socket = Instance('zmq.Socket')
64 log = Instance('logging.Logger')
65
66 def __init__(self, **kwargs):
67 super(Kernel, self).__init__(**kwargs)
68 self.user_ns = {}
69 self.history = []
70 self.compiler = CommandCompiler()
71 self.completer = KernelCompleter(self.user_ns)
72
73 # Build dict of handlers for message types
74 msg_types = [ 'execute_request', 'complete_request',
75 'object_info_request', 'shutdown_request' ]
76 self.handlers = {}
77 for msg_type in msg_types:
78 self.handlers[msg_type] = getattr(self, msg_type)
79
80 def start(self):
81 """ Start the kernel main loop.
82 """
83 while True:
84 ident,msg = self.session.recv(self.shell_socket,0)
85 assert ident is not None, "Missing message part."
86 omsg = Message(msg)
87 self.log.debug(str(omsg))
88 handler = self.handlers.get(omsg.msg_type, None)
89 if handler is None:
90 self.log.error("UNKNOWN MESSAGE TYPE: %s"%omsg)
91 else:
92 handler(ident, omsg)
93
94 def record_ports(self, ports):
95 """Record the ports that this kernel is using.
96
97 The creator of the Kernel instance must call this methods if they
98 want the :meth:`connect_request` method to return the port numbers.
99 """
100 self._recorded_ports = ports
101
102 #---------------------------------------------------------------------------
103 # Kernel request handlers
104 #---------------------------------------------------------------------------
105
106 def execute_request(self, ident, parent):
107 try:
108 code = parent[u'content'][u'code']
109 except:
110 self.log.error("Got bad msg: %s"%Message(parent))
111 return
112 pyin_msg = self.session.send(self.iopub_socket, u'pyin',{u'code':code}, parent=parent)
113
114 try:
115 comp_code = self.compiler(code, '<zmq-kernel>')
116
117 # Replace raw_input. Note that is not sufficient to replace
118 # raw_input in the user namespace.
119 raw_input = lambda prompt='': self._raw_input(prompt, ident, parent)
120 if py3compat.PY3:
121 __builtin__.input = raw_input
122 else:
123 __builtin__.raw_input = raw_input
124
125 # Set the parent message of the display hook and out streams.
126 sys.displayhook.set_parent(parent)
127 sys.stdout.set_parent(parent)
128 sys.stderr.set_parent(parent)
129
130 exec comp_code in self.user_ns, self.user_ns
131 except:
132 etype, evalue, tb = sys.exc_info()
133 tb = traceback.format_exception(etype, evalue, tb)
134 exc_content = {
135 u'status' : u'error',
136 u'traceback' : tb,
137 u'ename' : unicode(etype.__name__),
138 u'evalue' : unicode(evalue)
139 }
140 exc_msg = self.session.send(self.iopub_socket, u'pyerr', exc_content, parent)
141 reply_content = exc_content
142 else:
143 reply_content = { 'status' : 'ok', 'payload' : {} }
144
145 # Flush output before sending the reply.
146 sys.stderr.flush()
147 sys.stdout.flush()
148 # FIXME: on rare occasions, the flush doesn't seem to make it to the
149 # clients... This seems to mitigate the problem, but we definitely need
150 # to better understand what's going on.
151 if self._execute_sleep:
152 time.sleep(self._execute_sleep)
153
154 # Send the reply.
155 reply_msg = self.session.send(self.shell_socket, u'execute_reply', reply_content, parent, ident=ident)
156 self.log.debug(Message(reply_msg))
157 if reply_msg['content']['status'] == u'error':
158 self._abort_queue()
159
160 def complete_request(self, ident, parent):
161 matches = {'matches' : self._complete(parent),
162 'status' : 'ok'}
163 completion_msg = self.session.send(self.shell_socket, 'complete_reply',
164 matches, parent, ident)
165 self.log.debug(completion_msg)
166
167 def object_info_request(self, ident, parent):
168 context = parent['content']['oname'].split('.')
169 object_info = self._object_info(context)
170 msg = self.session.send(self.shell_socket, 'object_info_reply',
171 object_info, parent, ident)
172 self.log.debug(msg)
173
174 def shutdown_request(self, ident, parent):
175 content = dict(parent['content'])
176 msg = self.session.send(self.shell_socket, 'shutdown_reply',
177 content, parent, ident)
178 msg = self.session.send(self.iopub_socket, 'shutdown_reply',
179 content, parent, ident)
180 self.log.debug(msg)
181 time.sleep(0.1)
182 sys.exit(0)
183
184 #---------------------------------------------------------------------------
185 # Protected interface
186 #---------------------------------------------------------------------------
187
188 def _abort_queue(self):
189 while True:
190 ident,msg = self.session.recv(self.shell_socket, zmq.NOBLOCK)
191 if msg is None:
192 # msg=None on EAGAIN
193 break
194 else:
195 assert ident is not None, "Missing message part."
196 self.log.debug("Aborting: %s"%Message(msg))
197 msg_type = msg['header']['msg_type']
198 reply_type = msg_type.split('_')[0] + '_reply'
199 reply_msg = self.session.send(self.shell_socket, reply_type, {'status':'aborted'}, msg, ident=ident)
200 self.log.debug(Message(reply_msg))
201 # We need to wait a bit for requests to come in. This can probably
202 # be set shorter for true asynchronous clients.
203 time.sleep(0.1)
204
205 def _raw_input(self, prompt, ident, parent):
206 # Flush output before making the request.
207 sys.stderr.flush()
208 sys.stdout.flush()
209
210 # Send the input request.
211 content = dict(prompt=prompt)
212 msg = self.session.send(self.stdin_socket, u'input_request', content, parent, ident=ident)
213
214 # Await a response.
215 ident,reply = self.session.recv(self.stdin_socket, 0)
216 try:
217 value = reply['content']['value']
218 except:
219 self.log.error("Got bad raw_input reply: %s"%Message(parent))
220 value = ''
221 return value
222
223 def _complete(self, msg):
224 return self.completer.complete(msg.content.line, msg.content.text)
225
226 def _object_info(self, context):
227 symbol, leftover = self._symbol_from_context(context)
228 if symbol is not None and not leftover:
229 doc = getattr(symbol, '__doc__', '')
230 else:
231 doc = ''
232 object_info = dict(docstring = doc)
233 return object_info
234
235 def _symbol_from_context(self, context):
236 if not context:
237 return None, context
238
239 base_symbol_string = context[0]
240 symbol = self.user_ns.get(base_symbol_string, None)
241 if symbol is None:
242 symbol = __builtin__.__dict__.get(base_symbol_string, None)
243 if symbol is None:
244 return None, context
245
246 context = context[1:]
247 for i, name in enumerate(context):
248 new_symbol = getattr(symbol, name, None)
249 if new_symbol is None:
250 return symbol, context[i:]
251 else:
252 symbol = new_symbol
253
254 return symbol, []
255
256 #-----------------------------------------------------------------------------
257 # Kernel main and launch functions
258 #-----------------------------------------------------------------------------
259
260 def launch_kernel(*args, **kwargs):
261 """ Launches a simple Python kernel, binding to the specified ports.
262
263 This function simply calls entry_point.base_launch_kernel with the right first
264 command to start a pykernel. See base_launch_kernel for arguments.
265
266 Returns
267 -------
268 A tuple of form:
269 (kernel_process, xrep_port, pub_port, req_port, hb_port)
270 where kernel_process is a Popen object and the ports are integers.
271 """
272 return base_launch_kernel('from IPython.zmq.pykernel import main; main()',
273 *args, **kwargs)
274
275 def main():
276 """Run a PyKernel as an application"""
277 app = KernelApp.instance()
278 app.initialize()
279 app.start()
280
281 if __name__ == '__main__':
282 main()
General Comments 0
You need to be logged in to leave comments. Login now