From 48fb9d900a691f50b1b932bf8143bce9acf2df87 2008-06-09 20:02:25 From: Brian E Granger Date: 2008-06-09 20:02:25 Subject: [PATCH] Merge of the ipython-ipython1a branch into the ipython trunk. This merge represents the first merging of the things in ipython1-dev into ipython. More specifically, this merge includes the basic ipython1 kernel and a few related subpackages. Most importantly, the setup.py script and friends have been refactored. --- diff --git a/IPython/Release.py b/IPython/Release.py index cc9e745..cc62c70 100644 --- a/IPython/Release.py +++ b/IPython/Release.py @@ -32,7 +32,7 @@ else: version = '0.8.4' -description = "An enhanced interactive Python shell." +description = "Tools for interactive development in Python." long_description = \ """ @@ -77,13 +77,19 @@ license = 'BSD' authors = {'Fernando' : ('Fernando Perez','fperez@colorado.edu'), 'Janko' : ('Janko Hauser','jhauser@zscout.de'), 'Nathan' : ('Nathaniel Gray','n8gray@caltech.edu'), - 'Ville' : ('Ville Vainio','vivainio@gmail.com') + 'Ville' : ('Ville Vainio','vivainio@gmail.com'), + 'Brian' : ('Brian E Granger', 'ellisonbg@gmail.com'), + 'Min' : ('Min Ragan-Kelley', 'benjaminrk@gmail.com') } +author = 'The IPython Development Team' + +author_email = 'ipython-dev@scipy.org' + url = 'http://ipython.scipy.org' download_url = 'http://ipython.scipy.org/dist' platforms = ['Linux','Mac OSX','Windows XP/2000/NT','Windows 95/98/ME'] -keywords = ['Interactive','Interpreter','Shell'] +keywords = ['Interactive','Interpreter','Shell','Parallel','Distributed'] diff --git a/IPython/config/__init__.py b/IPython/config/__init__.py new file mode 100644 index 0000000..876b4f3 --- /dev/null +++ b/IPython/config/__init__.py @@ -0,0 +1,14 @@ +# encoding: utf-8 + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- \ No newline at end of file diff --git a/IPython/config/api.py b/IPython/config/api.py new file mode 100644 index 0000000..1fca60e --- /dev/null +++ b/IPython/config/api.py @@ -0,0 +1,99 @@ +# encoding: utf-8 + +"""This is the official entry point to IPython's configuration system. """ + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import os +from IPython.config.cutils import get_home_dir, get_ipython_dir +from IPython.external.configobj import ConfigObj + +class ConfigObjManager(object): + + def __init__(self, configObj, filename): + self.current = configObj + self.current.indent_type = ' ' + self.filename = filename + # self.write_default_config_file() + + def get_config_obj(self): + return self.current + + def update_config_obj(self, newConfig): + self.current.merge(newConfig) + + def update_config_obj_from_file(self, filename): + newConfig = ConfigObj(filename, file_error=False) + self.current.merge(newConfig) + + def update_config_obj_from_default_file(self, ipythondir=None): + fname = self.resolve_file_path(self.filename, ipythondir) + self.update_config_obj_from_file(fname) + + def write_config_obj_to_file(self, filename): + f = open(filename, 'w') + self.current.write(f) + f.close() + + def write_default_config_file(self): + ipdir = get_ipython_dir() + fname = ipdir + '/' + self.filename + if not os.path.isfile(fname): + print "Writing the configuration file to: " + fname + self.write_config_obj_to_file(fname) + + def _import(self, key): + package = '.'.join(key.split('.')[0:-1]) + obj = key.split('.')[-1] + execString = 'from %s import %s' % (package, obj) + exec execString + exec 'temp = %s' % obj + return temp + + def resolve_file_path(self, filename, ipythondir = None): + """Resolve filenames into absolute paths. + + This function looks in the following directories in order: + + 1. In the current working directory or by absolute path with ~ expanded + 2. In ipythondir if that is set + 3. In the IPYTHONDIR environment variable if it exists + 4. In the ~/.ipython directory + + Note: The IPYTHONDIR is also used by the trunk version of IPython so + changing it will also affect it was well. + """ + + # In cwd or by absolute path with ~ expanded + trythis = os.path.expanduser(filename) + if os.path.isfile(trythis): + return trythis + + # In ipythondir if it is set + if ipythondir is not None: + trythis = ipythondir + '/' + filename + if os.path.isfile(trythis): + return trythis + + trythis = get_ipython_dir() + '/' + filename + if os.path.isfile(trythis): + return trythis + + return None + + + + + + diff --git a/IPython/config/cutils.py b/IPython/config/cutils.py new file mode 100644 index 0000000..ad3cfc4 --- /dev/null +++ b/IPython/config/cutils.py @@ -0,0 +1,99 @@ +# encoding: utf-8 + +"""Configuration-related utilities for all IPython.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import os +import sys + +#--------------------------------------------------------------------------- +# Normal code begins +#--------------------------------------------------------------------------- + +class HomeDirError(Exception): + pass + +def get_home_dir(): + """Return the closest possible equivalent to a 'home' directory. + + We first try $HOME. Absent that, on NT it's $HOMEDRIVE\$HOMEPATH. + + Currently only Posix and NT are implemented, a HomeDirError exception is + raised for all other OSes. """ + + isdir = os.path.isdir + env = os.environ + try: + homedir = env['HOME'] + if not isdir(homedir): + # in case a user stuck some string which does NOT resolve to a + # valid path, it's as good as if we hadn't foud it + raise KeyError + return homedir + except KeyError: + if os.name == 'posix': + raise HomeDirError,'undefined $HOME, IPython can not proceed.' + elif os.name == 'nt': + # For some strange reason, win9x returns 'nt' for os.name. + try: + homedir = os.path.join(env['HOMEDRIVE'],env['HOMEPATH']) + if not isdir(homedir): + homedir = os.path.join(env['USERPROFILE']) + if not isdir(homedir): + raise HomeDirError + return homedir + except: + try: + # Use the registry to get the 'My Documents' folder. + import _winreg as wreg + key = wreg.OpenKey(wreg.HKEY_CURRENT_USER, + "Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders") + homedir = wreg.QueryValueEx(key,'Personal')[0] + key.Close() + if not isdir(homedir): + e = ('Invalid "Personal" folder registry key ' + 'typically "My Documents".\n' + 'Value: %s\n' + 'This is not a valid directory on your system.' % + homedir) + raise HomeDirError(e) + return homedir + except HomeDirError: + raise + except: + return 'C:\\' + elif os.name == 'dos': + # Desperate, may do absurd things in classic MacOS. May work under DOS. + return 'C:\\' + else: + raise HomeDirError,'support for your operating system not implemented.' + +def get_ipython_dir(): + ipdir_def = '.ipython' + home_dir = get_home_dir() + ipdir = os.path.abspath(os.environ.get('IPYTHONDIR', + os.path.join(home_dir,ipdir_def))) + return ipdir + +def import_item(key): + """ + Import and return bar given the string foo.bar. + """ + package = '.'.join(key.split('.')[0:-1]) + obj = key.split('.')[-1] + execString = 'from %s import %s' % (package, obj) + exec execString + exec 'temp = %s' % obj + return temp diff --git a/IPython/config/sconfig.py b/IPython/config/sconfig.py new file mode 100644 index 0000000..886d201 --- /dev/null +++ b/IPython/config/sconfig.py @@ -0,0 +1,622 @@ +# encoding: utf-8 + +"""Mix of ConfigObj and Struct-like access. + +Provides: + +- Coupling a Struct object to a ConfigObj one, so that changes to the Traited + instance propagate back into the ConfigObj. + +- A declarative interface for describing configurations that automatically maps + to valid ConfigObj representations. + +- From these descriptions, valid .conf files can be auto-generated, with class + docstrings and traits information used for initial auto-documentation. + +- Hierarchical inclusion of files, so that a base config can be overridden only + in specific spots. + + +Notes: + +The file creation policy is: + +1. Creating a SConfigManager(FooConfig,'missingfile.conf') will work +fine, and 'missingfile.conf' will be created empty. + +2. Creating SConfigManager(FooConfig,'OKfile.conf') where OKfile.conf has + +include = 'missingfile.conf' + +conks out with IOError. + +My rationale is that creating top-level empty files is a common and +reasonable need, but that having invalid include statements should +raise an error right away, so people know immediately that their files +have gone stale. + + +TODO: + + - Turn the currently interactive tests into proper doc/unit tests. Complete + docstrings. + + - Write the real ipython1 config system using this. That one is more + complicated than either the MPL one or the fake 'ipythontest' that I wrote + here, and it requires solving the issue of declaring references to other + objects inside the config files. + + - [Low priority] Write a custom TraitsUI view so that hierarchical + configurations provide nicer interactive editing. The automatic system is + remarkably good, but for very complex configurations having a nicely + organized view would be nice. +""" + +__docformat__ = "restructuredtext en" +__license__ = 'BSD' + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +############################################################################ +# Stdlib imports +############################################################################ +from cStringIO import StringIO +from inspect import isclass + +import os +import textwrap + +############################################################################ +# External imports +############################################################################ + +from IPython.external import configobj + +############################################################################ +# Utility functions +############################################################################ + +def get_split_ind(seq, N): + """seq is a list of words. Return the index into seq such that + len(' '.join(seq[:ind])<=N + """ + + sLen = 0 + # todo: use Alex's xrange pattern from the cbook for efficiency + for (word, ind) in zip(seq, range(len(seq))): + sLen += len(word) + 1 # +1 to account for the len(' ') + if sLen>=N: return ind + return len(seq) + +def wrap(prefix, text, cols, max_lines=6): + """'wrap text with prefix at length cols""" + pad = ' '*len(prefix.expandtabs()) + available = cols - len(pad) + + seq = text.split(' ') + Nseq = len(seq) + ind = 0 + lines = [] + while ind num_lines-abbr_end-1: + ret += pad + ' '.join(lines[i]) + '\n' + else: + if not lines_skipped: + lines_skipped = True + ret += ' <...snipped %d lines...> \n' % (num_lines-max_lines) +# for line in lines[1:]: +# ret += pad + ' '.join(line) + '\n' + return ret[:-1] + +def dedent(txt): + """A modified version of textwrap.dedent, specialized for docstrings. + + This version doesn't get confused by the first line of text having + inconsistent indentation from the rest, which happens a lot in docstrings. + + :Examples: + + >>> s = ''' + ... First line. + ... More... + ... End''' + + >>> print dedent(s) + First line. + More... + End + + >>> s = '''First line + ... More... + ... End''' + + >>> print dedent(s) + First line + More... + End + """ + out = [textwrap.dedent(t) for t in txt.split('\n',1) + if t and not t.isspace()] + return '\n'.join(out) + + +def comment(strng,indent=''): + """return an input string, commented out""" + template = indent + '# %s' + lines = [template % s for s in strng.splitlines(True)] + return ''.join(lines) + + +def configobj2str(cobj): + """Dump a Configobj instance to a string.""" + outstr = StringIO() + cobj.write(outstr) + return outstr.getvalue() + +def get_config_filename(conf): + """Find the filename attribute of a ConfigObj given a sub-section object. + """ + depth = conf.depth + for d in range(depth): + conf = conf.parent + return conf.filename + +def sconf2File(sconf,fname,force=False): + """Write a SConfig instance to a given filename. + + :Keywords: + + force : bool (False) + If true, force writing even if the file exists. + """ + + if os.path.isfile(fname) and not force: + raise IOError("File %s already exists, use force=True to overwrite" % + fname) + + txt = repr(sconf) + + fobj = open(fname,'w') + fobj.write(txt) + fobj.close() + +def filter_scalars(sc): + """ input sc MUST be sorted!!!""" + scalars = [] + maxi = len(sc)-1 + i = 0 + while i close an open quote for stupid emacs + +#***************************************************************************** +# +# Copyright (c) 2001 Ka-Ping Yee +# +# +# Published under the terms of the MIT license, hereby reproduced: +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +# +#***************************************************************************** + +__author__ = 'Ka-Ping Yee ' +__license__ = 'MIT' + +import string +import sys +from tokenize import tokenprog +from types import StringType + +class ItplError(ValueError): + def __init__(self, text, pos): + self.text = text + self.pos = pos + def __str__(self): + return "unfinished expression in %s at char %d" % ( + repr(self.text), self.pos) + +def matchorfail(text, pos): + match = tokenprog.match(text, pos) + if match is None: + raise ItplError(text, pos) + return match, match.end() + +class Itpl: + """Class representing a string with interpolation abilities. + + Upon creation, an instance works out what parts of the format + string are literal and what parts need to be evaluated. The + evaluation and substitution happens in the namespace of the + caller when str(instance) is called.""" + + def __init__(self, format,codec='utf_8',encoding_errors='backslashreplace'): + """The single mandatory argument to this constructor is a format + string. + + The format string is parsed according to the following rules: + + 1. A dollar sign and a name, possibly followed by any of: + - an open-paren, and anything up to the matching paren + - an open-bracket, and anything up to the matching bracket + - a period and a name + any number of times, is evaluated as a Python expression. + + 2. A dollar sign immediately followed by an open-brace, and + anything up to the matching close-brace, is evaluated as + a Python expression. + + 3. Outside of the expressions described in the above two rules, + two dollar signs in a row give you one literal dollar sign. + + Optional arguments: + + - codec('utf_8'): a string containing the name of a valid Python + codec. + + - encoding_errors('backslashreplace'): a string with a valid error handling + policy. See the codecs module documentation for details. + + These are used to encode the format string if a call to str() fails on + the expanded result.""" + + if not isinstance(format,basestring): + raise TypeError, "needs string initializer" + self.format = format + self.codec = codec + self.encoding_errors = encoding_errors + + namechars = "abcdefghijklmnopqrstuvwxyz" \ + "ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_"; + chunks = [] + pos = 0 + + while 1: + dollar = string.find(format, "$", pos) + if dollar < 0: break + nextchar = format[dollar+1] + + if nextchar == "{": + chunks.append((0, format[pos:dollar])) + pos, level = dollar+2, 1 + while level: + match, pos = matchorfail(format, pos) + tstart, tend = match.regs[3] + token = format[tstart:tend] + if token == "{": level = level+1 + elif token == "}": level = level-1 + chunks.append((1, format[dollar+2:pos-1])) + + elif nextchar in namechars: + chunks.append((0, format[pos:dollar])) + match, pos = matchorfail(format, dollar+1) + while pos < len(format): + if format[pos] == "." and \ + pos+1 < len(format) and format[pos+1] in namechars: + match, pos = matchorfail(format, pos+1) + elif format[pos] in "([": + pos, level = pos+1, 1 + while level: + match, pos = matchorfail(format, pos) + tstart, tend = match.regs[3] + token = format[tstart:tend] + if token[0] in "([": level = level+1 + elif token[0] in ")]": level = level-1 + else: break + chunks.append((1, format[dollar+1:pos])) + + else: + chunks.append((0, format[pos:dollar+1])) + pos = dollar + 1 + (nextchar == "$") + + if pos < len(format): chunks.append((0, format[pos:])) + self.chunks = chunks + + def __repr__(self): + return "" % repr(self.format) + + def _str(self,glob,loc): + """Evaluate to a string in the given globals/locals. + + The final output is built by calling str(), but if this fails, the + result is encoded with the instance's codec and error handling policy, + via a call to out.encode(self.codec,self.encoding_errors)""" + result = [] + app = result.append + for live, chunk in self.chunks: + if live: app(str(eval(chunk,glob,loc))) + else: app(chunk) + out = ''.join(result) + try: + return str(out) + except UnicodeError: + return out.encode(self.codec,self.encoding_errors) + + def __str__(self): + """Evaluate and substitute the appropriate parts of the string.""" + + # We need to skip enough frames to get to the actual caller outside of + # Itpl. + frame = sys._getframe(1) + while frame.f_globals["__name__"] == __name__: frame = frame.f_back + loc, glob = frame.f_locals, frame.f_globals + + return self._str(glob,loc) + +class ItplNS(Itpl): + """Class representing a string with interpolation abilities. + + This inherits from Itpl, but at creation time a namespace is provided + where the evaluation will occur. The interpolation becomes a bit more + efficient, as no traceback needs to be extracte. It also allows the + caller to supply a different namespace for the interpolation to occur than + its own.""" + + def __init__(self, format,globals,locals=None, + codec='utf_8',encoding_errors='backslashreplace'): + """ItplNS(format,globals[,locals]) -> interpolating string instance. + + This constructor, besides a format string, takes a globals dictionary + and optionally a locals (which defaults to globals if not provided). + + For further details, see the Itpl constructor.""" + + if locals is None: + locals = globals + self.globals = globals + self.locals = locals + Itpl.__init__(self,format,codec,encoding_errors) + + def __str__(self): + """Evaluate and substitute the appropriate parts of the string.""" + return self._str(self.globals,self.locals) + + def __repr__(self): + return "" % repr(self.format) + +# utilities for fast printing +def itpl(text): return str(Itpl(text)) +def printpl(text): print itpl(text) +# versions with namespace +def itplns(text,globals,locals=None): return str(ItplNS(text,globals,locals)) +def printplns(text,globals,locals=None): print itplns(text,globals,locals) + +class ItplFile: + """A file object that filters each write() through an interpolator.""" + def __init__(self, file): self.file = file + def __repr__(self): return "" + def __getattr__(self, attr): return getattr(self.file, attr) + def write(self, text): self.file.write(str(Itpl(text))) + +def filter(file=sys.stdout): + """Return an ItplFile that filters writes to the given file object. + + 'file = filter(file)' replaces 'file' with a filtered object that + has a write() method. When called with no argument, this creates + a filter to sys.stdout.""" + return ItplFile(file) + +def unfilter(ifile=None): + """Return the original file that corresponds to the given ItplFile. + + 'file = unfilter(file)' undoes the effect of 'file = filter(file)'. + 'sys.stdout = unfilter()' undoes the effect of 'sys.stdout = filter()'.""" + return ifile and ifile.file or sys.stdout.file diff --git a/IPython/external/configobj.py b/IPython/external/configobj.py new file mode 100644 index 0000000..9e64f18 --- /dev/null +++ b/IPython/external/configobj.py @@ -0,0 +1,2501 @@ +# configobj.py +# A config file reader/writer that supports nested sections in config files. +# Copyright (C) 2005-2008 Michael Foord, Nicola Larosa +# E-mail: fuzzyman AT voidspace DOT org DOT uk +# nico AT tekNico DOT net + +# ConfigObj 4 +# http://www.voidspace.org.uk/python/configobj.html + +# Released subject to the BSD License +# Please see http://www.voidspace.org.uk/python/license.shtml + +# Scripts maintained at http://www.voidspace.org.uk/python/index.shtml +# For information about bugfixes, updates and support, please join the +# ConfigObj mailing list: +# http://lists.sourceforge.net/lists/listinfo/configobj-develop +# Comments, suggestions and bug reports welcome. + +from __future__ import generators + +import sys +INTP_VER = sys.version_info[:2] +if INTP_VER < (2, 2): + raise RuntimeError("Python v.2.2 or later needed") + +import os, re +compiler = None +try: + import compiler +except ImportError: + # for IronPython + pass +from types import StringTypes +from warnings import warn +try: + from codecs import BOM_UTF8, BOM_UTF16, BOM_UTF16_BE, BOM_UTF16_LE +except ImportError: + # Python 2.2 does not have these + # UTF-8 + BOM_UTF8 = '\xef\xbb\xbf' + # UTF-16, little endian + BOM_UTF16_LE = '\xff\xfe' + # UTF-16, big endian + BOM_UTF16_BE = '\xfe\xff' + if sys.byteorder == 'little': + # UTF-16, native endianness + BOM_UTF16 = BOM_UTF16_LE + else: + # UTF-16, native endianness + BOM_UTF16 = BOM_UTF16_BE + +# A dictionary mapping BOM to +# the encoding to decode with, and what to set the +# encoding attribute to. +BOMS = { + BOM_UTF8: ('utf_8', None), + BOM_UTF16_BE: ('utf16_be', 'utf_16'), + BOM_UTF16_LE: ('utf16_le', 'utf_16'), + BOM_UTF16: ('utf_16', 'utf_16'), + } +# All legal variants of the BOM codecs. +# TODO: the list of aliases is not meant to be exhaustive, is there a +# better way ? +BOM_LIST = { + 'utf_16': 'utf_16', + 'u16': 'utf_16', + 'utf16': 'utf_16', + 'utf-16': 'utf_16', + 'utf16_be': 'utf16_be', + 'utf_16_be': 'utf16_be', + 'utf-16be': 'utf16_be', + 'utf16_le': 'utf16_le', + 'utf_16_le': 'utf16_le', + 'utf-16le': 'utf16_le', + 'utf_8': 'utf_8', + 'u8': 'utf_8', + 'utf': 'utf_8', + 'utf8': 'utf_8', + 'utf-8': 'utf_8', + } + +# Map of encodings to the BOM to write. +BOM_SET = { + 'utf_8': BOM_UTF8, + 'utf_16': BOM_UTF16, + 'utf16_be': BOM_UTF16_BE, + 'utf16_le': BOM_UTF16_LE, + None: BOM_UTF8 + } + + +def match_utf8(encoding): + return BOM_LIST.get(encoding.lower()) == 'utf_8' + + +# Quote strings used for writing values +squot = "'%s'" +dquot = '"%s"' +noquot = "%s" +wspace_plus = ' \r\t\n\v\t\'"' +tsquot = '"""%s"""' +tdquot = "'''%s'''" + +try: + enumerate +except NameError: + def enumerate(obj): + """enumerate for Python 2.2.""" + i = -1 + for item in obj: + i += 1 + yield i, item + +try: + True, False +except NameError: + True, False = 1, 0 + + +__version__ = '4.5.2' + +__revision__ = '$Id: configobj.py 156 2006-01-31 14:57:08Z fuzzyman $' + +__docformat__ = "restructuredtext en" + +__all__ = ( + '__version__', + 'DEFAULT_INDENT_TYPE', + 'DEFAULT_INTERPOLATION', + 'ConfigObjError', + 'NestingError', + 'ParseError', + 'DuplicateError', + 'ConfigspecError', + 'ConfigObj', + 'SimpleVal', + 'InterpolationError', + 'InterpolationLoopError', + 'MissingInterpolationOption', + 'RepeatSectionError', + 'ReloadError', + 'UnreprError', + 'UnknownType', + '__docformat__', + 'flatten_errors', +) + +DEFAULT_INTERPOLATION = 'configparser' +DEFAULT_INDENT_TYPE = ' ' +MAX_INTERPOL_DEPTH = 10 + +OPTION_DEFAULTS = { + 'interpolation': True, + 'raise_errors': False, + 'list_values': True, + 'create_empty': False, + 'file_error': False, + 'configspec': None, + 'stringify': True, + # option may be set to one of ('', ' ', '\t') + 'indent_type': None, + 'encoding': None, + 'default_encoding': None, + 'unrepr': False, + 'write_empty_values': False, +} + + + +def getObj(s): + s = "a=" + s + if compiler is None: + raise ImportError('compiler module not available') + p = compiler.parse(s) + return p.getChildren()[1].getChildren()[0].getChildren()[1] + + +class UnknownType(Exception): + pass + + +class Builder(object): + + def build(self, o): + m = getattr(self, 'build_' + o.__class__.__name__, None) + if m is None: + raise UnknownType(o.__class__.__name__) + return m(o) + + def build_List(self, o): + return map(self.build, o.getChildren()) + + def build_Const(self, o): + return o.value + + def build_Dict(self, o): + d = {} + i = iter(map(self.build, o.getChildren())) + for el in i: + d[el] = i.next() + return d + + def build_Tuple(self, o): + return tuple(self.build_List(o)) + + def build_Name(self, o): + if o.name == 'None': + return None + if o.name == 'True': + return True + if o.name == 'False': + return False + + # An undefined Name + raise UnknownType('Undefined Name') + + def build_Add(self, o): + real, imag = map(self.build_Const, o.getChildren()) + try: + real = float(real) + except TypeError: + raise UnknownType('Add') + if not isinstance(imag, complex) or imag.real != 0.0: + raise UnknownType('Add') + return real+imag + + def build_Getattr(self, o): + parent = self.build(o.expr) + return getattr(parent, o.attrname) + + def build_UnarySub(self, o): + return -self.build_Const(o.getChildren()[0]) + + def build_UnaryAdd(self, o): + return self.build_Const(o.getChildren()[0]) + + +_builder = Builder() + + +def unrepr(s): + if not s: + return s + return _builder.build(getObj(s)) + + + +class ConfigObjError(SyntaxError): + """ + This is the base class for all errors that ConfigObj raises. + It is a subclass of SyntaxError. + """ + def __init__(self, message='', line_number=None, line=''): + self.line = line + self.line_number = line_number + self.message = message + SyntaxError.__init__(self, message) + + +class NestingError(ConfigObjError): + """ + This error indicates a level of nesting that doesn't match. + """ + + +class ParseError(ConfigObjError): + """ + This error indicates that a line is badly written. + It is neither a valid ``key = value`` line, + nor a valid section marker line. + """ + + +class ReloadError(IOError): + """ + A 'reload' operation failed. + This exception is a subclass of ``IOError``. + """ + def __init__(self): + IOError.__init__(self, 'reload failed, filename is not set.') + + +class DuplicateError(ConfigObjError): + """ + The keyword or section specified already exists. + """ + + +class ConfigspecError(ConfigObjError): + """ + An error occured whilst parsing a configspec. + """ + + +class InterpolationError(ConfigObjError): + """Base class for the two interpolation errors.""" + + +class InterpolationLoopError(InterpolationError): + """Maximum interpolation depth exceeded in string interpolation.""" + + def __init__(self, option): + InterpolationError.__init__( + self, + 'interpolation loop detected in value "%s".' % option) + + +class RepeatSectionError(ConfigObjError): + """ + This error indicates additional sections in a section with a + ``__many__`` (repeated) section. + """ + + +class MissingInterpolationOption(InterpolationError): + """A value specified for interpolation was missing.""" + + def __init__(self, option): + InterpolationError.__init__( + self, + 'missing option "%s" in interpolation.' % option) + + +class UnreprError(ConfigObjError): + """An error parsing in unrepr mode.""" + + + +class InterpolationEngine(object): + """ + A helper class to help perform string interpolation. + + This class is an abstract base class; its descendants perform + the actual work. + """ + + # compiled regexp to use in self.interpolate() + _KEYCRE = re.compile(r"%\(([^)]*)\)s") + + def __init__(self, section): + # the Section instance that "owns" this engine + self.section = section + + + def interpolate(self, key, value): + def recursive_interpolate(key, value, section, backtrail): + """The function that does the actual work. + + ``value``: the string we're trying to interpolate. + ``section``: the section in which that string was found + ``backtrail``: a dict to keep track of where we've been, + to detect and prevent infinite recursion loops + + This is similar to a depth-first-search algorithm. + """ + # Have we been here already? + if backtrail.has_key((key, section.name)): + # Yes - infinite loop detected + raise InterpolationLoopError(key) + # Place a marker on our backtrail so we won't come back here again + backtrail[(key, section.name)] = 1 + + # Now start the actual work + match = self._KEYCRE.search(value) + while match: + # The actual parsing of the match is implementation-dependent, + # so delegate to our helper function + k, v, s = self._parse_match(match) + if k is None: + # That's the signal that no further interpolation is needed + replacement = v + else: + # Further interpolation may be needed to obtain final value + replacement = recursive_interpolate(k, v, s, backtrail) + # Replace the matched string with its final value + start, end = match.span() + value = ''.join((value[:start], replacement, value[end:])) + new_search_start = start + len(replacement) + # Pick up the next interpolation key, if any, for next time + # through the while loop + match = self._KEYCRE.search(value, new_search_start) + + # Now safe to come back here again; remove marker from backtrail + del backtrail[(key, section.name)] + + return value + + # Back in interpolate(), all we have to do is kick off the recursive + # function with appropriate starting values + value = recursive_interpolate(key, value, self.section, {}) + return value + + + def _fetch(self, key): + """Helper function to fetch values from owning section. + + Returns a 2-tuple: the value, and the section where it was found. + """ + # switch off interpolation before we try and fetch anything ! + save_interp = self.section.main.interpolation + self.section.main.interpolation = False + + # Start at section that "owns" this InterpolationEngine + current_section = self.section + while True: + # try the current section first + val = current_section.get(key) + if val is not None: + break + # try "DEFAULT" next + val = current_section.get('DEFAULT', {}).get(key) + if val is not None: + break + # move up to parent and try again + # top-level's parent is itself + if current_section.parent is current_section: + # reached top level, time to give up + break + current_section = current_section.parent + + # restore interpolation to previous value before returning + self.section.main.interpolation = save_interp + if val is None: + raise MissingInterpolationOption(key) + return val, current_section + + + def _parse_match(self, match): + """Implementation-dependent helper function. + + Will be passed a match object corresponding to the interpolation + key we just found (e.g., "%(foo)s" or "$foo"). Should look up that + key in the appropriate config file section (using the ``_fetch()`` + helper function) and return a 3-tuple: (key, value, section) + + ``key`` is the name of the key we're looking for + ``value`` is the value found for that key + ``section`` is a reference to the section where it was found + + ``key`` and ``section`` should be None if no further + interpolation should be performed on the resulting value + (e.g., if we interpolated "$$" and returned "$"). + """ + raise NotImplementedError() + + + +class ConfigParserInterpolation(InterpolationEngine): + """Behaves like ConfigParser.""" + _KEYCRE = re.compile(r"%\(([^)]*)\)s") + + def _parse_match(self, match): + key = match.group(1) + value, section = self._fetch(key) + return key, value, section + + + +class TemplateInterpolation(InterpolationEngine): + """Behaves like string.Template.""" + _delimiter = '$' + _KEYCRE = re.compile(r""" + \$(?: + (?P\$) | # Two $ signs + (?P[_a-z][_a-z0-9]*) | # $name format + {(?P[^}]*)} # ${name} format + ) + """, re.IGNORECASE | re.VERBOSE) + + def _parse_match(self, match): + # Valid name (in or out of braces): fetch value from section + key = match.group('named') or match.group('braced') + if key is not None: + value, section = self._fetch(key) + return key, value, section + # Escaped delimiter (e.g., $$): return single delimiter + if match.group('escaped') is not None: + # Return None for key and section to indicate it's time to stop + return None, self._delimiter, None + # Anything else: ignore completely, just return it unchanged + return None, match.group(), None + + +interpolation_engines = { + 'configparser': ConfigParserInterpolation, + 'template': TemplateInterpolation, +} + + + +class Section(dict): + """ + A dictionary-like object that represents a section in a config file. + + It does string interpolation if the 'interpolation' attribute + of the 'main' object is set to True. + + Interpolation is tried first from this object, then from the 'DEFAULT' + section of this object, next from the parent and its 'DEFAULT' section, + and so on until the main object is reached. + + A Section will behave like an ordered dictionary - following the + order of the ``scalars`` and ``sections`` attributes. + You can use this to change the order of members. + + Iteration follows the order: scalars, then sections. + """ + + def __init__(self, parent, depth, main, indict=None, name=None): + """ + * parent is the section above + * depth is the depth level of this section + * main is the main ConfigObj + * indict is a dictionary to initialise the section with + """ + if indict is None: + indict = {} + dict.__init__(self) + # used for nesting level *and* interpolation + self.parent = parent + # used for the interpolation attribute + self.main = main + # level of nesting depth of this Section + self.depth = depth + # purely for information + self.name = name + # + self._initialise() + # we do this explicitly so that __setitem__ is used properly + # (rather than just passing to ``dict.__init__``) + for entry, value in indict.iteritems(): + self[entry] = value + + + def _initialise(self): + # the sequence of scalar values in this Section + self.scalars = [] + # the sequence of sections in this Section + self.sections = [] + # for comments :-) + self.comments = {} + self.inline_comments = {} + # for the configspec + self.configspec = {} + self._order = [] + self._configspec_comments = {} + self._configspec_inline_comments = {} + self._cs_section_comments = {} + self._cs_section_inline_comments = {} + # for defaults + self.defaults = [] + self.default_values = {} + + + def _interpolate(self, key, value): + try: + # do we already have an interpolation engine? + engine = self._interpolation_engine + except AttributeError: + # not yet: first time running _interpolate(), so pick the engine + name = self.main.interpolation + if name == True: # note that "if name:" would be incorrect here + # backwards-compatibility: interpolation=True means use default + name = DEFAULT_INTERPOLATION + name = name.lower() # so that "Template", "template", etc. all work + class_ = interpolation_engines.get(name, None) + if class_ is None: + # invalid value for self.main.interpolation + self.main.interpolation = False + return value + else: + # save reference to engine so we don't have to do this again + engine = self._interpolation_engine = class_(self) + # let the engine do the actual work + return engine.interpolate(key, value) + + + def __getitem__(self, key): + """Fetch the item and do string interpolation.""" + val = dict.__getitem__(self, key) + if self.main.interpolation and isinstance(val, StringTypes): + return self._interpolate(key, val) + return val + + + def __setitem__(self, key, value, unrepr=False): + """ + Correctly set a value. + + Making dictionary values Section instances. + (We have to special case 'Section' instances - which are also dicts) + + Keys must be strings. + Values need only be strings (or lists of strings) if + ``main.stringify`` is set. + + `unrepr`` must be set when setting a value to a dictionary, without + creating a new sub-section. + """ + if not isinstance(key, StringTypes): + raise ValueError('The key "%s" is not a string.' % key) + + # add the comment + if not self.comments.has_key(key): + self.comments[key] = [] + self.inline_comments[key] = '' + # remove the entry from defaults + if key in self.defaults: + self.defaults.remove(key) + # + if isinstance(value, Section): + if not self.has_key(key): + self.sections.append(key) + dict.__setitem__(self, key, value) + elif isinstance(value, dict) and not unrepr: + # First create the new depth level, + # then create the section + if not self.has_key(key): + self.sections.append(key) + new_depth = self.depth + 1 + dict.__setitem__( + self, + key, + Section( + self, + new_depth, + self.main, + indict=value, + name=key)) + else: + if not self.has_key(key): + self.scalars.append(key) + if not self.main.stringify: + if isinstance(value, StringTypes): + pass + elif isinstance(value, (list, tuple)): + for entry in value: + if not isinstance(entry, StringTypes): + raise TypeError('Value is not a string "%s".' % entry) + else: + raise TypeError('Value is not a string "%s".' % value) + dict.__setitem__(self, key, value) + + + def __delitem__(self, key): + """Remove items from the sequence when deleting.""" + dict. __delitem__(self, key) + if key in self.scalars: + self.scalars.remove(key) + else: + self.sections.remove(key) + del self.comments[key] + del self.inline_comments[key] + + + def get(self, key, default=None): + """A version of ``get`` that doesn't bypass string interpolation.""" + try: + return self[key] + except KeyError: + return default + + + def update(self, indict): + """ + A version of update that uses our ``__setitem__``. + """ + for entry in indict: + self[entry] = indict[entry] + + + def pop(self, key, *args): + """ + 'D.pop(k[,d]) -> v, remove specified key and return the corresponding value. + If key is not found, d is returned if given, otherwise KeyError is raised' + """ + val = dict.pop(self, key, *args) + if key in self.scalars: + del self.comments[key] + del self.inline_comments[key] + self.scalars.remove(key) + elif key in self.sections: + del self.comments[key] + del self.inline_comments[key] + self.sections.remove(key) + if self.main.interpolation and isinstance(val, StringTypes): + return self._interpolate(key, val) + return val + + + def popitem(self): + """Pops the first (key,val)""" + sequence = (self.scalars + self.sections) + if not sequence: + raise KeyError(": 'popitem(): dictionary is empty'") + key = sequence[0] + val = self[key] + del self[key] + return key, val + + + def clear(self): + """ + A version of clear that also affects scalars/sections + Also clears comments and configspec. + + Leaves other attributes alone : + depth/main/parent are not affected + """ + dict.clear(self) + self.scalars = [] + self.sections = [] + self.comments = {} + self.inline_comments = {} + self.configspec = {} + + + def setdefault(self, key, default=None): + """A version of setdefault that sets sequence if appropriate.""" + try: + return self[key] + except KeyError: + self[key] = default + return self[key] + + + def items(self): + """D.items() -> list of D's (key, value) pairs, as 2-tuples""" + return zip((self.scalars + self.sections), self.values()) + + + def keys(self): + """D.keys() -> list of D's keys""" + return (self.scalars + self.sections) + + + def values(self): + """D.values() -> list of D's values""" + return [self[key] for key in (self.scalars + self.sections)] + + + def iteritems(self): + """D.iteritems() -> an iterator over the (key, value) items of D""" + return iter(self.items()) + + + def iterkeys(self): + """D.iterkeys() -> an iterator over the keys of D""" + return iter((self.scalars + self.sections)) + + __iter__ = iterkeys + + + def itervalues(self): + """D.itervalues() -> an iterator over the values of D""" + return iter(self.values()) + + + def __repr__(self): + """x.__repr__() <==> repr(x)""" + return '{%s}' % ', '.join([('%s: %s' % (repr(key), repr(self[key]))) + for key in (self.scalars + self.sections)]) + + __str__ = __repr__ + __str__.__doc__ = "x.__str__() <==> str(x)" + + + # Extra methods - not in a normal dictionary + + def dict(self): + """ + Return a deepcopy of self as a dictionary. + + All members that are ``Section`` instances are recursively turned to + ordinary dictionaries - by calling their ``dict`` method. + + >>> n = a.dict() + >>> n == a + 1 + >>> n is a + 0 + """ + newdict = {} + for entry in self: + this_entry = self[entry] + if isinstance(this_entry, Section): + this_entry = this_entry.dict() + elif isinstance(this_entry, list): + # create a copy rather than a reference + this_entry = list(this_entry) + elif isinstance(this_entry, tuple): + # create a copy rather than a reference + this_entry = tuple(this_entry) + newdict[entry] = this_entry + return newdict + + + def merge(self, indict): + """ + A recursive update - useful for merging config files. + + >>> a = '''[section1] + ... option1 = True + ... [[subsection]] + ... more_options = False + ... # end of file'''.splitlines() + >>> b = '''# File is user.ini + ... [section1] + ... option1 = False + ... # end of file'''.splitlines() + >>> c1 = ConfigObj(b) + >>> c2 = ConfigObj(a) + >>> c2.merge(c1) + >>> c2 + {'section1': {'option1': 'False', 'subsection': {'more_options': 'False'}}} + """ + for key, val in indict.items(): + if (key in self and isinstance(self[key], dict) and + isinstance(val, dict)): + self[key].merge(val) + else: + self[key] = val + + + def rename(self, oldkey, newkey): + """ + Change a keyname to another, without changing position in sequence. + + Implemented so that transformations can be made on keys, + as well as on values. (used by encode and decode) + + Also renames comments. + """ + if oldkey in self.scalars: + the_list = self.scalars + elif oldkey in self.sections: + the_list = self.sections + else: + raise KeyError('Key "%s" not found.' % oldkey) + pos = the_list.index(oldkey) + # + val = self[oldkey] + dict.__delitem__(self, oldkey) + dict.__setitem__(self, newkey, val) + the_list.remove(oldkey) + the_list.insert(pos, newkey) + comm = self.comments[oldkey] + inline_comment = self.inline_comments[oldkey] + del self.comments[oldkey] + del self.inline_comments[oldkey] + self.comments[newkey] = comm + self.inline_comments[newkey] = inline_comment + + + def walk(self, function, raise_errors=True, + call_on_sections=False, **keywargs): + """ + Walk every member and call a function on the keyword and value. + + Return a dictionary of the return values + + If the function raises an exception, raise the errror + unless ``raise_errors=False``, in which case set the return value to + ``False``. + + Any unrecognised keyword arguments you pass to walk, will be pased on + to the function you pass in. + + Note: if ``call_on_sections`` is ``True`` then - on encountering a + subsection, *first* the function is called for the *whole* subsection, + and then recurses into it's members. This means your function must be + able to handle strings, dictionaries and lists. This allows you + to change the key of subsections as well as for ordinary members. The + return value when called on the whole subsection has to be discarded. + + See the encode and decode methods for examples, including functions. + + .. caution:: + + You can use ``walk`` to transform the names of members of a section + but you mustn't add or delete members. + + >>> config = '''[XXXXsection] + ... XXXXkey = XXXXvalue'''.splitlines() + >>> cfg = ConfigObj(config) + >>> cfg + {'XXXXsection': {'XXXXkey': 'XXXXvalue'}} + >>> def transform(section, key): + ... val = section[key] + ... newkey = key.replace('XXXX', 'CLIENT1') + ... section.rename(key, newkey) + ... if isinstance(val, (tuple, list, dict)): + ... pass + ... else: + ... val = val.replace('XXXX', 'CLIENT1') + ... section[newkey] = val + >>> cfg.walk(transform, call_on_sections=True) + {'CLIENT1section': {'CLIENT1key': None}} + >>> cfg + {'CLIENT1section': {'CLIENT1key': 'CLIENT1value'}} + """ + out = {} + # scalars first + for i in range(len(self.scalars)): + entry = self.scalars[i] + try: + val = function(self, entry, **keywargs) + # bound again in case name has changed + entry = self.scalars[i] + out[entry] = val + except Exception: + if raise_errors: + raise + else: + entry = self.scalars[i] + out[entry] = False + # then sections + for i in range(len(self.sections)): + entry = self.sections[i] + if call_on_sections: + try: + function(self, entry, **keywargs) + except Exception: + if raise_errors: + raise + else: + entry = self.sections[i] + out[entry] = False + # bound again in case name has changed + entry = self.sections[i] + # previous result is discarded + out[entry] = self[entry].walk( + function, + raise_errors=raise_errors, + call_on_sections=call_on_sections, + **keywargs) + return out + + + def decode(self, encoding): + """ + Decode all strings and values to unicode, using the specified encoding. + + Works with subsections and list values. + + Uses the ``walk`` method. + + Testing ``encode`` and ``decode``. + >>> m = ConfigObj(a) + >>> m.decode('ascii') + >>> def testuni(val): + ... for entry in val: + ... if not isinstance(entry, unicode): + ... print >> sys.stderr, type(entry) + ... raise AssertionError, 'decode failed.' + ... if isinstance(val[entry], dict): + ... testuni(val[entry]) + ... elif not isinstance(val[entry], unicode): + ... raise AssertionError, 'decode failed.' + >>> testuni(m) + >>> m.encode('ascii') + >>> a == m + 1 + """ + warn('use of ``decode`` is deprecated.', DeprecationWarning) + def decode(section, key, encoding=encoding, warn=True): + """ """ + val = section[key] + if isinstance(val, (list, tuple)): + newval = [] + for entry in val: + newval.append(entry.decode(encoding)) + elif isinstance(val, dict): + newval = val + else: + newval = val.decode(encoding) + newkey = key.decode(encoding) + section.rename(key, newkey) + section[newkey] = newval + # using ``call_on_sections`` allows us to modify section names + self.walk(decode, call_on_sections=True) + + + def encode(self, encoding): + """ + Encode all strings and values from unicode, + using the specified encoding. + + Works with subsections and list values. + Uses the ``walk`` method. + """ + warn('use of ``encode`` is deprecated.', DeprecationWarning) + def encode(section, key, encoding=encoding): + """ """ + val = section[key] + if isinstance(val, (list, tuple)): + newval = [] + for entry in val: + newval.append(entry.encode(encoding)) + elif isinstance(val, dict): + newval = val + else: + newval = val.encode(encoding) + newkey = key.encode(encoding) + section.rename(key, newkey) + section[newkey] = newval + self.walk(encode, call_on_sections=True) + + + def istrue(self, key): + """A deprecated version of ``as_bool``.""" + warn('use of ``istrue`` is deprecated. Use ``as_bool`` method ' + 'instead.', DeprecationWarning) + return self.as_bool(key) + + + def as_bool(self, key): + """ + Accepts a key as input. The corresponding value must be a string or + the objects (``True`` or 1) or (``False`` or 0). We allow 0 and 1 to + retain compatibility with Python 2.2. + + If the string is one of ``True``, ``On``, ``Yes``, or ``1`` it returns + ``True``. + + If the string is one of ``False``, ``Off``, ``No``, or ``0`` it returns + ``False``. + + ``as_bool`` is not case sensitive. + + Any other input will raise a ``ValueError``. + + >>> a = ConfigObj() + >>> a['a'] = 'fish' + >>> a.as_bool('a') + Traceback (most recent call last): + ValueError: Value "fish" is neither True nor False + >>> a['b'] = 'True' + >>> a.as_bool('b') + 1 + >>> a['b'] = 'off' + >>> a.as_bool('b') + 0 + """ + val = self[key] + if val == True: + return True + elif val == False: + return False + else: + try: + if not isinstance(val, StringTypes): + # TODO: Why do we raise a KeyError here? + raise KeyError() + else: + return self.main._bools[val.lower()] + except KeyError: + raise ValueError('Value "%s" is neither True nor False' % val) + + + def as_int(self, key): + """ + A convenience method which coerces the specified value to an integer. + + If the value is an invalid literal for ``int``, a ``ValueError`` will + be raised. + + >>> a = ConfigObj() + >>> a['a'] = 'fish' + >>> a.as_int('a') + Traceback (most recent call last): + ValueError: invalid literal for int(): fish + >>> a['b'] = '1' + >>> a.as_int('b') + 1 + >>> a['b'] = '3.2' + >>> a.as_int('b') + Traceback (most recent call last): + ValueError: invalid literal for int(): 3.2 + """ + return int(self[key]) + + + def as_float(self, key): + """ + A convenience method which coerces the specified value to a float. + + If the value is an invalid literal for ``float``, a ``ValueError`` will + be raised. + + >>> a = ConfigObj() + >>> a['a'] = 'fish' + >>> a.as_float('a') + Traceback (most recent call last): + ValueError: invalid literal for float(): fish + >>> a['b'] = '1' + >>> a.as_float('b') + 1.0 + >>> a['b'] = '3.2' + >>> a.as_float('b') + 3.2000000000000002 + """ + return float(self[key]) + + + def restore_default(self, key): + """ + Restore (and return) default value for the specified key. + + This method will only work for a ConfigObj that was created + with a configspec and has been validated. + + If there is no default value for this key, ``KeyError`` is raised. + """ + default = self.default_values[key] + dict.__setitem__(self, key, default) + if key not in self.defaults: + self.defaults.append(key) + return default + + + def restore_defaults(self): + """ + Recursively restore default values to all members + that have them. + + This method will only work for a ConfigObj that was created + with a configspec and has been validated. + + It doesn't delete or modify entries without default values. + """ + for key in self.default_values: + self.restore_default(key) + + for section in self.sections: + self[section].restore_defaults() + + +class ConfigObj(Section): + """An object to read, create, and write config files.""" + + _keyword = re.compile(r'''^ # line start + (\s*) # indentation + ( # keyword + (?:".*?")| # double quotes + (?:'.*?')| # single quotes + (?:[^'"=].*?) # no quotes + ) + \s*=\s* # divider + (.*) # value (including list values and comments) + $ # line end + ''', + re.VERBOSE) + + _sectionmarker = re.compile(r'''^ + (\s*) # 1: indentation + ((?:\[\s*)+) # 2: section marker open + ( # 3: section name open + (?:"\s*\S.*?\s*")| # at least one non-space with double quotes + (?:'\s*\S.*?\s*')| # at least one non-space with single quotes + (?:[^'"\s].*?) # at least one non-space unquoted + ) # section name close + ((?:\s*\])+) # 4: section marker close + \s*(\#.*)? # 5: optional comment + $''', + re.VERBOSE) + + # this regexp pulls list values out as a single string + # or single values and comments + # FIXME: this regex adds a '' to the end of comma terminated lists + # workaround in ``_handle_value`` + _valueexp = re.compile(r'''^ + (?: + (?: + ( + (?: + (?: + (?:".*?")| # double quotes + (?:'.*?')| # single quotes + (?:[^'",\#][^,\#]*?) # unquoted + ) + \s*,\s* # comma + )* # match all list items ending in a comma (if any) + ) + ( + (?:".*?")| # double quotes + (?:'.*?')| # single quotes + (?:[^'",\#\s][^,]*?)| # unquoted + (?:(? 1: + msg = "Parsing failed with several errors.\nFirst error %s" % info + error = ConfigObjError(msg) + else: + error = self._errors[0] + # set the errors attribute; it's a list of tuples: + # (error_type, message, line_number) + error.errors = self._errors + # set the config attribute + error.config = self + raise error + # delete private attributes + del self._errors + + if configspec is None: + self.configspec = None + else: + self._handle_configspec(configspec) + + + def _initialise(self, options=None): + if options is None: + options = OPTION_DEFAULTS + + # initialise a few variables + self.filename = None + self._errors = [] + self.raise_errors = options['raise_errors'] + self.interpolation = options['interpolation'] + self.list_values = options['list_values'] + self.create_empty = options['create_empty'] + self.file_error = options['file_error'] + self.stringify = options['stringify'] + self.indent_type = options['indent_type'] + self.encoding = options['encoding'] + self.default_encoding = options['default_encoding'] + self.BOM = False + self.newlines = None + self.write_empty_values = options['write_empty_values'] + self.unrepr = options['unrepr'] + + self.initial_comment = [] + self.final_comment = [] + self.configspec = {} + + # Clear section attributes as well + Section._initialise(self) + + + def __repr__(self): + return ('ConfigObj({%s})' % + ', '.join([('%s: %s' % (repr(key), repr(self[key]))) + for key in (self.scalars + self.sections)])) + + + def _handle_bom(self, infile): + """ + Handle any BOM, and decode if necessary. + + If an encoding is specified, that *must* be used - but the BOM should + still be removed (and the BOM attribute set). + + (If the encoding is wrongly specified, then a BOM for an alternative + encoding won't be discovered or removed.) + + If an encoding is not specified, UTF8 or UTF16 BOM will be detected and + removed. The BOM attribute will be set. UTF16 will be decoded to + unicode. + + NOTE: This method must not be called with an empty ``infile``. + + Specifying the *wrong* encoding is likely to cause a + ``UnicodeDecodeError``. + + ``infile`` must always be returned as a list of lines, but may be + passed in as a single string. + """ + if ((self.encoding is not None) and + (self.encoding.lower() not in BOM_LIST)): + # No need to check for a BOM + # the encoding specified doesn't have one + # just decode + return self._decode(infile, self.encoding) + + if isinstance(infile, (list, tuple)): + line = infile[0] + else: + line = infile + if self.encoding is not None: + # encoding explicitly supplied + # And it could have an associated BOM + # TODO: if encoding is just UTF16 - we ought to check for both + # TODO: big endian and little endian versions. + enc = BOM_LIST[self.encoding.lower()] + if enc == 'utf_16': + # For UTF16 we try big endian and little endian + for BOM, (encoding, final_encoding) in BOMS.items(): + if not final_encoding: + # skip UTF8 + continue + if infile.startswith(BOM): + ### BOM discovered + ##self.BOM = True + # Don't need to remove BOM + return self._decode(infile, encoding) + + # If we get this far, will *probably* raise a DecodeError + # As it doesn't appear to start with a BOM + return self._decode(infile, self.encoding) + + # Must be UTF8 + BOM = BOM_SET[enc] + if not line.startswith(BOM): + return self._decode(infile, self.encoding) + + newline = line[len(BOM):] + + # BOM removed + if isinstance(infile, (list, tuple)): + infile[0] = newline + else: + infile = newline + self.BOM = True + return self._decode(infile, self.encoding) + + # No encoding specified - so we need to check for UTF8/UTF16 + for BOM, (encoding, final_encoding) in BOMS.items(): + if not line.startswith(BOM): + continue + else: + # BOM discovered + self.encoding = final_encoding + if not final_encoding: + self.BOM = True + # UTF8 + # remove BOM + newline = line[len(BOM):] + if isinstance(infile, (list, tuple)): + infile[0] = newline + else: + infile = newline + # UTF8 - don't decode + if isinstance(infile, StringTypes): + return infile.splitlines(True) + else: + return infile + # UTF16 - have to decode + return self._decode(infile, encoding) + + # No BOM discovered and no encoding specified, just return + if isinstance(infile, StringTypes): + # infile read from a file will be a single string + return infile.splitlines(True) + return infile + + + def _a_to_u(self, aString): + """Decode ASCII strings to unicode if a self.encoding is specified.""" + if self.encoding: + return aString.decode('ascii') + else: + return aString + + + def _decode(self, infile, encoding): + """ + Decode infile to unicode. Using the specified encoding. + + if is a string, it also needs converting to a list. + """ + if isinstance(infile, StringTypes): + # can't be unicode + # NOTE: Could raise a ``UnicodeDecodeError`` + return infile.decode(encoding).splitlines(True) + for i, line in enumerate(infile): + if not isinstance(line, unicode): + # NOTE: The isinstance test here handles mixed lists of unicode/string + # NOTE: But the decode will break on any non-string values + # NOTE: Or could raise a ``UnicodeDecodeError`` + infile[i] = line.decode(encoding) + return infile + + + def _decode_element(self, line): + """Decode element to unicode if necessary.""" + if not self.encoding: + return line + if isinstance(line, str) and self.default_encoding: + return line.decode(self.default_encoding) + return line + + + def _str(self, value): + """ + Used by ``stringify`` within validate, to turn non-string values + into strings. + """ + if not isinstance(value, StringTypes): + return str(value) + else: + return value + + + def _parse(self, infile): + """Actually parse the config file.""" + temp_list_values = self.list_values + if self.unrepr: + self.list_values = False + + comment_list = [] + done_start = False + this_section = self + maxline = len(infile) - 1 + cur_index = -1 + reset_comment = False + + while cur_index < maxline: + if reset_comment: + comment_list = [] + cur_index += 1 + line = infile[cur_index] + sline = line.strip() + # do we have anything on the line ? + if not sline or sline.startswith('#'): + reset_comment = False + comment_list.append(line) + continue + + if not done_start: + # preserve initial comment + self.initial_comment = comment_list + comment_list = [] + done_start = True + + reset_comment = True + # first we check if it's a section marker + mat = self._sectionmarker.match(line) + if mat is not None: + # is a section line + (indent, sect_open, sect_name, sect_close, comment) = mat.groups() + if indent and (self.indent_type is None): + self.indent_type = indent + cur_depth = sect_open.count('[') + if cur_depth != sect_close.count(']'): + self._handle_error("Cannot compute the section depth at line %s.", + NestingError, infile, cur_index) + continue + + if cur_depth < this_section.depth: + # the new section is dropping back to a previous level + try: + parent = self._match_depth(this_section, + cur_depth).parent + except SyntaxError: + self._handle_error("Cannot compute nesting level at line %s.", + NestingError, infile, cur_index) + continue + elif cur_depth == this_section.depth: + # the new section is a sibling of the current section + parent = this_section.parent + elif cur_depth == this_section.depth + 1: + # the new section is a child the current section + parent = this_section + else: + self._handle_error("Section too nested at line %s.", + NestingError, infile, cur_index) + + sect_name = self._unquote(sect_name) + if parent.has_key(sect_name): + self._handle_error('Duplicate section name at line %s.', + DuplicateError, infile, cur_index) + continue + + # create the new section + this_section = Section( + parent, + cur_depth, + self, + name=sect_name) + parent[sect_name] = this_section + parent.inline_comments[sect_name] = comment + parent.comments[sect_name] = comment_list + continue + # + # it's not a section marker, + # so it should be a valid ``key = value`` line + mat = self._keyword.match(line) + if mat is None: + # it neither matched as a keyword + # or a section marker + self._handle_error( + 'Invalid line at line "%s".', + ParseError, infile, cur_index) + else: + # is a keyword value + # value will include any inline comment + (indent, key, value) = mat.groups() + if indent and (self.indent_type is None): + self.indent_type = indent + # check for a multiline value + if value[:3] in ['"""', "'''"]: + try: + (value, comment, cur_index) = self._multiline( + value, infile, cur_index, maxline) + except SyntaxError: + self._handle_error( + 'Parse error in value at line %s.', + ParseError, infile, cur_index) + continue + else: + if self.unrepr: + comment = '' + try: + value = unrepr(value) + except Exception, e: + if type(e) == UnknownType: + msg = 'Unknown name or type in value at line %s.' + else: + msg = 'Parse error in value at line %s.' + self._handle_error(msg, UnreprError, infile, + cur_index) + continue + else: + if self.unrepr: + comment = '' + try: + value = unrepr(value) + except Exception, e: + if isinstance(e, UnknownType): + msg = 'Unknown name or type in value at line %s.' + else: + msg = 'Parse error in value at line %s.' + self._handle_error(msg, UnreprError, infile, + cur_index) + continue + else: + # extract comment and lists + try: + (value, comment) = self._handle_value(value) + except SyntaxError: + self._handle_error( + 'Parse error in value at line %s.', + ParseError, infile, cur_index) + continue + # + key = self._unquote(key) + if this_section.has_key(key): + self._handle_error( + 'Duplicate keyword name at line %s.', + DuplicateError, infile, cur_index) + continue + # add the key. + # we set unrepr because if we have got this far we will never + # be creating a new section + this_section.__setitem__(key, value, unrepr=True) + this_section.inline_comments[key] = comment + this_section.comments[key] = comment_list + continue + # + if self.indent_type is None: + # no indentation used, set the type accordingly + self.indent_type = '' + + # preserve the final comment + if not self and not self.initial_comment: + self.initial_comment = comment_list + elif not reset_comment: + self.final_comment = comment_list + self.list_values = temp_list_values + + + def _match_depth(self, sect, depth): + """ + Given a section and a depth level, walk back through the sections + parents to see if the depth level matches a previous section. + + Return a reference to the right section, + or raise a SyntaxError. + """ + while depth < sect.depth: + if sect is sect.parent: + # we've reached the top level already + raise SyntaxError() + sect = sect.parent + if sect.depth == depth: + return sect + # shouldn't get here + raise SyntaxError() + + + def _handle_error(self, text, ErrorClass, infile, cur_index): + """ + Handle an error according to the error settings. + + Either raise the error or store it. + The error will have occured at ``cur_index`` + """ + line = infile[cur_index] + cur_index += 1 + message = text % cur_index + error = ErrorClass(message, cur_index, line) + if self.raise_errors: + # raise the error - parsing stops here + raise error + # store the error + # reraise when parsing has finished + self._errors.append(error) + + + def _unquote(self, value): + """Return an unquoted version of a value""" + if (value[0] == value[-1]) and (value[0] in ('"', "'")): + value = value[1:-1] + return value + + + def _quote(self, value, multiline=True): + """ + Return a safely quoted version of a value. + + Raise a ConfigObjError if the value cannot be safely quoted. + If multiline is ``True`` (default) then use triple quotes + if necessary. + + Don't quote values that don't need it. + Recursively quote members of a list and return a comma joined list. + Multiline is ``False`` for lists. + Obey list syntax for empty and single member lists. + + If ``list_values=False`` then the value is only quoted if it contains + a ``\n`` (is multiline) or '#'. + + If ``write_empty_values`` is set, and the value is an empty string, it + won't be quoted. + """ + if multiline and self.write_empty_values and value == '': + # Only if multiline is set, so that it is used for values not + # keys, and not values that are part of a list + return '' + + if multiline and isinstance(value, (list, tuple)): + if not value: + return ',' + elif len(value) == 1: + return self._quote(value[0], multiline=False) + ',' + return ', '.join([self._quote(val, multiline=False) + for val in value]) + if not isinstance(value, StringTypes): + if self.stringify: + value = str(value) + else: + raise TypeError('Value "%s" is not a string.' % value) + + if not value: + return '""' + + no_lists_no_quotes = not self.list_values and '\n' not in value and '#' not in value + need_triple = multiline and ((("'" in value) and ('"' in value)) or ('\n' in value )) + hash_triple_quote = multiline and not need_triple and ("'" in value) and ('"' in value) and ('#' in value) + check_for_single = (no_lists_no_quotes or not need_triple) and not hash_triple_quote + + if check_for_single: + if not self.list_values: + # we don't quote if ``list_values=False`` + quot = noquot + # for normal values either single or double quotes will do + elif '\n' in value: + # will only happen if multiline is off - e.g. '\n' in key + raise ConfigObjError('Value "%s" cannot be safely quoted.' % value) + elif ((value[0] not in wspace_plus) and + (value[-1] not in wspace_plus) and + (',' not in value)): + quot = noquot + else: + quot = self._get_single_quote(value) + else: + # if value has '\n' or "'" *and* '"', it will need triple quotes + quot = self._get_triple_quote(value) + + if quot == noquot and '#' in value and self.list_values: + quot = self._get_single_quote(value) + + return quot % value + + + def _get_single_quote(self, value): + if ("'" in value) and ('"' in value): + raise ConfigObjError('Value "%s" cannot be safely quoted.' % value) + elif '"' in value: + quot = squot + else: + quot = dquot + return quot + + + def _get_triple_quote(self, value): + if (value.find('"""') != -1) and (value.find("'''") != -1): + raise ConfigObjError('Value "%s" cannot be safely quoted.' % value) + if value.find('"""') == -1: + quot = tdquot + else: + quot = tsquot + return quot + + + def _handle_value(self, value): + """ + Given a value string, unquote, remove comment, + handle lists. (including empty and single member lists) + """ + # do we look for lists in values ? + if not self.list_values: + mat = self._nolistvalue.match(value) + if mat is None: + raise SyntaxError() + # NOTE: we don't unquote here + return mat.groups() + # + mat = self._valueexp.match(value) + if mat is None: + # the value is badly constructed, probably badly quoted, + # or an invalid list + raise SyntaxError() + (list_values, single, empty_list, comment) = mat.groups() + if (list_values == '') and (single is None): + # change this if you want to accept empty values + raise SyntaxError() + # NOTE: note there is no error handling from here if the regex + # is wrong: then incorrect values will slip through + if empty_list is not None: + # the single comma - meaning an empty list + return ([], comment) + if single is not None: + # handle empty values + if list_values and not single: + # FIXME: the '' is a workaround because our regex now matches + # '' at the end of a list if it has a trailing comma + single = None + else: + single = single or '""' + single = self._unquote(single) + if list_values == '': + # not a list value + return (single, comment) + the_list = self._listvalueexp.findall(list_values) + the_list = [self._unquote(val) for val in the_list] + if single is not None: + the_list += [single] + return (the_list, comment) + + + def _multiline(self, value, infile, cur_index, maxline): + """Extract the value, where we are in a multiline situation.""" + quot = value[:3] + newvalue = value[3:] + single_line = self._triple_quote[quot][0] + multi_line = self._triple_quote[quot][1] + mat = single_line.match(value) + if mat is not None: + retval = list(mat.groups()) + retval.append(cur_index) + return retval + elif newvalue.find(quot) != -1: + # somehow the triple quote is missing + raise SyntaxError() + # + while cur_index < maxline: + cur_index += 1 + newvalue += '\n' + line = infile[cur_index] + if line.find(quot) == -1: + newvalue += line + else: + # end of multiline, process it + break + else: + # we've got to the end of the config, oops... + raise SyntaxError() + mat = multi_line.match(line) + if mat is None: + # a badly formed line + raise SyntaxError() + (value, comment) = mat.groups() + return (newvalue + value, comment, cur_index) + + + def _handle_configspec(self, configspec): + """Parse the configspec.""" + # FIXME: Should we check that the configspec was created with the + # correct settings ? (i.e. ``list_values=False``) + if not isinstance(configspec, ConfigObj): + try: + configspec = ConfigObj(configspec, + raise_errors=True, + file_error=True, + list_values=False) + except ConfigObjError, e: + # FIXME: Should these errors have a reference + # to the already parsed ConfigObj ? + raise ConfigspecError('Parsing configspec failed: %s' % e) + except IOError, e: + raise IOError('Reading configspec failed: %s' % e) + + self._set_configspec_value(configspec, self) + + + def _set_configspec_value(self, configspec, section): + """Used to recursively set configspec values.""" + if '__many__' in configspec.sections: + section.configspec['__many__'] = configspec['__many__'] + if len(configspec.sections) > 1: + # FIXME: can we supply any useful information here ? + raise RepeatSectionError() + + if hasattr(configspec, 'initial_comment'): + section._configspec_initial_comment = configspec.initial_comment + section._configspec_final_comment = configspec.final_comment + section._configspec_encoding = configspec.encoding + section._configspec_BOM = configspec.BOM + section._configspec_newlines = configspec.newlines + section._configspec_indent_type = configspec.indent_type + + for entry in configspec.scalars: + section._configspec_comments[entry] = configspec.comments[entry] + section._configspec_inline_comments[entry] = configspec.inline_comments[entry] + section.configspec[entry] = configspec[entry] + section._order.append(entry) + + for entry in configspec.sections: + if entry == '__many__': + continue + + section._cs_section_comments[entry] = configspec.comments[entry] + section._cs_section_inline_comments[entry] = configspec.inline_comments[entry] + if not section.has_key(entry): + section[entry] = {} + self._set_configspec_value(configspec[entry], section[entry]) + + + def _handle_repeat(self, section, configspec): + """Dynamically assign configspec for repeated section.""" + try: + section_keys = configspec.sections + scalar_keys = configspec.scalars + except AttributeError: + section_keys = [entry for entry in configspec + if isinstance(configspec[entry], dict)] + scalar_keys = [entry for entry in configspec + if not isinstance(configspec[entry], dict)] + + if '__many__' in section_keys and len(section_keys) > 1: + # FIXME: can we supply any useful information here ? + raise RepeatSectionError() + + scalars = {} + sections = {} + for entry in scalar_keys: + val = configspec[entry] + scalars[entry] = val + for entry in section_keys: + val = configspec[entry] + if entry == '__many__': + scalars[entry] = val + continue + sections[entry] = val + + section.configspec = scalars + for entry in sections: + if not section.has_key(entry): + section[entry] = {} + self._handle_repeat(section[entry], sections[entry]) + + + def _write_line(self, indent_string, entry, this_entry, comment): + """Write an individual line, for the write method""" + # NOTE: the calls to self._quote here handles non-StringType values. + if not self.unrepr: + val = self._decode_element(self._quote(this_entry)) + else: + val = repr(this_entry) + return '%s%s%s%s%s' % (indent_string, + self._decode_element(self._quote(entry, multiline=False)), + self._a_to_u(' = '), + val, + self._decode_element(comment)) + + + def _write_marker(self, indent_string, depth, entry, comment): + """Write a section marker line""" + return '%s%s%s%s%s' % (indent_string, + self._a_to_u('[' * depth), + self._quote(self._decode_element(entry), multiline=False), + self._a_to_u(']' * depth), + self._decode_element(comment)) + + + def _handle_comment(self, comment): + """Deal with a comment.""" + if not comment: + return '' + start = self.indent_type + if not comment.startswith('#'): + start += self._a_to_u(' # ') + return (start + comment) + + + # Public methods + + def write(self, outfile=None, section=None): + """ + Write the current ConfigObj as a file + + tekNico: FIXME: use StringIO instead of real files + + >>> filename = a.filename + >>> a.filename = 'test.ini' + >>> a.write() + >>> a.filename = filename + >>> a == ConfigObj('test.ini', raise_errors=True) + 1 + """ + if self.indent_type is None: + # this can be true if initialised from a dictionary + self.indent_type = DEFAULT_INDENT_TYPE + + out = [] + cs = self._a_to_u('#') + csp = self._a_to_u('# ') + if section is None: + int_val = self.interpolation + self.interpolation = False + section = self + for line in self.initial_comment: + line = self._decode_element(line) + stripped_line = line.strip() + if stripped_line and not stripped_line.startswith(cs): + line = csp + line + out.append(line) + + indent_string = self.indent_type * section.depth + for entry in (section.scalars + section.sections): + if entry in section.defaults: + # don't write out default values + continue + for comment_line in section.comments[entry]: + comment_line = self._decode_element(comment_line.lstrip()) + if comment_line and not comment_line.startswith(cs): + comment_line = csp + comment_line + out.append(indent_string + comment_line) + this_entry = section[entry] + comment = self._handle_comment(section.inline_comments[entry]) + + if isinstance(this_entry, dict): + # a section + out.append(self._write_marker( + indent_string, + this_entry.depth, + entry, + comment)) + out.extend(self.write(section=this_entry)) + else: + out.append(self._write_line( + indent_string, + entry, + this_entry, + comment)) + + if section is self: + for line in self.final_comment: + line = self._decode_element(line) + stripped_line = line.strip() + if stripped_line and not stripped_line.startswith(cs): + line = csp + line + out.append(line) + self.interpolation = int_val + + if section is not self: + return out + + if (self.filename is None) and (outfile is None): + # output a list of lines + # might need to encode + # NOTE: This will *screw* UTF16, each line will start with the BOM + if self.encoding: + out = [l.encode(self.encoding) for l in out] + if (self.BOM and ((self.encoding is None) or + (BOM_LIST.get(self.encoding.lower()) == 'utf_8'))): + # Add the UTF8 BOM + if not out: + out.append('') + out[0] = BOM_UTF8 + out[0] + return out + + # Turn the list to a string, joined with correct newlines + newline = self.newlines or os.linesep + output = self._a_to_u(newline).join(out) + if self.encoding: + output = output.encode(self.encoding) + if self.BOM and ((self.encoding is None) or match_utf8(self.encoding)): + # Add the UTF8 BOM + output = BOM_UTF8 + output + + if not output.endswith(newline): + output += newline + if outfile is not None: + outfile.write(output) + else: + h = open(self.filename, 'wb') + h.write(output) + h.close() + + + def validate(self, validator, preserve_errors=False, copy=False, + section=None): + """ + Test the ConfigObj against a configspec. + + It uses the ``validator`` object from *validate.py*. + + To run ``validate`` on the current ConfigObj, call: :: + + test = config.validate(validator) + + (Normally having previously passed in the configspec when the ConfigObj + was created - you can dynamically assign a dictionary of checks to the + ``configspec`` attribute of a section though). + + It returns ``True`` if everything passes, or a dictionary of + pass/fails (True/False). If every member of a subsection passes, it + will just have the value ``True``. (It also returns ``False`` if all + members fail). + + In addition, it converts the values from strings to their native + types if their checks pass (and ``stringify`` is set). + + If ``preserve_errors`` is ``True`` (``False`` is default) then instead + of a marking a fail with a ``False``, it will preserve the actual + exception object. This can contain info about the reason for failure. + For example the ``VdtValueTooSmallError`` indicates that the value + supplied was too small. If a value (or section) is missing it will + still be marked as ``False``. + + You must have the validate module to use ``preserve_errors=True``. + + You can then use the ``flatten_errors`` function to turn your nested + results dictionary into a flattened list of failures - useful for + displaying meaningful error messages. + """ + if section is None: + if self.configspec is None: + raise ValueError('No configspec supplied.') + if preserve_errors: + # We do this once to remove a top level dependency on the validate module + # Which makes importing configobj faster + from validate import VdtMissingValue + self._vdtMissingValue = VdtMissingValue + section = self + # + spec_section = section.configspec + if copy and hasattr(section, '_configspec_initial_comment'): + section.initial_comment = section._configspec_initial_comment + section.final_comment = section._configspec_final_comment + section.encoding = section._configspec_encoding + section.BOM = section._configspec_BOM + section.newlines = section._configspec_newlines + section.indent_type = section._configspec_indent_type + + if '__many__' in section.configspec: + many = spec_section['__many__'] + # dynamically assign the configspecs + # for the sections below + for entry in section.sections: + self._handle_repeat(section[entry], many) + # + out = {} + ret_true = True + ret_false = True + order = [k for k in section._order if k in spec_section] + order += [k for k in spec_section if k not in order] + for entry in order: + if entry == '__many__': + continue + if (not entry in section.scalars) or (entry in section.defaults): + # missing entries + # or entries from defaults + missing = True + val = None + if copy and not entry in section.scalars: + # copy comments + section.comments[entry] = ( + section._configspec_comments.get(entry, [])) + section.inline_comments[entry] = ( + section._configspec_inline_comments.get(entry, '')) + # + else: + missing = False + val = section[entry] + try: + check = validator.check(spec_section[entry], + val, + missing=missing + ) + except validator.baseErrorClass, e: + if not preserve_errors or isinstance(e, self._vdtMissingValue): + out[entry] = False + else: + # preserve the error + out[entry] = e + ret_false = False + ret_true = False + else: + try: + section.default_values.pop(entry, None) + except AttributeError: + # For Python 2.2 compatibility + try: + del section.default_values[entry] + except KeyError: + pass + + if hasattr(validator, 'get_default_value'): + try: + section.default_values[entry] = validator.get_default_value(spec_section[entry]) + except KeyError: + # No default + pass + + ret_false = False + out[entry] = True + if self.stringify or missing: + # if we are doing type conversion + # or the value is a supplied default + if not self.stringify: + if isinstance(check, (list, tuple)): + # preserve lists + check = [self._str(item) for item in check] + elif missing and check is None: + # convert the None from a default to a '' + check = '' + else: + check = self._str(check) + if (check != val) or missing: + section[entry] = check + if not copy and missing and entry not in section.defaults: + section.defaults.append(entry) + # Missing sections will have been created as empty ones when the + # configspec was read. + for entry in section.sections: + # FIXME: this means DEFAULT is not copied in copy mode + if section is self and entry == 'DEFAULT': + continue + if copy: + section.comments[entry] = section._cs_section_comments[entry] + section.inline_comments[entry] = ( + section._cs_section_inline_comments[entry]) + check = self.validate(validator, preserve_errors=preserve_errors, + copy=copy, section=section[entry]) + out[entry] = check + if check == False: + ret_true = False + elif check == True: + ret_false = False + else: + ret_true = False + ret_false = False + # + if ret_true: + return True + elif ret_false: + return False + return out + + + def reset(self): + """Clear ConfigObj instance and restore to 'freshly created' state.""" + self.clear() + self._initialise() + # FIXME: Should be done by '_initialise', but ConfigObj constructor (and reload) + # requires an empty dictionary + self.configspec = None + # Just to be sure ;-) + self._original_configspec = None + + + def reload(self): + """ + Reload a ConfigObj from file. + + This method raises a ``ReloadError`` if the ConfigObj doesn't have + a filename attribute pointing to a file. + """ + if not isinstance(self.filename, StringTypes): + raise ReloadError() + + filename = self.filename + current_options = {} + for entry in OPTION_DEFAULTS: + if entry == 'configspec': + continue + current_options[entry] = getattr(self, entry) + + configspec = self._original_configspec + current_options['configspec'] = configspec + + self.clear() + self._initialise(current_options) + self._load(filename, configspec) + + + +class SimpleVal(object): + """ + A simple validator. + Can be used to check that all members expected are present. + + To use it, provide a configspec with all your members in (the value given + will be ignored). Pass an instance of ``SimpleVal`` to the ``validate`` + method of your ``ConfigObj``. ``validate`` will return ``True`` if all + members are present, or a dictionary with True/False meaning + present/missing. (Whole missing sections will be replaced with ``False``) + """ + + def __init__(self): + self.baseErrorClass = ConfigObjError + + def check(self, check, member, missing=False): + """A dummy check method, always returns the value unchanged.""" + if missing: + raise self.baseErrorClass() + return member + + +# Check / processing functions for options +def flatten_errors(cfg, res, levels=None, results=None): + """ + An example function that will turn a nested dictionary of results + (as returned by ``ConfigObj.validate``) into a flat list. + + ``cfg`` is the ConfigObj instance being checked, ``res`` is the results + dictionary returned by ``validate``. + + (This is a recursive function, so you shouldn't use the ``levels`` or + ``results`` arguments - they are used by the function. + + Returns a list of keys that failed. Each member of the list is a tuple : + :: + + ([list of sections...], key, result) + + If ``validate`` was called with ``preserve_errors=False`` (the default) + then ``result`` will always be ``False``. + + *list of sections* is a flattened list of sections that the key was found + in. + + If the section was missing then key will be ``None``. + + If the value (or section) was missing then ``result`` will be ``False``. + + If ``validate`` was called with ``preserve_errors=True`` and a value + was present, but failed the check, then ``result`` will be the exception + object returned. You can use this as a string that describes the failure. + + For example *The value "3" is of the wrong type*. + + >>> import validate + >>> vtor = validate.Validator() + >>> my_ini = ''' + ... option1 = True + ... [section1] + ... option1 = True + ... [section2] + ... another_option = Probably + ... [section3] + ... another_option = True + ... [[section3b]] + ... value = 3 + ... value2 = a + ... value3 = 11 + ... ''' + >>> my_cfg = ''' + ... option1 = boolean() + ... option2 = boolean() + ... option3 = boolean(default=Bad_value) + ... [section1] + ... option1 = boolean() + ... option2 = boolean() + ... option3 = boolean(default=Bad_value) + ... [section2] + ... another_option = boolean() + ... [section3] + ... another_option = boolean() + ... [[section3b]] + ... value = integer + ... value2 = integer + ... value3 = integer(0, 10) + ... [[[section3b-sub]]] + ... value = string + ... [section4] + ... another_option = boolean() + ... ''' + >>> cs = my_cfg.split('\\n') + >>> ini = my_ini.split('\\n') + >>> cfg = ConfigObj(ini, configspec=cs) + >>> res = cfg.validate(vtor, preserve_errors=True) + >>> errors = [] + >>> for entry in flatten_errors(cfg, res): + ... section_list, key, error = entry + ... section_list.insert(0, '[root]') + ... if key is not None: + ... section_list.append(key) + ... else: + ... section_list.append('[missing]') + ... section_string = ', '.join(section_list) + ... errors.append((section_string, ' = ', error)) + >>> errors.sort() + >>> for entry in errors: + ... print entry[0], entry[1], (entry[2] or 0) + [root], option2 = 0 + [root], option3 = the value "Bad_value" is of the wrong type. + [root], section1, option2 = 0 + [root], section1, option3 = the value "Bad_value" is of the wrong type. + [root], section2, another_option = the value "Probably" is of the wrong type. + [root], section3, section3b, section3b-sub, [missing] = 0 + [root], section3, section3b, value2 = the value "a" is of the wrong type. + [root], section3, section3b, value3 = the value "11" is too big. + [root], section4, [missing] = 0 + """ + if levels is None: + # first time called + levels = [] + results = [] + if res is True: + return results + if res is False: + results.append((levels[:], None, False)) + if levels: + levels.pop() + return results + for (key, val) in res.items(): + if val == True: + continue + if isinstance(cfg.get(key), dict): + # Go down one level + levels.append(key) + flatten_errors(cfg[key], val, levels, results) + continue + results.append((levels[:], key, val)) + # + # Go up one level + if levels: + levels.pop() + # + return results + + +"""*A programming language is a medium of expression.* - Paul Graham""" diff --git a/IPython/external/guid.py b/IPython/external/guid.py new file mode 100644 index 0000000..da1a226 --- /dev/null +++ b/IPython/external/guid.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python +# encoding: utf-8 + +# GUID.py +# Version 2.6 +# +# Copyright (c) 2006 Conan C. Albrecht +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is furnished +# to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR +# PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE +# FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR +# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +# DEALINGS IN THE SOFTWARE. + + + +################################################################################################## +### A globally-unique identifier made up of time and ip and 8 digits for a counter: +### each GUID is 40 characters wide +### +### A globally unique identifier that combines ip, time, and a counter. Since the +### time is listed first, you can sort records by guid. You can also extract the time +### and ip if needed. +### +### Since the counter has eight hex characters, you can create up to +### 0xffffffff (4294967295) GUIDs every millisecond. If your processor +### is somehow fast enough to create more than that in a millisecond (looking +### toward the future, of course), the function will wait until the next +### millisecond to return. +### +### GUIDs make wonderful database keys. They require no access to the +### database (to get the max index number), they are extremely unique, and they sort +### automatically by time. GUIDs prevent key clashes when merging +### two databases together, combining data, or generating keys in distributed +### systems. +### +### There is an Internet Draft for UUIDs, but this module does not implement it. +### If the draft catches on, perhaps I'll conform the module to it. +### + + +# Changelog +# Sometime, 1997 Created the Java version of GUID +# Went through many versions in Java +# Sometime, 2002 Created the Python version of GUID, mirroring the Java version +# November 24, 2003 Changed Python version to be more pythonic, took out object and made just a module +# December 2, 2003 Fixed duplicating GUIDs. Sometimes they duplicate if multiples are created +# in the same millisecond (it checks the last 100 GUIDs now and has a larger random part) +# December 9, 2003 Fixed MAX_RANDOM, which was going over sys.maxint +# June 12, 2004 Allowed a custom IP address to be sent in rather than always using the +# local IP address. +# November 4, 2005 Changed the random part to a counter variable. Now GUIDs are totally +# unique and more efficient, as long as they are created by only +# on runtime on a given machine. The counter part is after the time +# part so it sorts correctly. +# November 8, 2005 The counter variable now starts at a random long now and cycles +# around. This is in case two guids are created on the same +# machine at the same millisecond (by different processes). Even though +# it is possible the GUID can be created, this makes it highly unlikely +# since the counter will likely be different. +# November 11, 2005 Fixed a bug in the new IP getting algorithm. Also, use IPv6 range +# for IP when we make it up (when it's no accessible) +# November 21, 2005 Added better IP-finding code. It finds IP address better now. +# January 5, 2006 Fixed a small bug caused in old versions of python (random module use) + +import math +import socket +import random +import sys +import time +import threading + + + +############################# +### global module variables + +#Makes a hex IP from a decimal dot-separated ip (eg: 127.0.0.1) +make_hexip = lambda ip: ''.join(["%04x" % long(i) for i in ip.split('.')]) # leave space for ip v6 (65K in each sub) + +MAX_COUNTER = 0xfffffffe +counter = 0L +firstcounter = MAX_COUNTER +lasttime = 0 +ip = '' +lock = threading.RLock() +try: # only need to get the IP addresss once + ip = socket.getaddrinfo(socket.gethostname(),0)[-1][-1][0] + hexip = make_hexip(ip) +except: # if we don't have an ip, default to someting in the 10.x.x.x private range + ip = '10' + rand = random.Random() + for i in range(3): + ip += '.' + str(rand.randrange(1, 0xffff)) # might as well use IPv6 range if we're making it up + hexip = make_hexip(ip) + + +################################# +### Public module functions + +def generate(ip=None): + '''Generates a new guid. A guid is unique in space and time because it combines + the machine IP with the current time in milliseconds. Be careful about sending in + a specified IP address because the ip makes it unique in space. You could send in + the same IP address that is created on another machine. + ''' + global counter, firstcounter, lasttime + lock.acquire() # can't generate two guids at the same time + try: + parts = [] + + # do we need to wait for the next millisecond (are we out of counters?) + now = long(time.time() * 1000) + while lasttime == now and counter == firstcounter: + time.sleep(.01) + now = long(time.time() * 1000) + + # time part + parts.append("%016x" % now) + + # counter part + if lasttime != now: # time to start counter over since we have a different millisecond + firstcounter = long(random.uniform(1, MAX_COUNTER)) # start at random position + counter = firstcounter + counter += 1 + if counter > MAX_COUNTER: + counter = 0 + lasttime = now + parts.append("%08x" % (counter)) + + # ip part + parts.append(hexip) + + # put them all together + return ''.join(parts) + finally: + lock.release() + + +def extract_time(guid): + '''Extracts the time portion out of the guid and returns the + number of seconds since the epoch as a float''' + return float(long(guid[0:16], 16)) / 1000.0 + + +def extract_counter(guid): + '''Extracts the counter from the guid (returns the bits in decimal)''' + return int(guid[16:24], 16) + + +def extract_ip(guid): + '''Extracts the ip portion out of the guid and returns it + as a string like 10.10.10.10''' + # there's probably a more elegant way to do this + thisip = [] + for index in range(24, 40, 4): + thisip.append(str(int(guid[index: index + 4], 16))) + return '.'.join(thisip) + diff --git a/IPython/external/validate.py b/IPython/external/validate.py new file mode 100644 index 0000000..764b18a --- /dev/null +++ b/IPython/external/validate.py @@ -0,0 +1,1414 @@ +# validate.py +# A Validator object +# Copyright (C) 2005 Michael Foord, Mark Andrews, Nicola Larosa +# E-mail: fuzzyman AT voidspace DOT org DOT uk +# mark AT la-la DOT com +# nico AT tekNico DOT net + +# This software is licensed under the terms of the BSD license. +# http://www.voidspace.org.uk/python/license.shtml +# Basically you're free to copy, modify, distribute and relicense it, +# So long as you keep a copy of the license with it. + +# Scripts maintained at http://www.voidspace.org.uk/python/index.shtml +# For information about bugfixes, updates and support, please join the +# ConfigObj mailing list: +# http://lists.sourceforge.net/lists/listinfo/configobj-develop +# Comments, suggestions and bug reports welcome. + +""" + The Validator object is used to check that supplied values + conform to a specification. + + The value can be supplied as a string - e.g. from a config file. + In this case the check will also *convert* the value to + the required type. This allows you to add validation + as a transparent layer to access data stored as strings. + The validation checks that the data is correct *and* + converts it to the expected type. + + Some standard checks are provided for basic data types. + Additional checks are easy to write. They can be + provided when the ``Validator`` is instantiated or + added afterwards. + + The standard functions work with the following basic data types : + + * integers + * floats + * booleans + * strings + * ip_addr + + plus lists of these datatypes + + Adding additional checks is done through coding simple functions. + + The full set of standard checks are : + + * 'integer': matches integer values (including negative) + Takes optional 'min' and 'max' arguments : :: + + integer() + integer(3, 9) # any value from 3 to 9 + integer(min=0) # any positive value + integer(max=9) + + * 'float': matches float values + Has the same parameters as the integer check. + + * 'boolean': matches boolean values - ``True`` or ``False`` + Acceptable string values for True are : + true, on, yes, 1 + Acceptable string values for False are : + false, off, no, 0 + + Any other value raises an error. + + * 'ip_addr': matches an Internet Protocol address, v.4, represented + by a dotted-quad string, i.e. '1.2.3.4'. + + * 'string': matches any string. + Takes optional keyword args 'min' and 'max' + to specify min and max lengths of the string. + + * 'list': matches any list. + Takes optional keyword args 'min', and 'max' to specify min and + max sizes of the list. (Always returns a list.) + + * 'tuple': matches any tuple. + Takes optional keyword args 'min', and 'max' to specify min and + max sizes of the tuple. (Always returns a tuple.) + + * 'int_list': Matches a list of integers. + Takes the same arguments as list. + + * 'float_list': Matches a list of floats. + Takes the same arguments as list. + + * 'bool_list': Matches a list of boolean values. + Takes the same arguments as list. + + * 'ip_addr_list': Matches a list of IP addresses. + Takes the same arguments as list. + + * 'string_list': Matches a list of strings. + Takes the same arguments as list. + + * 'mixed_list': Matches a list with different types in + specific positions. List size must match + the number of arguments. + + Each position can be one of : + 'integer', 'float', 'ip_addr', 'string', 'boolean' + + So to specify a list with two strings followed + by two integers, you write the check as : :: + + mixed_list('string', 'string', 'integer', 'integer') + + * 'pass': This check matches everything ! It never fails + and the value is unchanged. + + It is also the default if no check is specified. + + * 'option': This check matches any from a list of options. + You specify this check with : :: + + option('option 1', 'option 2', 'option 3') + + You can supply a default value (returned if no value is supplied) + using the default keyword argument. + + You specify a list argument for default using a list constructor syntax in + the check : :: + + checkname(arg1, arg2, default=list('val 1', 'val 2', 'val 3')) + + A badly formatted set of arguments will raise a ``VdtParamError``. +""" + +__docformat__ = "restructuredtext en" + +__version__ = '0.3.2' + +__revision__ = '$Id: validate.py 123 2005-09-08 08:54:28Z fuzzyman $' + +__all__ = ( + '__version__', + 'dottedQuadToNum', + 'numToDottedQuad', + 'ValidateError', + 'VdtUnknownCheckError', + 'VdtParamError', + 'VdtTypeError', + 'VdtValueError', + 'VdtValueTooSmallError', + 'VdtValueTooBigError', + 'VdtValueTooShortError', + 'VdtValueTooLongError', + 'VdtMissingValue', + 'Validator', + 'is_integer', + 'is_float', + 'is_boolean', + 'is_list', + 'is_tuple', + 'is_ip_addr', + 'is_string', + 'is_int_list', + 'is_bool_list', + 'is_float_list', + 'is_string_list', + 'is_ip_addr_list', + 'is_mixed_list', + 'is_option', + '__docformat__', +) + + +import sys +INTP_VER = sys.version_info[:2] +if INTP_VER < (2, 2): + raise RuntimeError("Python v.2.2 or later needed") + +import re +StringTypes = (str, unicode) + + +_list_arg = re.compile(r''' + (?: + ([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*list\( + ( + (?: + \s* + (?: + (?:".*?")| # double quotes + (?:'.*?')| # single quotes + (?:[^'",\s\)][^,\)]*?) # unquoted + ) + \s*,\s* + )* + (?: + (?:".*?")| # double quotes + (?:'.*?')| # single quotes + (?:[^'",\s\)][^,\)]*?) # unquoted + )? # last one + ) + \) + ) +''', re.VERBOSE) # two groups + +_list_members = re.compile(r''' + ( + (?:".*?")| # double quotes + (?:'.*?')| # single quotes + (?:[^'",\s=][^,=]*?) # unquoted + ) + (?: + (?:\s*,\s*)|(?:\s*$) # comma + ) +''', re.VERBOSE) # one group + +_paramstring = r''' + (?: + ( + (?: + [a-zA-Z_][a-zA-Z0-9_]*\s*=\s*list\( + (?: + \s* + (?: + (?:".*?")| # double quotes + (?:'.*?')| # single quotes + (?:[^'",\s\)][^,\)]*?) # unquoted + ) + \s*,\s* + )* + (?: + (?:".*?")| # double quotes + (?:'.*?')| # single quotes + (?:[^'",\s\)][^,\)]*?) # unquoted + )? # last one + \) + )| + (?: + (?:".*?")| # double quotes + (?:'.*?')| # single quotes + (?:[^'",\s=][^,=]*?)| # unquoted + (?: # keyword argument + [a-zA-Z_][a-zA-Z0-9_]*\s*=\s* + (?: + (?:".*?")| # double quotes + (?:'.*?')| # single quotes + (?:[^'",\s=][^,=]*?) # unquoted + ) + ) + ) + ) + (?: + (?:\s*,\s*)|(?:\s*$) # comma + ) + ) + ''' + +_matchstring = '^%s*' % _paramstring + +# Python pre 2.2.1 doesn't have bool +try: + bool +except NameError: + def bool(val): + """Simple boolean equivalent function. """ + if val: + return 1 + else: + return 0 + + +def dottedQuadToNum(ip): + """ + Convert decimal dotted quad string to long integer + + >>> dottedQuadToNum('1 ') + 1L + >>> dottedQuadToNum(' 1.2') + 16777218L + >>> dottedQuadToNum(' 1.2.3 ') + 16908291L + >>> dottedQuadToNum('1.2.3.4') + 16909060L + >>> dottedQuadToNum('1.2.3. 4') + Traceback (most recent call last): + ValueError: Not a good dotted-quad IP: 1.2.3. 4 + >>> dottedQuadToNum('255.255.255.255') + 4294967295L + >>> dottedQuadToNum('255.255.255.256') + Traceback (most recent call last): + ValueError: Not a good dotted-quad IP: 255.255.255.256 + """ + + # import here to avoid it when ip_addr values are not used + import socket, struct + + try: + return struct.unpack('!L', + socket.inet_aton(ip.strip()))[0] + except socket.error: + # bug in inet_aton, corrected in Python 2.3 + if ip.strip() == '255.255.255.255': + return 0xFFFFFFFFL + else: + raise ValueError('Not a good dotted-quad IP: %s' % ip) + return + + +def numToDottedQuad(num): + """ + Convert long int to dotted quad string + + >>> numToDottedQuad(-1L) + Traceback (most recent call last): + ValueError: Not a good numeric IP: -1 + >>> numToDottedQuad(1L) + '0.0.0.1' + >>> numToDottedQuad(16777218L) + '1.0.0.2' + >>> numToDottedQuad(16908291L) + '1.2.0.3' + >>> numToDottedQuad(16909060L) + '1.2.3.4' + >>> numToDottedQuad(4294967295L) + '255.255.255.255' + >>> numToDottedQuad(4294967296L) + Traceback (most recent call last): + ValueError: Not a good numeric IP: 4294967296 + """ + + # import here to avoid it when ip_addr values are not used + import socket, struct + + # no need to intercept here, 4294967295L is fine + try: + return socket.inet_ntoa( + struct.pack('!L', long(num))) + except (socket.error, struct.error, OverflowError): + raise ValueError('Not a good numeric IP: %s' % num) + + +class ValidateError(Exception): + """ + This error indicates that the check failed. + It can be the base class for more specific errors. + + Any check function that fails ought to raise this error. + (or a subclass) + + >>> raise ValidateError + Traceback (most recent call last): + ValidateError + """ + + +class VdtMissingValue(ValidateError): + """No value was supplied to a check that needed one.""" + + +class VdtUnknownCheckError(ValidateError): + """An unknown check function was requested""" + + def __init__(self, value): + """ + >>> raise VdtUnknownCheckError('yoda') + Traceback (most recent call last): + VdtUnknownCheckError: the check "yoda" is unknown. + """ + ValidateError.__init__(self, 'the check "%s" is unknown.' % (value,)) + + +class VdtParamError(SyntaxError): + """An incorrect parameter was passed""" + + def __init__(self, name, value): + """ + >>> raise VdtParamError('yoda', 'jedi') + Traceback (most recent call last): + VdtParamError: passed an incorrect value "jedi" for parameter "yoda". + """ + SyntaxError.__init__(self, 'passed an incorrect value "%s" for parameter "%s".' % (value, name)) + + +class VdtTypeError(ValidateError): + """The value supplied was of the wrong type""" + + def __init__(self, value): + """ + >>> raise VdtTypeError('jedi') + Traceback (most recent call last): + VdtTypeError: the value "jedi" is of the wrong type. + """ + ValidateError.__init__(self, 'the value "%s" is of the wrong type.' % (value,)) + + +class VdtValueError(ValidateError): + """The value supplied was of the correct type, but was not an allowed value.""" + + def __init__(self, value): + """ + >>> raise VdtValueError('jedi') + Traceback (most recent call last): + VdtValueError: the value "jedi" is unacceptable. + """ + ValidateError.__init__(self, 'the value "%s" is unacceptable.' % (value,)) + + +class VdtValueTooSmallError(VdtValueError): + """The value supplied was of the correct type, but was too small.""" + + def __init__(self, value): + """ + >>> raise VdtValueTooSmallError('0') + Traceback (most recent call last): + VdtValueTooSmallError: the value "0" is too small. + """ + ValidateError.__init__(self, 'the value "%s" is too small.' % (value,)) + + +class VdtValueTooBigError(VdtValueError): + """The value supplied was of the correct type, but was too big.""" + + def __init__(self, value): + """ + >>> raise VdtValueTooBigError('1') + Traceback (most recent call last): + VdtValueTooBigError: the value "1" is too big. + """ + ValidateError.__init__(self, 'the value "%s" is too big.' % (value,)) + + +class VdtValueTooShortError(VdtValueError): + """The value supplied was of the correct type, but was too short.""" + + def __init__(self, value): + """ + >>> raise VdtValueTooShortError('jed') + Traceback (most recent call last): + VdtValueTooShortError: the value "jed" is too short. + """ + ValidateError.__init__( + self, + 'the value "%s" is too short.' % (value,)) + + +class VdtValueTooLongError(VdtValueError): + """The value supplied was of the correct type, but was too long.""" + + def __init__(self, value): + """ + >>> raise VdtValueTooLongError('jedie') + Traceback (most recent call last): + VdtValueTooLongError: the value "jedie" is too long. + """ + ValidateError.__init__(self, 'the value "%s" is too long.' % (value,)) + + +class Validator(object): + """ + Validator is an object that allows you to register a set of 'checks'. + These checks take input and test that it conforms to the check. + + This can also involve converting the value from a string into + the correct datatype. + + The ``check`` method takes an input string which configures which + check is to be used and applies that check to a supplied value. + + An example input string would be: + 'int_range(param1, param2)' + + You would then provide something like: + + >>> def int_range_check(value, min, max): + ... # turn min and max from strings to integers + ... min = int(min) + ... max = int(max) + ... # check that value is of the correct type. + ... # possible valid inputs are integers or strings + ... # that represent integers + ... if not isinstance(value, (int, long, StringTypes)): + ... raise VdtTypeError(value) + ... elif isinstance(value, StringTypes): + ... # if we are given a string + ... # attempt to convert to an integer + ... try: + ... value = int(value) + ... except ValueError: + ... raise VdtValueError(value) + ... # check the value is between our constraints + ... if not min <= value: + ... raise VdtValueTooSmallError(value) + ... if not value <= max: + ... raise VdtValueTooBigError(value) + ... return value + + >>> fdict = {'int_range': int_range_check} + >>> vtr1 = Validator(fdict) + >>> vtr1.check('int_range(20, 40)', '30') + 30 + >>> vtr1.check('int_range(20, 40)', '60') + Traceback (most recent call last): + VdtValueTooBigError: the value "60" is too big. + + New functions can be added with : :: + + >>> vtr2 = Validator() + >>> vtr2.functions['int_range'] = int_range_check + + Or by passing in a dictionary of functions when Validator + is instantiated. + + Your functions *can* use keyword arguments, + but the first argument should always be 'value'. + + If the function doesn't take additional arguments, + the parentheses are optional in the check. + It can be written with either of : :: + + keyword = function_name + keyword = function_name() + + The first program to utilise Validator() was Michael Foord's + ConfigObj, an alternative to ConfigParser which supports lists and + can validate a config file using a config schema. + For more details on using Validator with ConfigObj see: + http://www.voidspace.org.uk/python/configobj.html + """ + + # this regex does the initial parsing of the checks + _func_re = re.compile(r'(.+?)\((.*)\)') + + # this regex takes apart keyword arguments + _key_arg = re.compile(r'^([a-zA-Z_][a-zA-Z0-9_]*)\s*=\s*(.*)$') + + + # this regex finds keyword=list(....) type values + _list_arg = _list_arg + + # this regex takes individual values out of lists - in one pass + _list_members = _list_members + + # These regexes check a set of arguments for validity + # and then pull the members out + _paramfinder = re.compile(_paramstring, re.VERBOSE) + _matchfinder = re.compile(_matchstring, re.VERBOSE) + + + def __init__(self, functions=None): + """ + >>> vtri = Validator() + """ + self.functions = { + '': self._pass, + 'integer': is_integer, + 'float': is_float, + 'boolean': is_boolean, + 'ip_addr': is_ip_addr, + 'string': is_string, + 'list': is_list, + 'tuple': is_tuple, + 'int_list': is_int_list, + 'float_list': is_float_list, + 'bool_list': is_bool_list, + 'ip_addr_list': is_ip_addr_list, + 'string_list': is_string_list, + 'mixed_list': is_mixed_list, + 'pass': self._pass, + 'option': is_option, + } + if functions is not None: + self.functions.update(functions) + # tekNico: for use by ConfigObj + self.baseErrorClass = ValidateError + self._cache = {} + + + def check(self, check, value, missing=False): + """ + Usage: check(check, value) + + Arguments: + check: string representing check to apply (including arguments) + value: object to be checked + Returns value, converted to correct type if necessary + + If the check fails, raises a ``ValidateError`` subclass. + + >>> vtor.check('yoda', '') + Traceback (most recent call last): + VdtUnknownCheckError: the check "yoda" is unknown. + >>> vtor.check('yoda()', '') + Traceback (most recent call last): + VdtUnknownCheckError: the check "yoda" is unknown. + + >>> vtor.check('string(default="")', '', missing=True) + '' + """ + fun_name, fun_args, fun_kwargs, default = self._parse_with_caching(check) + + if missing: + if default is None: + # no information needed here - to be handled by caller + raise VdtMissingValue() + value = self._handle_none(default) + + if value is None: + return None + + return self._check_value(value, fun_name, fun_args, fun_kwargs) + + + def _handle_none(self, value): + if value == 'None': + value = None + elif value in ("'None'", '"None"'): + # Special case a quoted None + value = self._unquote(value) + return value + + + def _parse_with_caching(self, check): + if check in self._cache: + fun_name, fun_args, fun_kwargs, default = self._cache[check] + # We call list and dict below to work with *copies* of the data + # rather than the original (which are mutable of course) + fun_args = list(fun_args) + fun_kwargs = dict(fun_kwargs) + else: + fun_name, fun_args, fun_kwargs, default = self._parse_check(check) + fun_kwargs = dict((str(key), value) for (key, value) in fun_kwargs.items()) + self._cache[check] = fun_name, list(fun_args), dict(fun_kwargs), default + return fun_name, fun_args, fun_kwargs, default + + + def _check_value(self, value, fun_name, fun_args, fun_kwargs): + try: + fun = self.functions[fun_name] + except KeyError: + raise VdtUnknownCheckError(fun_name) + else: + return fun(value, *fun_args, **fun_kwargs) + + + def _parse_check(self, check): + fun_match = self._func_re.match(check) + if fun_match: + fun_name = fun_match.group(1) + arg_string = fun_match.group(2) + arg_match = self._matchfinder.match(arg_string) + if arg_match is None: + # Bad syntax + raise VdtParamError('Bad syntax in check "%s".' % check) + fun_args = [] + fun_kwargs = {} + # pull out args of group 2 + for arg in self._paramfinder.findall(arg_string): + # args may need whitespace removing (before removing quotes) + arg = arg.strip() + listmatch = self._list_arg.match(arg) + if listmatch: + key, val = self._list_handle(listmatch) + fun_kwargs[key] = val + continue + keymatch = self._key_arg.match(arg) + if keymatch: + val = keymatch.group(2) + if not val in ("'None'", '"None"'): + # Special case a quoted None + val = self._unquote(val) + fun_kwargs[keymatch.group(1)] = val + continue + + fun_args.append(self._unquote(arg)) + else: + # allows for function names without (args) + return check, (), {}, None + + # Default must be deleted if the value is specified too, + # otherwise the check function will get a spurious "default" keyword arg + try: + default = fun_kwargs.pop('default', None) + except AttributeError: + # Python 2.2 compatibility + default = None + try: + default = fun_kwargs['default'] + del fun_kwargs['default'] + except KeyError: + pass + + return fun_name, fun_args, fun_kwargs, default + + + def _unquote(self, val): + """Unquote a value if necessary.""" + if (len(val) >= 2) and (val[0] in ("'", '"')) and (val[0] == val[-1]): + val = val[1:-1] + return val + + + def _list_handle(self, listmatch): + """Take apart a ``keyword=list('val, 'val')`` type string.""" + out = [] + name = listmatch.group(1) + args = listmatch.group(2) + for arg in self._list_members.findall(args): + out.append(self._unquote(arg)) + return name, out + + + def _pass(self, value): + """ + Dummy check that always passes + + >>> vtor.check('', 0) + 0 + >>> vtor.check('', '0') + '0' + """ + return value + + + def get_default_value(self, check): + """ + Given a check, return the default value for the check + (converted to the right type). + + If the check doesn't specify a default value then a + ``KeyError`` will be raised. + """ + fun_name, fun_args, fun_kwargs, default = self._parse_with_caching(check) + if default is None: + raise KeyError('Check "%s" has no default value.' % check) + value = self._handle_none(default) + if value is None: + return value + return self._check_value(value, fun_name, fun_args, fun_kwargs) + + +def _is_num_param(names, values, to_float=False): + """ + Return numbers from inputs or raise VdtParamError. + + Lets ``None`` pass through. + Pass in keyword argument ``to_float=True`` to + use float for the conversion rather than int. + + >>> _is_num_param(('', ''), (0, 1.0)) + [0, 1] + >>> _is_num_param(('', ''), (0, 1.0), to_float=True) + [0.0, 1.0] + >>> _is_num_param(('a'), ('a')) + Traceback (most recent call last): + VdtParamError: passed an incorrect value "a" for parameter "a". + """ + fun = to_float and float or int + out_params = [] + for (name, val) in zip(names, values): + if val is None: + out_params.append(val) + elif isinstance(val, (int, long, float, StringTypes)): + try: + out_params.append(fun(val)) + except ValueError, e: + raise VdtParamError(name, val) + else: + raise VdtParamError(name, val) + return out_params + + +# built in checks +# you can override these by setting the appropriate name +# in Validator.functions +# note: if the params are specified wrongly in your input string, +# you will also raise errors. + +def is_integer(value, min=None, max=None): + """ + A check that tests that a given value is an integer (int, or long) + and optionally, between bounds. A negative value is accepted, while + a float will fail. + + If the value is a string, then the conversion is done - if possible. + Otherwise a VdtError is raised. + + >>> vtor.check('integer', '-1') + -1 + >>> vtor.check('integer', '0') + 0 + >>> vtor.check('integer', 9) + 9 + >>> vtor.check('integer', 'a') + Traceback (most recent call last): + VdtTypeError: the value "a" is of the wrong type. + >>> vtor.check('integer', '2.2') + Traceback (most recent call last): + VdtTypeError: the value "2.2" is of the wrong type. + >>> vtor.check('integer(10)', '20') + 20 + >>> vtor.check('integer(max=20)', '15') + 15 + >>> vtor.check('integer(10)', '9') + Traceback (most recent call last): + VdtValueTooSmallError: the value "9" is too small. + >>> vtor.check('integer(10)', 9) + Traceback (most recent call last): + VdtValueTooSmallError: the value "9" is too small. + >>> vtor.check('integer(max=20)', '35') + Traceback (most recent call last): + VdtValueTooBigError: the value "35" is too big. + >>> vtor.check('integer(max=20)', 35) + Traceback (most recent call last): + VdtValueTooBigError: the value "35" is too big. + >>> vtor.check('integer(0, 9)', False) + 0 + """ + (min_val, max_val) = _is_num_param(('min', 'max'), (min, max)) + if not isinstance(value, (int, long, StringTypes)): + raise VdtTypeError(value) + if isinstance(value, StringTypes): + # if it's a string - does it represent an integer ? + try: + value = int(value) + except ValueError: + raise VdtTypeError(value) + if (min_val is not None) and (value < min_val): + raise VdtValueTooSmallError(value) + if (max_val is not None) and (value > max_val): + raise VdtValueTooBigError(value) + return value + + +def is_float(value, min=None, max=None): + """ + A check that tests that a given value is a float + (an integer will be accepted), and optionally - that it is between bounds. + + If the value is a string, then the conversion is done - if possible. + Otherwise a VdtError is raised. + + This can accept negative values. + + >>> vtor.check('float', '2') + 2.0 + + From now on we multiply the value to avoid comparing decimals + + >>> vtor.check('float', '-6.8') * 10 + -68.0 + >>> vtor.check('float', '12.2') * 10 + 122.0 + >>> vtor.check('float', 8.4) * 10 + 84.0 + >>> vtor.check('float', 'a') + Traceback (most recent call last): + VdtTypeError: the value "a" is of the wrong type. + >>> vtor.check('float(10.1)', '10.2') * 10 + 102.0 + >>> vtor.check('float(max=20.2)', '15.1') * 10 + 151.0 + >>> vtor.check('float(10.0)', '9.0') + Traceback (most recent call last): + VdtValueTooSmallError: the value "9.0" is too small. + >>> vtor.check('float(max=20.0)', '35.0') + Traceback (most recent call last): + VdtValueTooBigError: the value "35.0" is too big. + """ + (min_val, max_val) = _is_num_param( + ('min', 'max'), (min, max), to_float=True) + if not isinstance(value, (int, long, float, StringTypes)): + raise VdtTypeError(value) + if not isinstance(value, float): + # if it's a string - does it represent a float ? + try: + value = float(value) + except ValueError: + raise VdtTypeError(value) + if (min_val is not None) and (value < min_val): + raise VdtValueTooSmallError(value) + if (max_val is not None) and (value > max_val): + raise VdtValueTooBigError(value) + return value + + +bool_dict = { + True: True, 'on': True, '1': True, 'true': True, 'yes': True, + False: False, 'off': False, '0': False, 'false': False, 'no': False, +} + + +def is_boolean(value): + """ + Check if the value represents a boolean. + + >>> vtor.check('boolean', 0) + 0 + >>> vtor.check('boolean', False) + 0 + >>> vtor.check('boolean', '0') + 0 + >>> vtor.check('boolean', 'off') + 0 + >>> vtor.check('boolean', 'false') + 0 + >>> vtor.check('boolean', 'no') + 0 + >>> vtor.check('boolean', 'nO') + 0 + >>> vtor.check('boolean', 'NO') + 0 + >>> vtor.check('boolean', 1) + 1 + >>> vtor.check('boolean', True) + 1 + >>> vtor.check('boolean', '1') + 1 + >>> vtor.check('boolean', 'on') + 1 + >>> vtor.check('boolean', 'true') + 1 + >>> vtor.check('boolean', 'yes') + 1 + >>> vtor.check('boolean', 'Yes') + 1 + >>> vtor.check('boolean', 'YES') + 1 + >>> vtor.check('boolean', '') + Traceback (most recent call last): + VdtTypeError: the value "" is of the wrong type. + >>> vtor.check('boolean', 'up') + Traceback (most recent call last): + VdtTypeError: the value "up" is of the wrong type. + + """ + if isinstance(value, StringTypes): + try: + return bool_dict[value.lower()] + except KeyError: + raise VdtTypeError(value) + # we do an equality test rather than an identity test + # this ensures Python 2.2 compatibilty + # and allows 0 and 1 to represent True and False + if value == False: + return False + elif value == True: + return True + else: + raise VdtTypeError(value) + + +def is_ip_addr(value): + """ + Check that the supplied value is an Internet Protocol address, v.4, + represented by a dotted-quad string, i.e. '1.2.3.4'. + + >>> vtor.check('ip_addr', '1 ') + '1' + >>> vtor.check('ip_addr', ' 1.2') + '1.2' + >>> vtor.check('ip_addr', ' 1.2.3 ') + '1.2.3' + >>> vtor.check('ip_addr', '1.2.3.4') + '1.2.3.4' + >>> vtor.check('ip_addr', '0.0.0.0') + '0.0.0.0' + >>> vtor.check('ip_addr', '255.255.255.255') + '255.255.255.255' + >>> vtor.check('ip_addr', '255.255.255.256') + Traceback (most recent call last): + VdtValueError: the value "255.255.255.256" is unacceptable. + >>> vtor.check('ip_addr', '1.2.3.4.5') + Traceback (most recent call last): + VdtValueError: the value "1.2.3.4.5" is unacceptable. + >>> vtor.check('ip_addr', '1.2.3. 4') + Traceback (most recent call last): + VdtValueError: the value "1.2.3. 4" is unacceptable. + >>> vtor.check('ip_addr', 0) + Traceback (most recent call last): + VdtTypeError: the value "0" is of the wrong type. + """ + if not isinstance(value, StringTypes): + raise VdtTypeError(value) + value = value.strip() + try: + dottedQuadToNum(value) + except ValueError: + raise VdtValueError(value) + return value + + +def is_list(value, min=None, max=None): + """ + Check that the value is a list of values. + + You can optionally specify the minimum and maximum number of members. + + It does no check on list members. + + >>> vtor.check('list', ()) + [] + >>> vtor.check('list', []) + [] + >>> vtor.check('list', (1, 2)) + [1, 2] + >>> vtor.check('list', [1, 2]) + [1, 2] + >>> vtor.check('list(3)', (1, 2)) + Traceback (most recent call last): + VdtValueTooShortError: the value "(1, 2)" is too short. + >>> vtor.check('list(max=5)', (1, 2, 3, 4, 5, 6)) + Traceback (most recent call last): + VdtValueTooLongError: the value "(1, 2, 3, 4, 5, 6)" is too long. + >>> vtor.check('list(min=3, max=5)', (1, 2, 3, 4)) + [1, 2, 3, 4] + >>> vtor.check('list', 0) + Traceback (most recent call last): + VdtTypeError: the value "0" is of the wrong type. + >>> vtor.check('list', '12') + Traceback (most recent call last): + VdtTypeError: the value "12" is of the wrong type. + """ + (min_len, max_len) = _is_num_param(('min', 'max'), (min, max)) + if isinstance(value, StringTypes): + raise VdtTypeError(value) + try: + num_members = len(value) + except TypeError: + raise VdtTypeError(value) + if min_len is not None and num_members < min_len: + raise VdtValueTooShortError(value) + if max_len is not None and num_members > max_len: + raise VdtValueTooLongError(value) + return list(value) + + +def is_tuple(value, min=None, max=None): + """ + Check that the value is a tuple of values. + + You can optionally specify the minimum and maximum number of members. + + It does no check on members. + + >>> vtor.check('tuple', ()) + () + >>> vtor.check('tuple', []) + () + >>> vtor.check('tuple', (1, 2)) + (1, 2) + >>> vtor.check('tuple', [1, 2]) + (1, 2) + >>> vtor.check('tuple(3)', (1, 2)) + Traceback (most recent call last): + VdtValueTooShortError: the value "(1, 2)" is too short. + >>> vtor.check('tuple(max=5)', (1, 2, 3, 4, 5, 6)) + Traceback (most recent call last): + VdtValueTooLongError: the value "(1, 2, 3, 4, 5, 6)" is too long. + >>> vtor.check('tuple(min=3, max=5)', (1, 2, 3, 4)) + (1, 2, 3, 4) + >>> vtor.check('tuple', 0) + Traceback (most recent call last): + VdtTypeError: the value "0" is of the wrong type. + >>> vtor.check('tuple', '12') + Traceback (most recent call last): + VdtTypeError: the value "12" is of the wrong type. + """ + return tuple(is_list(value, min, max)) + + +def is_string(value, min=None, max=None): + """ + Check that the supplied value is a string. + + You can optionally specify the minimum and maximum number of members. + + >>> vtor.check('string', '0') + '0' + >>> vtor.check('string', 0) + Traceback (most recent call last): + VdtTypeError: the value "0" is of the wrong type. + >>> vtor.check('string(2)', '12') + '12' + >>> vtor.check('string(2)', '1') + Traceback (most recent call last): + VdtValueTooShortError: the value "1" is too short. + >>> vtor.check('string(min=2, max=3)', '123') + '123' + >>> vtor.check('string(min=2, max=3)', '1234') + Traceback (most recent call last): + VdtValueTooLongError: the value "1234" is too long. + """ + if not isinstance(value, StringTypes): + raise VdtTypeError(value) + (min_len, max_len) = _is_num_param(('min', 'max'), (min, max)) + try: + num_members = len(value) + except TypeError: + raise VdtTypeError(value) + if min_len is not None and num_members < min_len: + raise VdtValueTooShortError(value) + if max_len is not None and num_members > max_len: + raise VdtValueTooLongError(value) + return value + + +def is_int_list(value, min=None, max=None): + """ + Check that the value is a list of integers. + + You can optionally specify the minimum and maximum number of members. + + Each list member is checked that it is an integer. + + >>> vtor.check('int_list', ()) + [] + >>> vtor.check('int_list', []) + [] + >>> vtor.check('int_list', (1, 2)) + [1, 2] + >>> vtor.check('int_list', [1, 2]) + [1, 2] + >>> vtor.check('int_list', [1, 'a']) + Traceback (most recent call last): + VdtTypeError: the value "a" is of the wrong type. + """ + return [is_integer(mem) for mem in is_list(value, min, max)] + + +def is_bool_list(value, min=None, max=None): + """ + Check that the value is a list of booleans. + + You can optionally specify the minimum and maximum number of members. + + Each list member is checked that it is a boolean. + + >>> vtor.check('bool_list', ()) + [] + >>> vtor.check('bool_list', []) + [] + >>> check_res = vtor.check('bool_list', (True, False)) + >>> check_res == [True, False] + 1 + >>> check_res = vtor.check('bool_list', [True, False]) + >>> check_res == [True, False] + 1 + >>> vtor.check('bool_list', [True, 'a']) + Traceback (most recent call last): + VdtTypeError: the value "a" is of the wrong type. + """ + return [is_boolean(mem) for mem in is_list(value, min, max)] + + +def is_float_list(value, min=None, max=None): + """ + Check that the value is a list of floats. + + You can optionally specify the minimum and maximum number of members. + + Each list member is checked that it is a float. + + >>> vtor.check('float_list', ()) + [] + >>> vtor.check('float_list', []) + [] + >>> vtor.check('float_list', (1, 2.0)) + [1.0, 2.0] + >>> vtor.check('float_list', [1, 2.0]) + [1.0, 2.0] + >>> vtor.check('float_list', [1, 'a']) + Traceback (most recent call last): + VdtTypeError: the value "a" is of the wrong type. + """ + return [is_float(mem) for mem in is_list(value, min, max)] + + +def is_string_list(value, min=None, max=None): + """ + Check that the value is a list of strings. + + You can optionally specify the minimum and maximum number of members. + + Each list member is checked that it is a string. + + >>> vtor.check('string_list', ()) + [] + >>> vtor.check('string_list', []) + [] + >>> vtor.check('string_list', ('a', 'b')) + ['a', 'b'] + >>> vtor.check('string_list', ['a', 1]) + Traceback (most recent call last): + VdtTypeError: the value "1" is of the wrong type. + >>> vtor.check('string_list', 'hello') + Traceback (most recent call last): + VdtTypeError: the value "hello" is of the wrong type. + """ + if isinstance(value, StringTypes): + raise VdtTypeError(value) + return [is_string(mem) for mem in is_list(value, min, max)] + + +def is_ip_addr_list(value, min=None, max=None): + """ + Check that the value is a list of IP addresses. + + You can optionally specify the minimum and maximum number of members. + + Each list member is checked that it is an IP address. + + >>> vtor.check('ip_addr_list', ()) + [] + >>> vtor.check('ip_addr_list', []) + [] + >>> vtor.check('ip_addr_list', ('1.2.3.4', '5.6.7.8')) + ['1.2.3.4', '5.6.7.8'] + >>> vtor.check('ip_addr_list', ['a']) + Traceback (most recent call last): + VdtValueError: the value "a" is unacceptable. + """ + return [is_ip_addr(mem) for mem in is_list(value, min, max)] + + +fun_dict = { + 'integer': is_integer, + 'float': is_float, + 'ip_addr': is_ip_addr, + 'string': is_string, + 'boolean': is_boolean, +} + + +def is_mixed_list(value, *args): + """ + Check that the value is a list. + Allow specifying the type of each member. + Work on lists of specific lengths. + + You specify each member as a positional argument specifying type + + Each type should be one of the following strings : + 'integer', 'float', 'ip_addr', 'string', 'boolean' + + So you can specify a list of two strings, followed by + two integers as : + + mixed_list('string', 'string', 'integer', 'integer') + + The length of the list must match the number of positional + arguments you supply. + + >>> mix_str = "mixed_list('integer', 'float', 'ip_addr', 'string', 'boolean')" + >>> check_res = vtor.check(mix_str, (1, 2.0, '1.2.3.4', 'a', True)) + >>> check_res == [1, 2.0, '1.2.3.4', 'a', True] + 1 + >>> check_res = vtor.check(mix_str, ('1', '2.0', '1.2.3.4', 'a', 'True')) + >>> check_res == [1, 2.0, '1.2.3.4', 'a', True] + 1 + >>> vtor.check(mix_str, ('b', 2.0, '1.2.3.4', 'a', True)) + Traceback (most recent call last): + VdtTypeError: the value "b" is of the wrong type. + >>> vtor.check(mix_str, (1, 2.0, '1.2.3.4', 'a')) + Traceback (most recent call last): + VdtValueTooShortError: the value "(1, 2.0, '1.2.3.4', 'a')" is too short. + >>> vtor.check(mix_str, (1, 2.0, '1.2.3.4', 'a', 1, 'b')) + Traceback (most recent call last): + VdtValueTooLongError: the value "(1, 2.0, '1.2.3.4', 'a', 1, 'b')" is too long. + >>> vtor.check(mix_str, 0) + Traceback (most recent call last): + VdtTypeError: the value "0" is of the wrong type. + + This test requires an elaborate setup, because of a change in error string + output from the interpreter between Python 2.2 and 2.3 . + + >>> res_seq = ( + ... 'passed an incorrect value "', + ... 'yoda', + ... '" for parameter "mixed_list".', + ... ) + >>> if INTP_VER == (2, 2): + ... res_str = "".join(res_seq) + ... else: + ... res_str = "'".join(res_seq) + >>> try: + ... vtor.check('mixed_list("yoda")', ('a')) + ... except VdtParamError, err: + ... str(err) == res_str + 1 + """ + try: + length = len(value) + except TypeError: + raise VdtTypeError(value) + if length < len(args): + raise VdtValueTooShortError(value) + elif length > len(args): + raise VdtValueTooLongError(value) + try: + return [fun_dict[arg](val) for arg, val in zip(args, value)] + except KeyError, e: + raise VdtParamError('mixed_list', e) + + +def is_option(value, *options): + """ + This check matches the value to any of a set of options. + + >>> vtor.check('option("yoda", "jedi")', 'yoda') + 'yoda' + >>> vtor.check('option("yoda", "jedi")', 'jed') + Traceback (most recent call last): + VdtValueError: the value "jed" is unacceptable. + >>> vtor.check('option("yoda", "jedi")', 0) + Traceback (most recent call last): + VdtTypeError: the value "0" is of the wrong type. + """ + if not isinstance(value, StringTypes): + raise VdtTypeError(value) + if not value in options: + raise VdtValueError(value) + return value + + +def _test(value, *args, **keywargs): + """ + A function that exists for test purposes. + + >>> checks = [ + ... '3, 6, min=1, max=3, test=list(a, b, c)', + ... '3', + ... '3, 6', + ... '3,', + ... 'min=1, test="a b c"', + ... 'min=5, test="a, b, c"', + ... 'min=1, max=3, test="a, b, c"', + ... 'min=-100, test=-99', + ... 'min=1, max=3', + ... '3, 6, test="36"', + ... '3, 6, test="a, b, c"', + ... '3, max=3, test=list("a", "b", "c")', + ... '''3, max=3, test=list("'a'", 'b', "x=(c)")''', + ... "test='x=fish(3)'", + ... ] + >>> v = Validator({'test': _test}) + >>> for entry in checks: + ... print v.check(('test(%s)' % entry), 3) + (3, ('3', '6'), {'test': ['a', 'b', 'c'], 'max': '3', 'min': '1'}) + (3, ('3',), {}) + (3, ('3', '6'), {}) + (3, ('3',), {}) + (3, (), {'test': 'a b c', 'min': '1'}) + (3, (), {'test': 'a, b, c', 'min': '5'}) + (3, (), {'test': 'a, b, c', 'max': '3', 'min': '1'}) + (3, (), {'test': '-99', 'min': '-100'}) + (3, (), {'max': '3', 'min': '1'}) + (3, ('3', '6'), {'test': '36'}) + (3, ('3', '6'), {'test': 'a, b, c'}) + (3, ('3',), {'test': ['a', 'b', 'c'], 'max': '3'}) + (3, ('3',), {'test': ["'a'", 'b', 'x=(c)'], 'max': '3'}) + (3, (), {'test': 'x=fish(3)'}) + + >>> v = Validator() + >>> v.check('integer(default=6)', '3') + 3 + >>> v.check('integer(default=6)', None, True) + 6 + >>> v.get_default_value('integer(default=6)') + 6 + >>> v.get_default_value('float(default=6)') + 6.0 + >>> v.get_default_value('pass(default=None)') + >>> v.get_default_value("string(default='None')") + 'None' + >>> v.get_default_value('pass') + Traceback (most recent call last): + KeyError: 'Check "pass" has no default value.' + >>> v.get_default_value('pass(default=list(1, 2, 3, 4))') + ['1', '2', '3', '4'] + + >>> v = Validator() + >>> v.check("pass(default=None)", None, True) + >>> v.check("pass(default='None')", None, True) + 'None' + >>> v.check('pass(default="None")', None, True) + 'None' + >>> v.check('pass(default=list(1, 2, 3, 4))', None, True) + ['1', '2', '3', '4'] + + Bug test for unicode arguments + >>> v = Validator() + >>> v.check(u'string(min=4)', u'test') + u'test' + + >>> v = Validator() + >>> v.get_default_value(u'string(min=4, default="1234")') + u'1234' + >>> v.check(u'string(min=4, default="1234")', u'test') + u'test' + + >>> v = Validator() + >>> default = v.get_default_value('string(default=None)') + >>> default == None + 1 + """ + return (value, args, keywargs) + + +if __name__ == '__main__': + # run the code tests in doctest format + import doctest + m = sys.modules.get('__main__') + globs = m.__dict__.copy() + globs.update({ + 'INTP_VER': INTP_VER, + 'vtor': Validator(), + }) + doctest.testmod(m, globs=globs) diff --git a/IPython/kernel/__init__.py b/IPython/kernel/__init__.py new file mode 100755 index 0000000..ce2ad73 --- /dev/null +++ b/IPython/kernel/__init__.py @@ -0,0 +1,24 @@ +# encoding: utf-8 +"""The IPython1 kernel. + +The IPython kernel actually refers to three things: + + * The IPython Engine + * The IPython Controller + * Clients to the IPython Controller + +The kernel module implements the engine, controller and client and all the +network protocols needed for the various entities to talk to each other. + +An end user should probably begin by looking at the `client.py` module +if they need blocking clients or in `asyncclient.py` if they want asynchronous, +deferred/Twisted using clients. +""" +__docformat__ = "restructuredtext en" +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + \ No newline at end of file diff --git a/IPython/kernel/asyncclient.py b/IPython/kernel/asyncclient.py new file mode 100644 index 0000000..a1542e2 --- /dev/null +++ b/IPython/kernel/asyncclient.py @@ -0,0 +1,41 @@ +# encoding: utf-8 + +"""Asynchronous clients for the IPython controller. + +This module has clients for using the various interfaces of the controller +in a fully asynchronous manner. This means that you will need to run the +Twisted reactor yourself and that all methods of the client classes return +deferreds to the result. + +The main methods are are `get_*_client` and `get_client`. +""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from IPython.kernel import codeutil +from IPython.kernel.clientconnector import ClientConnector + +# Other things that the user will need +from IPython.kernel.task import Task +from IPython.kernel.error import CompositeError + +#------------------------------------------------------------------------------- +# Code +#------------------------------------------------------------------------------- + +_client_tub = ClientConnector() +get_multiengine_client = _client_tub.get_multiengine_client +get_task_client = _client_tub.get_task_client +get_client = _client_tub.get_client + diff --git a/IPython/kernel/client.py b/IPython/kernel/client.py new file mode 100644 index 0000000..85d677b --- /dev/null +++ b/IPython/kernel/client.py @@ -0,0 +1,96 @@ +# encoding: utf-8 + +"""This module contains blocking clients for the controller interfaces. + +Unlike the clients in `asyncclient.py`, the clients in this module are fully +blocking. This means that methods on the clients return the actual results +rather than a deferred to the result. Also, we manage the Twisted reactor +for you. This is done by running the reactor in a thread. + +The main classes in this module are: + + * MultiEngineClient + * TaskClient + * Task + * CompositeError +""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import sys + +# from IPython.tools import growl +# growl.start("IPython1 Client") + + +from twisted.internet import reactor +from IPython.kernel.clientconnector import ClientConnector +from IPython.kernel.twistedutil import ReactorInThread +from IPython.kernel.twistedutil import blockingCallFromThread + +# These enable various things +from IPython.kernel import codeutil +import IPython.kernel.magic + +# Other things that the user will need +from IPython.kernel.task import Task +from IPython.kernel.error import CompositeError + +#------------------------------------------------------------------------------- +# Code +#------------------------------------------------------------------------------- + +_client_tub = ClientConnector() + + +def get_multiengine_client(furl_or_file=''): + """Get the blocking MultiEngine client. + + :Parameters: + furl_or_file : str + A furl or a filename containing a furl. If empty, the + default furl_file will be used + + :Returns: + The connected MultiEngineClient instance + """ + client = blockingCallFromThread(_client_tub.get_multiengine_client, + furl_or_file) + return client.adapt_to_blocking_client() + +def get_task_client(furl_or_file=''): + """Get the blocking Task client. + + :Parameters: + furl_or_file : str + A furl or a filename containing a furl. If empty, the + default furl_file will be used + + :Returns: + The connected TaskClient instance + """ + client = blockingCallFromThread(_client_tub.get_task_client, + furl_or_file) + return client.adapt_to_blocking_client() + + +MultiEngineClient = get_multiengine_client +TaskClient = get_task_client + + + +# Now we start the reactor in a thread +rit = ReactorInThread() +rit.setDaemon(True) +rit.start() \ No newline at end of file diff --git a/IPython/kernel/clientconnector.py b/IPython/kernel/clientconnector.py new file mode 100644 index 0000000..595a353 --- /dev/null +++ b/IPython/kernel/clientconnector.py @@ -0,0 +1,150 @@ +# encoding: utf-8 + +"""A class for handling client connections to the controller.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from twisted.internet import defer + +from IPython.kernel.fcutil import Tub, UnauthenticatedTub + +from IPython.kernel.config import config_manager as kernel_config_manager +from IPython.config.cutils import import_item +from IPython.kernel.fcutil import find_furl + +co = kernel_config_manager.get_config_obj() +client_co = co['client'] + +#------------------------------------------------------------------------------- +# The ClientConnector class +#------------------------------------------------------------------------------- + +class ClientConnector(object): + """ + This class gets remote references from furls and returns the wrapped clients. + + This class is also used in `client.py` and `asyncclient.py` to create + a single per client-process Tub. + """ + + def __init__(self): + self._remote_refs = {} + self.tub = Tub() + self.tub.startService() + + def get_reference(self, furl_or_file): + """ + Get a remote reference using a furl or a file containing a furl. + + Remote references are cached locally so once a remote reference + has been retrieved for a given furl, the cached version is + returned. + + :Parameters: + furl_or_file : str + A furl or a filename containing a furl + + :Returns: + A deferred to a remote reference + """ + furl = find_furl(furl_or_file) + if furl in self._remote_refs: + d = defer.succeed(self._remote_refs[furl]) + else: + d = self.tub.getReference(furl) + d.addCallback(self.save_ref, furl) + return d + + def save_ref(self, ref, furl): + """ + Cache a remote reference by its furl. + """ + self._remote_refs[furl] = ref + return ref + + def get_task_client(self, furl_or_file=''): + """ + Get the task controller client. + + This method is a simple wrapper around `get_client` that allow + `furl_or_file` to be empty, in which case, the furls is taken + from the default furl file given in the configuration. + + :Parameters: + furl_or_file : str + A furl or a filename containing a furl. If empty, the + default furl_file will be used + + :Returns: + A deferred to the actual client class + """ + task_co = client_co['client_interfaces']['task'] + if furl_or_file: + ff = furl_or_file + else: + ff = task_co['furl_file'] + return self.get_client(ff) + + def get_multiengine_client(self, furl_or_file=''): + """ + Get the multiengine controller client. + + This method is a simple wrapper around `get_client` that allow + `furl_or_file` to be empty, in which case, the furls is taken + from the default furl file given in the configuration. + + :Parameters: + furl_or_file : str + A furl or a filename containing a furl. If empty, the + default furl_file will be used + + :Returns: + A deferred to the actual client class + """ + task_co = client_co['client_interfaces']['multiengine'] + if furl_or_file: + ff = furl_or_file + else: + ff = task_co['furl_file'] + return self.get_client(ff) + + def get_client(self, furl_or_file): + """ + Get a remote reference and wrap it in a client by furl. + + This method first gets a remote reference and then calls its + `get_client_name` method to find the apprpriate client class + that should be used to wrap the remote reference. + + :Parameters: + furl_or_file : str + A furl or a filename containing a furl + + :Returns: + A deferred to the actual client class + """ + furl = find_furl(furl_or_file) + d = self.get_reference(furl) + def wrap_remote_reference(rr): + d = rr.callRemote('get_client_name') + d.addCallback(lambda name: import_item(name)) + def adapt(client_interface): + client = client_interface(rr) + client.tub = self.tub + return client + d.addCallback(adapt) + + return d + d.addCallback(wrap_remote_reference) + return d diff --git a/IPython/kernel/clientinterfaces.py b/IPython/kernel/clientinterfaces.py new file mode 100644 index 0000000..248e511 --- /dev/null +++ b/IPython/kernel/clientinterfaces.py @@ -0,0 +1,32 @@ +# encoding: utf-8 + +"""General client interfaces.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from zope.interface import Interface, implements + +class IFCClientInterfaceProvider(Interface): + + def remote_get_client_name(): + """Return a string giving the class which implements a client-side interface. + + The client side of any foolscap connection initially gets a remote reference. + Some class is needed to adapt that reference to an interface. This... + """ + +class IBlockingClientAdaptor(Interface): + + def adapt_to_blocking_client(): + """""" \ No newline at end of file diff --git a/IPython/kernel/codeutil.py b/IPython/kernel/codeutil.py new file mode 100644 index 0000000..31e0361 --- /dev/null +++ b/IPython/kernel/codeutil.py @@ -0,0 +1,39 @@ +# encoding: utf-8 + +"""Utilities to enable code objects to be pickled. + +Any process that import this module will be able to pickle code objects. This +includes the func_code attribute of any function. Once unpickled, new +functions can be built using new.function(code, globals()). Eventually +we need to automate all of this so that functions themselves can be pickled. + +Reference: A. Tremols, P Cogolo, "Python Cookbook," p 302-305 +""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import new, types, copy_reg + +def code_ctor(*args): + return new.code(*args) + +def reduce_code(co): + if co.co_freevars or co.co_cellvars: + raise ValueError("Sorry, cannot pickle code objects with closures") + return code_ctor, (co.co_argcount, co.co_nlocals, co.co_stacksize, + co.co_flags, co.co_code, co.co_consts, co.co_names, + co.co_varnames, co.co_filename, co.co_name, co.co_firstlineno, + co.co_lnotab) + +copy_reg.pickle(types.CodeType, reduce_code) \ No newline at end of file diff --git a/IPython/kernel/config/__init__.py b/IPython/kernel/config/__init__.py new file mode 100644 index 0000000..efc8d0a --- /dev/null +++ b/IPython/kernel/config/__init__.py @@ -0,0 +1,125 @@ +# encoding: utf-8 + +"""Default kernel configuration.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from IPython.external.configobj import ConfigObj +from IPython.config.api import ConfigObjManager +from IPython.config.cutils import get_ipython_dir + +default_kernel_config = ConfigObj() + +try: + ipython_dir = get_ipython_dir() + '/' +except: + # This will defaults to the cwd + ipython_dir = '' + +#------------------------------------------------------------------------------- +# Engine Configuration +#------------------------------------------------------------------------------- + +engine_config = dict( + logfile = '', # Empty means log to stdout + furl_file = ipython_dir + 'ipcontroller-engine.furl' +) + +#------------------------------------------------------------------------------- +# MPI Configuration +#------------------------------------------------------------------------------- + +mpi_config = dict( + mpi4py = """from mpi4py import MPI as mpi +mpi.size = mpi.COMM_WORLD.Get_size() +mpi.rank = mpi.COMM_WORLD.Get_rank() +""", + pytrilinos = """from PyTrilinos import Epetra +class SimpleStruct: + pass +mpi = SimpleStruct() +mpi.rank = 0 +mpi.size = 0 +""", + default = '' +) + +#------------------------------------------------------------------------------- +# Controller Configuration +#------------------------------------------------------------------------------- + +controller_config = dict( + + logfile = '', # Empty means log to stdout + import_statement = '', + + engine_tub = dict( + ip = '', # Empty string means all interfaces + port = 0, # 0 means pick a port for me + location = '', # Empty string means try to set automatically + secure = True, + cert_file = ipython_dir + 'ipcontroller-engine.pem', + ), + engine_fc_interface = 'IPython.kernel.enginefc.IFCControllerBase', + engine_furl_file = ipython_dir + 'ipcontroller-engine.furl', + + controller_interfaces = dict( + # multiengine = dict( + # controller_interface = 'IPython.kernel.multiengine.IMultiEngine', + # fc_interface = 'IPython.kernel.multienginefc.IFCMultiEngine', + # furl_file = 'ipcontroller-mec.furl' + # ), + task = dict( + controller_interface = 'IPython.kernel.task.ITaskController', + fc_interface = 'IPython.kernel.taskfc.IFCTaskController', + furl_file = ipython_dir + 'ipcontroller-tc.furl' + ), + multiengine = dict( + controller_interface = 'IPython.kernel.multiengine.IMultiEngine', + fc_interface = 'IPython.kernel.multienginefc.IFCSynchronousMultiEngine', + furl_file = ipython_dir + 'ipcontroller-mec.furl' + ) + ), + + client_tub = dict( + ip = '', # Empty string means all interfaces + port = 0, # 0 means pick a port for me + location = '', # Empty string means try to set automatically + secure = True, + cert_file = ipython_dir + 'ipcontroller-client.pem' + ) +) + +#------------------------------------------------------------------------------- +# Client Configuration +#------------------------------------------------------------------------------- + +client_config = dict( + client_interfaces = dict( + task = dict( + furl_file = ipython_dir + 'ipcontroller-tc.furl' + ), + multiengine = dict( + furl_file = ipython_dir + 'ipcontroller-mec.furl' + ) + ) +) + +default_kernel_config['engine'] = engine_config +default_kernel_config['mpi'] = mpi_config +default_kernel_config['controller'] = controller_config +default_kernel_config['client'] = client_config + + +config_manager = ConfigObjManager(default_kernel_config, 'IPython.kernel.ini') \ No newline at end of file diff --git a/IPython/kernel/contexts.py b/IPython/kernel/contexts.py new file mode 100644 index 0000000..040102b --- /dev/null +++ b/IPython/kernel/contexts.py @@ -0,0 +1,178 @@ +# encoding: utf-8 +# -*- test-case-name: IPython.kernel.test.test_contexts -*- +"""Context managers for IPython. + +Python 2.5 introduced the `with` statement, which is based on the context +manager protocol. This module offers a few context managers for common cases, +which can also be useful as templates for writing new, application-specific +managers. +""" + +from __future__ import with_statement + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import linecache +import sys + +from twisted.internet.error import ConnectionRefusedError + +from IPython.ultraTB import _fixed_getinnerframes, findsource +from IPython import ipapi + +from IPython.kernel import error + +#--------------------------------------------------------------------------- +# Utility functions needed by all context managers. +#--------------------------------------------------------------------------- + +def remote(): + """Raises a special exception meant to be caught by context managers. + """ + m = 'Special exception to stop local execution of parallel code.' + raise error.StopLocalExecution(m) + + +def strip_whitespace(source,require_remote=True): + """strip leading whitespace from input source. + + :Parameters: + + """ + remote_mark = 'remote()' + # Expand tabs to avoid any confusion. + wsource = [l.expandtabs(4) for l in source] + # Detect the indentation level + done = False + for line in wsource: + if line.isspace(): + continue + for col,char in enumerate(line): + if char != ' ': + done = True + break + if done: + break + # Now we know how much leading space there is in the code. Next, we + # extract up to the first line that has less indentation. + # WARNINGS: we skip comments that may be misindented, but we do NOT yet + # detect triple quoted strings that may have flush left text. + for lno,line in enumerate(wsource): + lead = line[:col] + if lead.isspace(): + continue + else: + if not lead.lstrip().startswith('#'): + break + # The real 'with' source is up to lno + src_lines = [l[col:] for l in wsource[:lno+1]] + + # Finally, check that the source's first non-comment line begins with the + # special call 'remote()' + if require_remote: + for nline,line in enumerate(src_lines): + if line.isspace() or line.startswith('#'): + continue + if line.startswith(remote_mark): + break + else: + raise ValueError('%s call missing at the start of code' % + remote_mark) + out_lines = src_lines[nline+1:] + else: + # If the user specified that the remote() call wasn't mandatory + out_lines = src_lines + + # src = ''.join(out_lines) # dbg + #print 'SRC:\n<<<<<<<>>>>>>>\n%s<<<<<>>>>>>' % src # dbg + return ''.join(out_lines) + +class RemoteContextBase(object): + def __init__(self): + self.ip = ipapi.get() + + def _findsource_file(self,f): + linecache.checkcache() + s = findsource(f.f_code) + lnum = f.f_lineno + wsource = s[0][f.f_lineno:] + return strip_whitespace(wsource) + + def _findsource_ipython(self,f): + from IPython import ipapi + self.ip = ipapi.get() + buf = self.ip.IP.input_hist_raw[-1].splitlines()[1:] + wsource = [l+'\n' for l in buf ] + + return strip_whitespace(wsource) + + def findsource(self,frame): + local_ns = frame.f_locals + global_ns = frame.f_globals + if frame.f_code.co_filename == '': + src = self._findsource_ipython(frame) + else: + src = self._findsource_file(frame) + return src + + def __enter__(self): + raise NotImplementedError + + def __exit__ (self, etype, value, tb): + if issubclass(etype,error.StopLocalExecution): + return True + +class RemoteMultiEngine(RemoteContextBase): + def __init__(self,mec): + self.mec = mec + RemoteContextBase.__init__(self) + + def __enter__(self): + src = self.findsource(sys._getframe(1)) + return self.mec.execute(src) + + +# XXX - Temporary hackish testing, we'll move this into proper tests right +# away + +if __name__ == '__main__': + + # XXX - for now, we need a running cluster to be started separately. The + # daemon work is almost finished, and will make much of this unnecessary. + from IPython.kernel import client + mec = client.MultiEngineClient(('127.0.0.1',10105)) + + try: + mec.get_ids() + except ConnectionRefusedError: + import os, time + os.system('ipcluster -n 2 &') + time.sleep(2) + mec = client.MultiEngineClient(('127.0.0.1',10105)) + + mec.block = False + + import itertools + c = itertools.count() + + parallel = RemoteMultiEngine(mec) + + with parallel as pr: + # A comment + remote() # this means the code below only runs remotely + print 'Hello remote world' + x = 3.14 + # Comments are OK + # Even misindented. + y = x+1 diff --git a/IPython/kernel/controllerservice.py b/IPython/kernel/controllerservice.py new file mode 100644 index 0000000..c341da6 --- /dev/null +++ b/IPython/kernel/controllerservice.py @@ -0,0 +1,376 @@ +# encoding: utf-8 +# -*- test-case-name: IPython.kernel.test.test_controllerservice -*- + +"""A Twisted Service for the IPython Controller. + +The IPython Controller: + +* Listens for Engines to connect and then manages access to those engines. +* Listens for clients and passes commands from client to the Engines. +* Exposes an asynchronous interfaces to the Engines which themselves can block. +* Acts as a gateway to the Engines. + +The design of the controller is somewhat abstract to allow flexibility in how +the controller is presented to clients. This idea is that there is a basic +ControllerService class that allows engines to connect to it. But, this +basic class has no client interfaces. To expose client interfaces developers +provide an adapter that makes the ControllerService look like something. For +example, one client interface might support task farming and another might +support interactive usage. The important thing is that by using interfaces +and adapters, a single controller can be accessed from multiple interfaces. +Furthermore, by adapting various client interfaces to various network +protocols, each client interface can be exposed to multiple network protocols. +See multiengine.py for an example of how to adapt the ControllerService +to a client interface. +""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import os, sys + +from twisted.application import service +from twisted.internet import defer, reactor +from twisted.python import log, components +from zope.interface import Interface, implements, Attribute +import zope.interface as zi + +from IPython.kernel.engineservice import \ + IEngineCore, \ + IEngineSerialized, \ + IEngineQueued + +from IPython.config import cutils +from IPython.kernel import codeutil + +#------------------------------------------------------------------------------- +# Interfaces for the Controller +#------------------------------------------------------------------------------- + +class IControllerCore(Interface): + """Basic methods any controller must have. + + This is basically the aspect of the controller relevant to the + engines and does not assume anything about how the engines will + be presented to a client. + """ + + engines = Attribute("A dict of engine ids and engine instances.") + + def register_engine(remoteEngine, id=None, ip=None, port=None, + pid=None): + """Register new remote engine. + + The controller can use the ip, port, pid of the engine to do useful things + like kill the engines. + + :Parameters: + remoteEngine + An implementer of IEngineCore, IEngineSerialized and IEngineQueued. + id : int + Requested id. + ip : str + IP address the engine is running on. + port : int + Port the engine is on. + pid : int + pid of the running engine. + + :Returns: A dict of {'id':id} and possibly other key, value pairs. + """ + + def unregister_engine(id): + """Handle a disconnecting engine. + + :Parameters: + id + The integer engine id of the engine to unregister. + """ + + def on_register_engine_do(f, includeID, *args, **kwargs): + """Call ``f(*args, **kwargs)`` when an engine is registered. + + :Parameters: + includeID : int + If True the first argument to f will be the id of the engine. + """ + + def on_unregister_engine_do(f, includeID, *args, **kwargs): + """Call ``f(*args, **kwargs)`` when an engine is unregistered. + + :Parameters: + includeID : int + If True the first argument to f will be the id of the engine. + """ + + def on_register_engine_do_not(f): + """Stop calling f on engine registration""" + + def on_unregister_engine_do_not(f): + """Stop calling f on engine unregistration""" + + def on_n_engines_registered_do(n, f, *arg, **kwargs): + """Call f(*args, **kwargs) the first time the nth engine registers.""" + +class IControllerBase(IControllerCore): + """The basic controller interface.""" + pass + + +#------------------------------------------------------------------------------- +# Implementation of the ControllerService +#------------------------------------------------------------------------------- + +class ControllerService(object, service.Service): + """A basic Controller represented as a Twisted Service. + + This class doesn't implement any client notification mechanism. That + is up to adapted subclasses. + """ + + # I also pick up the IService interface by inheritance from service.Service + implements(IControllerBase) + name = 'ControllerService' + + def __init__(self, maxEngines=511, saveIDs=False): + self.saveIDs = saveIDs + self.engines = {} + self.availableIDs = range(maxEngines,-1,-1) # [511,...,0] + self._onRegister = [] + self._onUnregister = [] + self._onNRegistered = [] + + #--------------------------------------------------------------------------- + # Methods used to save the engine info to a log file + #--------------------------------------------------------------------------- + + def _buildEngineInfoString(self, id, ip, port, pid): + if id is None: + id = -99 + if ip is None: + ip = "-99" + if port is None: + port = -99 + if pid is None: + pid = -99 + return "Engine Info: %d %s %d %d" % (id, ip , port, pid) + + def _logEngineInfo(self, id, ip, port, pid): + log.msg(self._buildEngineInfoString(id,ip,port,pid)) + + def _getEngineInfoLogFile(self): + # Store all logs inside the ipython directory + ipdir = cutils.get_ipython_dir() + pjoin = os.path.join + logdir_base = pjoin(ipdir,'log') + if not os.path.isdir(logdir_base): + os.makedirs(logdir_base) + logfile = os.path.join(logdir_base,'ipcontroller-%s-engine-info.log' % os.getpid()) + return logfile + + def _logEngineInfoToFile(self, id, ip, port, pid): + """Log info about an engine to a log file. + + When an engine registers with a ControllerService, the ControllerService + saves information about the engine to a log file. That information + can be useful for various purposes, such as killing hung engines, etc. + + This method takes the assigned id, ip/port and pid of the engine + and saves it to a file of the form: + + ~/.ipython/log/ipcontroller-###-engine-info.log + + where ### is the pid of the controller. + + Each line of this file has the form: + + Engine Info: ip ip port pid + + If any of the entries are not known, they are replaced by -99. + """ + + fname = self._getEngineInfoLogFile() + f = open(fname, 'a') + s = self._buildEngineInfoString(id,ip,port,pid) + f.write(s + '\n') + f.close() + + #--------------------------------------------------------------------------- + # IControllerCore methods + #--------------------------------------------------------------------------- + + def register_engine(self, remoteEngine, id=None, + ip=None, port=None, pid=None): + """Register new engine connection""" + + # What happens if these assertions fail? + assert IEngineCore.providedBy(remoteEngine), \ + "engine passed to register_engine doesn't provide IEngineCore" + assert IEngineSerialized.providedBy(remoteEngine), \ + "engine passed to register_engine doesn't provide IEngineSerialized" + assert IEngineQueued.providedBy(remoteEngine), \ + "engine passed to register_engine doesn't provide IEngineQueued" + assert isinstance(id, int) or id is None, \ + "id to register_engine must be an integer or None" + assert isinstance(ip, str) or ip is None, \ + "ip to register_engine must be a string or None" + assert isinstance(port, int) or port is None, \ + "port to register_engine must be an integer or None" + assert isinstance(pid, int) or pid is None, \ + "pid to register_engine must be an integer or None" + + desiredID = id + if desiredID in self.engines.keys(): + desiredID = None + + if desiredID in self.availableIDs: + getID = desiredID + self.availableIDs.remove(desiredID) + else: + getID = self.availableIDs.pop() + remoteEngine.id = getID + remoteEngine.service = self + self.engines[getID] = remoteEngine + + # Log the Engine Information for monitoring purposes + self._logEngineInfoToFile(getID, ip, port, pid) + + msg = "registered engine with id: %i" %getID + log.msg(msg) + + for i in range(len(self._onRegister)): + (f,args,kwargs,ifid) = self._onRegister[i] + try: + if ifid: + f(getID, *args, **kwargs) + else: + f(*args, **kwargs) + except: + self._onRegister.pop(i) + + # Call functions when the nth engine is registered and them remove them + for i, (n, f, args, kwargs) in enumerate(self._onNRegistered): + if len(self.engines.keys()) == n: + try: + try: + f(*args, **kwargs) + except: + log.msg("Function %r failed when the %ith engine registered" % (f, n)) + finally: + self._onNRegistered.pop(i) + + return {'id':getID} + + def unregister_engine(self, id): + """Unregister engine by id.""" + + assert isinstance(id, int) or id is None, \ + "id to unregister_engine must be an integer or None" + + msg = "unregistered engine with id: %i" %id + log.msg(msg) + try: + del self.engines[id] + except KeyError: + log.msg("engine with id %i was not registered" % id) + else: + if not self.saveIDs: + self.availableIDs.append(id) + # Sort to assign lower ids first + self.availableIDs.sort(reverse=True) + else: + log.msg("preserving id %i" %id) + + for i in range(len(self._onUnregister)): + (f,args,kwargs,ifid) = self._onUnregister[i] + try: + if ifid: + f(id, *args, **kwargs) + else: + f(*args, **kwargs) + except: + self._onUnregister.pop(i) + + def on_register_engine_do(self, f, includeID, *args, **kwargs): + assert callable(f), "f must be callable" + self._onRegister.append((f,args,kwargs,includeID)) + + def on_unregister_engine_do(self, f, includeID, *args, **kwargs): + assert callable(f), "f must be callable" + self._onUnregister.append((f,args,kwargs,includeID)) + + def on_register_engine_do_not(self, f): + for i in range(len(self._onRegister)): + g = self._onRegister[i][0] + if f == g: + self._onRegister.pop(i) + return + + def on_unregister_engine_do_not(self, f): + for i in range(len(self._onUnregister)): + g = self._onUnregister[i][0] + if f == g: + self._onUnregister.pop(i) + return + + def on_n_engines_registered_do(self, n, f, *args, **kwargs): + if len(self.engines.keys()) >= n: + f(*args, **kwargs) + else: + self._onNRegistered.append((n,f,args,kwargs)) + + +#------------------------------------------------------------------------------- +# Base class for adapting controller to different client APIs +#------------------------------------------------------------------------------- + +class ControllerAdapterBase(object): + """All Controller adapters should inherit from this class. + + This class provides a wrapped version of the IControllerBase interface that + can be used to easily create new custom controllers. Subclasses of this + will provide a full implementation of IControllerBase. + + This class doesn't implement any client notification mechanism. That + is up to subclasses. + """ + + implements(IControllerBase) + + def __init__(self, controller): + self.controller = controller + # Needed for IControllerCore + self.engines = self.controller.engines + + def register_engine(self, remoteEngine, id=None, + ip=None, port=None, pid=None): + return self.controller.register_engine(remoteEngine, + id, ip, port, pid) + + def unregister_engine(self, id): + return self.controller.unregister_engine(id) + + def on_register_engine_do(self, f, includeID, *args, **kwargs): + return self.controller.on_register_engine_do(f, includeID, *args, **kwargs) + + def on_unregister_engine_do(self, f, includeID, *args, **kwargs): + return self.controller.on_unregister_engine_do(f, includeID, *args, **kwargs) + + def on_register_engine_do_not(self, f): + return self.controller.on_register_engine_do_not(f) + + def on_unregister_engine_do_not(self, f): + return self.controller.on_unregister_engine_do_not(f) + + def on_n_engines_registered_do(self, n, f, *args, **kwargs): + return self.controller.on_n_engines_registered_do(n, f, *args, **kwargs) diff --git a/IPython/kernel/core/__init__.py b/IPython/kernel/core/__init__.py new file mode 100644 index 0000000..0aad75d --- /dev/null +++ b/IPython/kernel/core/__init__.py @@ -0,0 +1,16 @@ +# encoding: utf-8 + +"""The IPython Core.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- \ No newline at end of file diff --git a/IPython/kernel/core/config/__init__.py b/IPython/kernel/core/config/__init__.py new file mode 100644 index 0000000..6f60906 --- /dev/null +++ b/IPython/kernel/core/config/__init__.py @@ -0,0 +1,25 @@ +# encoding: utf-8 + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from IPython.external.configobj import ConfigObj +from IPython.config.api import ConfigObjManager + +default_core_config = ConfigObj() +default_core_config['shell'] = dict( + shell_class = 'IPython.kernel.core.interpreter.Interpreter', + import_statement = '' +) + +config_manager = ConfigObjManager(default_core_config, 'IPython.kernel.core.ini') \ No newline at end of file diff --git a/IPython/kernel/core/display_formatter.py b/IPython/kernel/core/display_formatter.py new file mode 100644 index 0000000..55a04a9 --- /dev/null +++ b/IPython/kernel/core/display_formatter.py @@ -0,0 +1,70 @@ +# encoding: utf-8 + +"""Objects for replacing sys.displayhook().""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +class IDisplayFormatter(object): + """ Objects conforming to this interface will be responsible for formatting + representations of objects that pass through sys.displayhook() during an + interactive interpreter session. + """ + + # The kind of formatter. + kind = 'display' + + # The unique identifier for this formatter. + identifier = None + + + def __call__(self, obj): + """ Return a formatted representation of an object. + + Return None if one cannot return a representation in this format. + """ + + raise NotImplementedError + + +class ReprDisplayFormatter(IDisplayFormatter): + """ Return the repr() string representation of an object. + """ + + # The unique identifier for this formatter. + identifier = 'repr' + + + def __call__(self, obj): + """ Return a formatted representation of an object. + """ + + return repr(obj) + + +class PPrintDisplayFormatter(IDisplayFormatter): + """ Return a pretty-printed string representation of an object. + """ + + # The unique identifier for this formatter. + identifier = 'pprint' + + + def __call__(self, obj): + """ Return a formatted representation of an object. + """ + + import pprint + return pprint.pformat(obj) + + diff --git a/IPython/kernel/core/display_trap.py b/IPython/kernel/core/display_trap.py new file mode 100644 index 0000000..2d276c3 --- /dev/null +++ b/IPython/kernel/core/display_trap.py @@ -0,0 +1,100 @@ +# encoding: utf-8 + +"""Manager for replacing sys.displayhook().""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +# Standard library imports. +import sys + + + +class DisplayTrap(object): + """ Object to trap and format objects passing through sys.displayhook(). + + This trap maintains two lists of callables: formatters and callbacks. The + formatters take the *last* object that has gone through since the trap was + set and returns a string representation. Callbacks are executed on *every* + object that passes through the displayhook and does not return anything. + """ + + def __init__(self, formatters=None, callbacks=None): + # A list of formatters to apply. Each should be an instance conforming + # to the IDisplayFormatter interface. + if formatters is None: + formatters = [] + self.formatters = formatters + + # A list of callables, each of which should be executed *every* time an + # object passes through sys.displayhook(). + if callbacks is None: + callbacks = [] + self.callbacks = callbacks + + # The last object to pass through the displayhook. + self.obj = None + + # The previous hook before we replace it. + self.old_hook = None + + def hook(self, obj): + """ This method actually implements the hook. + """ + + # Run through the list of callbacks and trigger all of them. + for callback in self.callbacks: + callback(obj) + + # Store the object for formatting. + self.obj = obj + + def set(self): + """ Set the hook. + """ + + if sys.displayhook is not self.hook: + self.old_hook = sys.displayhook + sys.displayhook = self.hook + + def unset(self): + """ Unset the hook. + """ + + sys.displayhook = self.old_hook + + def clear(self): + """ Reset the stored object. + """ + + self.obj = None + + def add_to_message(self, message): + """ Add the formatted display of the objects to the message dictionary + being returned from the interpreter to its listeners. + """ + + # If there was no displayed object (or simply None), then don't add + # anything. + if self.obj is None: + return + + # Go through the list of formatters and let them add their formatting. + display = {} + for formatter in self.formatters: + representation = formatter(self.obj) + if representation is not None: + display[formatter.identifier] = representation + + message['display'] = display + diff --git a/IPython/kernel/core/error.py b/IPython/kernel/core/error.py new file mode 100644 index 0000000..50c2591 --- /dev/null +++ b/IPython/kernel/core/error.py @@ -0,0 +1,41 @@ +# encoding: utf-8 + +""" +error.py + +We declare here a class hierarchy for all exceptions produced by IPython, in +cases where we don't just raise one from the standard library. +""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + + +class IPythonError(Exception): + """Base exception that all of our exceptions inherit from. + + This can be raised by code that doesn't have any more specific + information.""" + + pass + +# Exceptions associated with the controller objects +class ControllerError(IPythonError): pass + +class ControllerCreationError(ControllerError): pass + + +# Exceptions associated with the Engines +class EngineError(IPythonError): pass + +class EngineCreationError(EngineError): pass diff --git a/IPython/kernel/core/history.py b/IPython/kernel/core/history.py new file mode 100644 index 0000000..1baa029 --- /dev/null +++ b/IPython/kernel/core/history.py @@ -0,0 +1,137 @@ +# encoding: utf-8 + +""" Manage the input and output history of the interpreter and the +frontend. + +There are 2 different history objects, one that lives in the interpreter, +and one that lives in the frontend. They are synced with a diff at each +execution of a command, as the interpreter history is a real stack, its +existing entries are not mutable. +""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from copy import copy + +# Local imports. +from util import InputList + + +############################################################################## +class History(object): + """ An object managing the input and output history. + """ + + def __init__(self, input_cache=None, output_cache=None): + + # Stuff that IPython adds to the namespace. + self.namespace_additions = dict( + _ = None, + __ = None, + ___ = None, + ) + + # A list to store input commands. + if input_cache is None: + input_cache =InputList([]) + self.input_cache = input_cache + + # A dictionary to store trapped output. + if output_cache is None: + output_cache = {} + self.output_cache = output_cache + + def get_history_item(self, index): + """ Returns the history string at index, where index is the + distance from the end (positive). + """ + if index>0 and index0 and index<(len(self.input_cache)-1): + return self.input_cache[-index] + + def get_input_cache(self): + return copy(self.input_cache) + + def get_input_after(self, index): + """ Returns the list of the commands entered after index. + """ + # We need to call directly list.__getslice__, because this object + # is not a real list. + return list.__getslice__(self.input_cache, index, + len(self.input_cache)) + + +############################################################################## +class FrontEndHistory(History): + """ An object managing the input and output history at the frontend. + It is used as a local cache to reduce network latency problems + and multiple users editing the same thing. + """ + + def add_items(self, item_list): + """ Adds the given command list to the stack of executed + commands. + """ + self.input_cache.extend(item_list) diff --git a/IPython/kernel/core/interpreter.py b/IPython/kernel/core/interpreter.py new file mode 100644 index 0000000..6e1e8cd --- /dev/null +++ b/IPython/kernel/core/interpreter.py @@ -0,0 +1,749 @@ +# encoding: utf-8 + +"""Central interpreter object for an IPython engine. + +The interpreter is the object whose job is to process lines of user input and +actually execute them in the user's namespace. +""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +# Standard library imports. +from compiler.ast import Discard +from types import FunctionType + +import __builtin__ +import codeop +import compiler +import pprint +import sys +import traceback + +# Local imports. +from IPython.kernel.core import ultraTB +from IPython.kernel.core.display_trap import DisplayTrap +from IPython.kernel.core.macro import Macro +from IPython.kernel.core.prompts import CachedOutput +from IPython.kernel.core.traceback_trap import TracebackTrap +from IPython.kernel.core.util import Bunch, system_shell +from IPython.external.Itpl import ItplNS + +# Global constants +COMPILER_ERROR = 'error' +INCOMPLETE_INPUT = 'incomplete' +COMPLETE_INPUT = 'complete' + +############################################################################## +# TEMPORARY!!! fake configuration, while we decide whether to use tconfig or +# not + +rc = Bunch() +rc.cache_size = 100 +rc.pprint = True +rc.separate_in = '\n' +rc.separate_out = '\n' +rc.separate_out2 = '' +rc.prompt_in1 = r'In [\#]: ' +rc.prompt_in2 = r' .\\D.: ' +rc.prompt_out = '' +rc.prompts_pad_left = False + +############################################################################## + +# Top-level utilities +def default_display_formatters(): + """ Return a list of default display formatters. + """ + + from display_formatter import PPrintDisplayFormatter, ReprDisplayFormatter + return [PPrintDisplayFormatter(), ReprDisplayFormatter()] + +def default_traceback_formatters(): + """ Return a list of default traceback formatters. + """ + + from traceback_formatter import PlainTracebackFormatter + return [PlainTracebackFormatter()] + +# Top-level classes +class NotDefined(object): pass + +class Interpreter(object): + """ An interpreter object. + + fixme: needs to negotiate available formatters with frontends. + + Important: the interpeter should be built so that it exposes a method + for each attribute/method of its sub-object. This way it can be + replaced by a network adapter. + """ + + def __init__(self, user_ns=None, global_ns=None,translator=None, + magic=None, display_formatters=None, + traceback_formatters=None, output_trap=None, history=None, + message_cache=None, filename='', config=None): + + # The local/global namespaces for code execution + local_ns = user_ns # compatibility name + if local_ns is None: + local_ns = {} + self.user_ns = local_ns + # The local namespace + if global_ns is None: + global_ns = {} + self.user_global_ns = global_ns + + # An object that will translate commands into executable Python. + # The current translator does not work properly so for now we are going + # without! + # if translator is None: + # from IPython.kernel.core.translator import IPythonTranslator + # translator = IPythonTranslator() + self.translator = translator + + # An object that maintains magic commands. + if magic is None: + from IPython.kernel.core.magic import Magic + magic = Magic(self) + self.magic = magic + + # A list of formatters for the displayhook. + if display_formatters is None: + display_formatters = default_display_formatters() + self.display_formatters = display_formatters + + # A list of formatters for tracebacks. + if traceback_formatters is None: + traceback_formatters = default_traceback_formatters() + self.traceback_formatters = traceback_formatters + + # The object trapping stdout/stderr. + if output_trap is None: + from IPython.kernel.core.output_trap import OutputTrap + output_trap = OutputTrap() + self.output_trap = output_trap + + # An object that manages the history. + if history is None: + from IPython.kernel.core.history import InterpreterHistory + history = InterpreterHistory() + self.history = history + self.get_history_item = history.get_history_item + self.get_history_input_cache = history.get_input_cache + self.get_history_input_after = history.get_input_after + + # An object that caches all of the return messages. + if message_cache is None: + from IPython.kernel.core.message_cache import SimpleMessageCache + message_cache = SimpleMessageCache() + self.message_cache = message_cache + + # The "filename" of the code that is executed in this interpreter. + self.filename = filename + + # An object that contains much configuration information. + if config is None: + # fixme: Move this constant elsewhere! + config = Bunch(ESC_MAGIC='%') + self.config = config + + # Hook managers. + # fixme: make the display callbacks configurable. In the meantime, + # enable macros. + self.display_trap = DisplayTrap( + formatters=self.display_formatters, + callbacks=[self._possible_macro], + ) + self.traceback_trap = TracebackTrap( + formatters=self.traceback_formatters) + + # This is used temporarily for reformating exceptions in certain + # cases. It will go away once the ultraTB stuff is ported + # to ipython1 + self.tbHandler = ultraTB.FormattedTB(color_scheme='NoColor', + mode='Context', + tb_offset=2) + + # An object that can compile commands and remember __future__ + # statements. + self.command_compiler = codeop.CommandCompiler() + + # A replacement for the raw_input() and input() builtins. Change these + # attributes later to configure them. + self.raw_input_builtin = raw_input + self.input_builtin = input + + # The number of the current cell. + self.current_cell_number = 1 + + # Initialize cache, set in/out prompts and printing system + self.outputcache = CachedOutput(self, + rc.cache_size, + rc.pprint, + input_sep = rc.separate_in, + output_sep = rc.separate_out, + output_sep2 = rc.separate_out2, + ps1 = rc.prompt_in1, + ps2 = rc.prompt_in2, + ps_out = rc.prompt_out, + pad_left = rc.prompts_pad_left) + + # Need to decide later if this is the right approach, but clients + # commonly use sys.ps1/2, so it may be best to just set them here + sys.ps1 = self.outputcache.prompt1.p_str + sys.ps2 = self.outputcache.prompt2.p_str + + # This is the message dictionary assigned temporarily when running the + # code. + self.message = None + + self.setup_namespace() + + + #### Public 'Interpreter' interface ######################################## + + def formatTraceback(self, et, ev, tb, message=''): + """Put a formatted version of the traceback into value and reraise. + + When exceptions have to be sent over the network, the traceback + needs to be put into the value of the exception in a nicely + formatted way. The method takes the type, value and tb of an + exception and puts a string representation of the tb into the + value of the exception and reraises it. + + Currently this method uses the ultraTb formatter from IPython trunk. + Eventually it should simply use the traceback formatters in core + that are loaded into self.tracback_trap.formatters. + """ + tbinfo = self.tbHandler.text(et,ev,tb) + ev._ipython_traceback_text = tbinfo + return et, ev, tb + + def execute(self, commands, raiseException=True): + """ Execute some IPython commands. + + 1. Translate them into Python. + 2. Run them. + 3. Trap stdout/stderr. + 4. Trap sys.displayhook(). + 5. Trap exceptions. + 6. Return a message object. + + Parameters + ---------- + commands : str + The raw commands that the user typed into the prompt. + + Returns + ------- + message : dict + The dictionary of responses. See the README.txt in this directory + for an explanation of the format. + """ + + # Create a message dictionary with all of the information we will be + # returning to the frontend and other listeners. + message = self.setup_message() + + # Massage the input and store the raw and translated commands into + # a dict. + user_input = dict(raw=commands) + if self.translator is not None: + python = self.translator(commands, message) + if python is None: + # Something went wrong with the translation. The translator + # should have added an appropriate entry to the message object. + return message + else: + python = commands + user_input['translated'] = python + message['input'] = user_input + + # Set the message object so that any magics executed in the code have + # access. + self.message = message + + # Set all of the output/exception traps. + self.set_traps() + + # Actually execute the Python code. + status = self.execute_python(python) + + # Unset all of the traps. + self.unset_traps() + + # Unset the message object. + self.message = None + + # Update the history variables in the namespace. + # E.g. In, Out, _, __, ___ + if self.history is not None: + self.history.update_history(self, python) + + # Let all of the traps contribute to the message and then clear their + # stored information. + self.output_trap.add_to_message(message) + self.output_trap.clear() + self.display_trap.add_to_message(message) + self.display_trap.clear() + self.traceback_trap.add_to_message(message) + # Pull out the type, value and tb of the current exception + # before clearing it. + einfo = self.traceback_trap.args + self.traceback_trap.clear() + + # Cache the message. + self.message_cache.add_message(self.current_cell_number, message) + + # Bump the number. + self.current_cell_number += 1 + + # This conditional lets the execute method either raise any + # exception that has occured in user code OR return the message + # dict containing the traceback and other useful info. + if raiseException and einfo: + raise einfo[0],einfo[1],einfo[2] + else: + return message + + def generate_prompt(self, is_continuation): + """Calculate and return a string with the prompt to display. + + :Parameters: + is_continuation : bool + Whether the input line is continuing multiline input or not, so + that a proper continuation prompt can be computed.""" + + if is_continuation: + return str(self.outputcache.prompt2) + else: + return str(self.outputcache.prompt1) + + def execute_python(self, python): + """ Actually run the Python code in the namespace. + + :Parameters: + + python : str + Pure, exec'able Python code. Special IPython commands should have + already been translated into pure Python. + """ + + # We use a CommandCompiler instance to compile the code so as to keep + # track of __future__ imports. + try: + commands = self.split_commands(python) + except (SyntaxError, IndentationError), e: + # Save the exc_info so compilation related exceptions can be + # reraised + self.traceback_trap.args = sys.exc_info() + self.pack_exception(self.message,e) + return None + + for cmd in commands: + try: + code = self.command_compiler(cmd, self.filename, 'single') + except (SyntaxError, OverflowError, ValueError), e: + self.traceback_trap.args = sys.exc_info() + self.pack_exception(self.message,e) + # No point in continuing if one block raised + return None + else: + self.execute_block(code) + + def execute_block(self,code): + """Execute a single block of code in the user namespace. + + Return value: a flag indicating whether the code to be run completed + successfully: + + - 0: successful execution. + - 1: an error occurred. + """ + + outflag = 1 # start by assuming error, success will reset it + try: + exec code in self.user_ns + outflag = 0 + except SystemExit: + self.resetbuffer() + self.traceback_trap.args = sys.exc_info() + except: + self.traceback_trap.args = sys.exc_info() + + return outflag + + def execute_macro(self, macro): + """ Execute the value of a macro. + + Parameters + ---------- + macro : Macro + """ + + python = macro.value + if self.translator is not None: + python = self.translator(python) + self.execute_python(python) + + def getCommand(self, i=None): + """Gets the ith message in the message_cache. + + This is implemented here for compatibility with the old ipython1 shell + I am not sure we need this though. I even seem to remember that we + were going to get rid of it. + """ + return self.message_cache.get_message(i) + + def reset(self): + """Reset the interpreter. + + Currently this only resets the users variables in the namespace. + In the future we might want to also reset the other stateful + things like that the Interpreter has, like In, Out, etc. + """ + self.user_ns.clear() + self.setup_namespace() + + def complete(self,line,text=None, pos=None): + """Complete the given text. + + :Parameters: + + text : str + Text fragment to be completed on. Typically this is + """ + # fixme: implement + raise NotImplementedError + + def push(self, ns): + """ Put value into the namespace with name key. + + Parameters + ---------- + **kwds + """ + + self.user_ns.update(ns) + + def push_function(self, ns): + # First set the func_globals for all functions to self.user_ns + new_kwds = {} + for k, v in ns.iteritems(): + if not isinstance(v, FunctionType): + raise TypeError("function object expected") + new_kwds[k] = FunctionType(v.func_code, self.user_ns) + self.user_ns.update(new_kwds) + + def pack_exception(self,message,exc): + message['exception'] = exc.__class__ + message['exception_value'] = \ + traceback.format_exception_only(exc.__class__, exc) + + def feed_block(self, source, filename='', symbol='single'): + """Compile some source in the interpreter. + + One several things can happen: + + 1) The input is incorrect; compile_command() raised an + exception (SyntaxError or OverflowError). + + 2) The input is incomplete, and more input is required; + compile_command() returned None. Nothing happens. + + 3) The input is complete; compile_command() returned a code + object. The code is executed by calling self.runcode() (which + also handles run-time exceptions, except for SystemExit). + + The return value is: + + - True in case 2 + + - False in the other cases, unless an exception is raised, where + None is returned instead. This can be used by external callers to + know whether to continue feeding input or not. + + The return value can be used to decide whether to use sys.ps1 or + sys.ps2 to prompt the next line.""" + + self.message = self.setup_message() + + try: + code = self.command_compiler(source,filename,symbol) + except (OverflowError, SyntaxError, IndentationError, ValueError ), e: + # Case 1 + self.traceback_trap.args = sys.exc_info() + self.pack_exception(self.message,e) + return COMPILER_ERROR,False + + if code is None: + # Case 2: incomplete input. This means that the input can span + # multiple lines. But we still need to decide when to actually + # stop taking user input. Later we'll add auto-indentation support + # somehow. In the meantime, we'll just stop if there are two lines + # of pure whitespace at the end. + last_two = source.rsplit('\n',2)[-2:] + print 'last two:',last_two # dbg + if len(last_two)==2 and all(s.isspace() for s in last_two): + return COMPLETE_INPUT,False + else: + return INCOMPLETE_INPUT, True + else: + # Case 3 + return COMPLETE_INPUT, False + + def pull(self, keys): + """ Get an item out of the namespace by key. + + Parameters + ---------- + key : str + + Returns + ------- + value : object + + Raises + ------ + TypeError if the key is not a string. + NameError if the object doesn't exist. + """ + + if isinstance(keys, str): + result = self.user_ns.get(keys, NotDefined()) + if isinstance(result, NotDefined): + raise NameError('name %s is not defined' % keys) + elif isinstance(keys, (list, tuple)): + result = [] + for key in keys: + if not isinstance(key, str): + raise TypeError("objects must be keyed by strings.") + else: + r = self.user_ns.get(key, NotDefined()) + if isinstance(r, NotDefined): + raise NameError('name %s is not defined' % key) + else: + result.append(r) + if len(keys)==1: + result = result[0] + else: + raise TypeError("keys must be a strong or a list/tuple of strings") + return result + + def pull_function(self, keys): + return self.pull(keys) + + #### Interactive user API ################################################## + + def ipsystem(self, command): + """ Execute a command in a system shell while expanding variables in the + current namespace. + + Parameters + ---------- + command : str + """ + + # Expand $variables. + command = self.var_expand(command) + + system_shell(command, + header='IPython system call: ', + verbose=self.rc.system_verbose, + ) + + def ipmagic(self, arg_string): + """ Call a magic function by name. + + ipmagic('name -opt foo bar') is equivalent to typing at the ipython + prompt: + + In[1]: %name -opt foo bar + + To call a magic without arguments, simply use ipmagic('name'). + + This provides a proper Python function to call IPython's magics in any + valid Python code you can type at the interpreter, including loops and + compound statements. It is added by IPython to the Python builtin + namespace upon initialization. + + Parameters + ---------- + arg_string : str + A string containing the name of the magic function to call and any + additional arguments to be passed to the magic. + + Returns + ------- + something : object + The return value of the actual object. + """ + + # Taken from IPython. + raise NotImplementedError('Not ported yet') + + args = arg_string.split(' ', 1) + magic_name = args[0] + magic_name = magic_name.lstrip(self.config.ESC_MAGIC) + + try: + magic_args = args[1] + except IndexError: + magic_args = '' + fn = getattr(self.magic, 'magic_'+magic_name, None) + if fn is None: + self.error("Magic function `%s` not found." % magic_name) + else: + magic_args = self.var_expand(magic_args) + return fn(magic_args) + + + #### Private 'Interpreter' interface ####################################### + + def setup_message(self): + """Return a message object. + + This method prepares and returns a message dictionary. This dict + contains the various fields that are used to transfer information about + execution, results, tracebacks, etc, to clients (either in or out of + process ones). Because of the need to work with possibly out of + process clients, this dict MUST contain strictly pickle-safe values. + """ + + return dict(number=self.current_cell_number) + + def setup_namespace(self): + """ Add things to the namespace. + """ + + self.user_ns.setdefault('__name__', '__main__') + self.user_ns.setdefault('__builtins__', __builtin__) + self.user_ns['__IP'] = self + if self.raw_input_builtin is not None: + self.user_ns['raw_input'] = self.raw_input_builtin + if self.input_builtin is not None: + self.user_ns['input'] = self.input_builtin + + builtin_additions = dict( + ipmagic=self.ipmagic, + ) + __builtin__.__dict__.update(builtin_additions) + + if self.history is not None: + self.history.setup_namespace(self.user_ns) + + def set_traps(self): + """ Set all of the output, display, and traceback traps. + """ + + self.output_trap.set() + self.display_trap.set() + self.traceback_trap.set() + + def unset_traps(self): + """ Unset all of the output, display, and traceback traps. + """ + + self.output_trap.unset() + self.display_trap.unset() + self.traceback_trap.unset() + + def split_commands(self, python): + """ Split multiple lines of code into discrete commands that can be + executed singly. + + Parameters + ---------- + python : str + Pure, exec'able Python code. + + Returns + ------- + commands : list of str + Separate commands that can be exec'ed independently. + """ + + # compiler.parse treats trailing spaces after a newline as a + # SyntaxError. This is different than codeop.CommandCompiler, which + # will compile the trailng spaces just fine. We simply strip any + # trailing whitespace off. Passing a string with trailing whitespace + # to exec will fail however. There seems to be some inconsistency in + # how trailing whitespace is handled, but this seems to work. + python = python.strip() + + # The compiler module will parse the code into an abstract syntax tree. + ast = compiler.parse(python) + + # Uncomment to help debug the ast tree + # for n in ast.node: + # print n.lineno,'->',n + + # Each separate command is available by iterating over ast.node. The + # lineno attribute is the line number (1-indexed) beginning the commands + # suite. + # lines ending with ";" yield a Discard Node that doesn't have a lineno + # attribute. These nodes can and should be discarded. But there are + # other situations that cause Discard nodes that shouldn't be discarded. + # We might eventually discover other cases where lineno is None and have + # to put in a more sophisticated test. + linenos = [x.lineno-1 for x in ast.node if x.lineno is not None] + + # When we finally get the slices, we will need to slice all the way to + # the end even though we don't have a line number for it. Fortunately, + # None does the job nicely. + linenos.append(None) + lines = python.splitlines() + + # Create a list of atomic commands. + cmds = [] + for i, j in zip(linenos[:-1], linenos[1:]): + cmd = lines[i:j] + if cmd: + cmds.append('\n'.join(cmd)+'\n') + + return cmds + + def error(self, text): + """ Pass an error message back to the shell. + + Preconditions + ------------- + This should only be called when self.message is set. In other words, + when code is being executed. + + Parameters + ---------- + text : str + """ + + errors = self.message.get('IPYTHON_ERROR', []) + errors.append(text) + + def var_expand(self, template): + """ Expand $variables in the current namespace using Itpl. + + Parameters + ---------- + template : str + """ + + return str(ItplNS(template, self.user_ns)) + + def _possible_macro(self, obj): + """ If the object is a macro, execute it. + """ + + if isinstance(obj, Macro): + self.execute_macro(obj) + diff --git a/IPython/kernel/core/macro.py b/IPython/kernel/core/macro.py new file mode 100644 index 0000000..56d2b44 --- /dev/null +++ b/IPython/kernel/core/macro.py @@ -0,0 +1,34 @@ +# encoding: utf-8 + +"""Support for interactive macros in IPython""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +class Macro: + """Simple class to store the value of macros as strings. + + This allows us to later exec them by checking when something is an + instance of this class.""" + + def __init__(self,data): + + # store the macro value, as a single string which can be evaluated by + # runlines() + self.value = ''.join(data).rstrip()+'\n' + + def __str__(self): + return self.value + + def __repr__(self): + return 'IPython.macro.Macro(%s)' % repr(self.value) \ No newline at end of file diff --git a/IPython/kernel/core/magic.py b/IPython/kernel/core/magic.py new file mode 100644 index 0000000..967ce73 --- /dev/null +++ b/IPython/kernel/core/magic.py @@ -0,0 +1,147 @@ +# encoding: utf-8 + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import os +import __builtin__ + +# Local imports. +from util import Bunch + + +# fixme: RTK thinks magics should be implemented as separate classes rather than +# methods on a single class. This would give us the ability to plug new magics +# in and configure them separately. + +class Magic(object): + """ An object that maintains magic functions. + """ + + def __init__(self, interpreter, config=None): + # A reference to the interpreter. + self.interpreter = interpreter + + # A reference to the configuration object. + if config is None: + # fixme: we need a better place to store this information. + config = Bunch(ESC_MAGIC='%') + self.config = config + + def has_magic(self, name): + """ Return True if this object provides a given magic. + + Parameters + ---------- + name : str + """ + + return hasattr(self, 'magic_' + name) + + def object_find(self, name): + """ Find an object in the available namespaces. + + fixme: this should probably be moved elsewhere. The interpreter? + """ + + name = name.strip() + + # Namespaces to search. + # fixme: implement internal and alias namespaces. + user_ns = self.interpreter.user_ns + internal_ns = {} + builtin_ns = __builtin__.__dict__ + alias_ns = {} + + # Order the namespaces. + namespaces = [ + ('Interactive', user_ns), + ('IPython internal', internal_ns), + ('Python builtin', builtin_ns), + ('Alias', alias_ns), + ] + + # Initialize all results. + found = False + obj = None + space = None + ds = None + ismagic = False + isalias = False + + # Look for the given name by splitting it in parts. If the head is + # found, then we look for all the remaining parts as members, and only + # declare success if we can find them all. + parts = name.split('.') + head, rest = parts[0], parts[1:] + for nsname, ns in namespaces: + try: + obj = ns[head] + except KeyError: + continue + else: + for part in rest: + try: + obj = getattr(obj, part) + except: + # Blanket except b/c some badly implemented objects + # allow __getattr__ to raise exceptions other than + # AttributeError, which then crashes us. + break + else: + # If we finish the for loop (no break), we got all members + found = True + space = nsname + isalias = (ns == alias_ns) + break # namespace loop + + # Try to see if it is a magic. + if not found: + if name.startswith(self.config.ESC_MAGIC): + name = name[1:] + obj = getattr(self, 'magic_' + name, None) + if obj is not None: + found = True + space = 'IPython internal' + ismagic = True + + # Last try: special-case some literals like '', [], {}, etc: + if not found and head in ["''", '""', '[]', '{}', '()']: + obj = eval(head) + found = True + space = 'Interactive' + + return dict( + found=found, + obj=obj, + namespace=space, + ismagic=ismagic, + isalias=isalias, + ) + + + + + + def magic_pwd(self, parameter_s=''): + """ Return the current working directory path. + """ + return os.getcwd() + + def magic_env(self, parameter_s=''): + """ List environment variables. + """ + + return os.environ.data + + diff --git a/IPython/kernel/core/message_cache.py b/IPython/kernel/core/message_cache.py new file mode 100644 index 0000000..f0385f5 --- /dev/null +++ b/IPython/kernel/core/message_cache.py @@ -0,0 +1,98 @@ +# encoding: utf-8 + +"""Storage for the responses from the interpreter.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + + +class IMessageCache(object): + """ Storage for the response from the interpreter. + """ + + def add_message(self, i, message): + """ Add a message dictionary to the cache. + + Parameters + ---------- + i : int + message : dict + """ + + def get_message(self, i=None): + """ Get the message from the cache. + + Parameters + ---------- + i : int, optional + The number of the message. If not provided, return the + highest-numbered message. + + Returns + ------- + message : dict + + Raises + ------ + IndexError if the message does not exist in the cache. + """ + + +class SimpleMessageCache(object): + """ Simple dictionary-based, in-memory storage of the responses from the + interpreter. + """ + + def __init__(self): + self.cache = {} + + def add_message(self, i, message): + """ Add a message dictionary to the cache. + + Parameters + ---------- + i : int + message : dict + """ + + self.cache[i] = message + + def get_message(self, i=None): + """ Get the message from the cache. + + Parameters + ---------- + i : int, optional + The number of the message. If not provided, return the + highest-numbered message. + + Returns + ------- + message : dict + + Raises + ------ + IndexError if the message does not exist in the cache. + """ + if i is None: + keys = self.cache.keys() + if len(keys) == 0: + raise IndexError("index %r out of range" % i) + else: + i = max(self.cache.keys()) + try: + return self.cache[i] + except KeyError: + # IndexError is more appropriate, here. + raise IndexError("index %r out of range" % i) + diff --git a/IPython/kernel/core/output_trap.py b/IPython/kernel/core/output_trap.py new file mode 100644 index 0000000..8e65662 --- /dev/null +++ b/IPython/kernel/core/output_trap.py @@ -0,0 +1,99 @@ +# encoding: utf-8 + +""" Trap stdout/stderr.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import sys +from cStringIO import StringIO + + +class OutputTrap(object): + """ Object which can trap text sent to stdout and stderr. + """ + + def __init__(self): + # Filelike objects to store stdout/stderr text. + self.out = StringIO() + self.err = StringIO() + + # Boolean to check if the stdout/stderr hook is set. + self.out_set = False + self.err_set = False + + @property + def out_text(self): + """ Return the text currently in the stdout buffer. + """ + return self.out.getvalue() + + @property + def err_text(self): + """ Return the text currently in the stderr buffer. + """ + return self.err.getvalue() + + def set(self): + """ Set the hooks. + """ + + if sys.stdout is not self.out: + self._out_save = sys.stdout + sys.stdout = self.out + self.out_set = True + + if sys.stderr is not self.err: + self._err_save = sys.stderr + sys.stderr = self.err + self.err_set = True + + def unset(self): + """ Remove the hooks. + """ + + sys.stdout = self._out_save + self.out_set = False + + sys.stderr = self._err_save + self.err_set = False + + def clear(self): + """ Clear out the buffers. + """ + + self.out.close() + self.out = StringIO() + + self.err.close() + self.err = StringIO() + + def add_to_message(self, message): + """ Add the text from stdout and stderr to the message from the + interpreter to its listeners. + + Parameters + ---------- + message : dict + """ + + out_text = self.out_text + if out_text: + message['stdout'] = out_text + + err_text = self.err_text + if err_text: + message['stderr'] = err_text + + + diff --git a/IPython/kernel/core/prompts.py b/IPython/kernel/core/prompts.py new file mode 100644 index 0000000..4c572dc --- /dev/null +++ b/IPython/kernel/core/prompts.py @@ -0,0 +1,591 @@ +# encoding: utf-8 + +"""Classes for handling input/output prompts.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from IPython import Release +__author__ = '%s <%s>' % Release.authors['Fernando'] +__license__ = Release.license +__version__ = Release.version + +#**************************************************************************** +# Required modules +import __builtin__ +import os +import socket +import sys +import time + +# IPython's own +from IPython.external.Itpl import ItplNS +from macro import Macro + +# Temporarily use this until it is ported to ipython1 + +from IPython import ColorANSI +from IPython.ipstruct import Struct +from IPython.genutils import * +from IPython.ipapi import TryNext + +#**************************************************************************** +#Color schemes for Prompts. + +PromptColors = ColorANSI.ColorSchemeTable() +InputColors = ColorANSI.InputTermColors # just a shorthand +Colors = ColorANSI.TermColors # just a shorthand + + +__PColNoColor = ColorANSI.ColorScheme( + 'NoColor', + in_prompt = InputColors.NoColor, # Input prompt + in_number = InputColors.NoColor, # Input prompt number + in_prompt2 = InputColors.NoColor, # Continuation prompt + in_normal = InputColors.NoColor, # color off (usu. Colors.Normal) + + out_prompt = Colors.NoColor, # Output prompt + out_number = Colors.NoColor, # Output prompt number + + normal = Colors.NoColor # color off (usu. Colors.Normal) + ) + +PromptColors.add_scheme(__PColNoColor) + +# make some schemes as instances so we can copy them for modification easily: +__PColLinux = __PColNoColor.copy('Linux') +# Don't forget to enter it into the table! +PromptColors.add_scheme(__PColLinux) +__PColLightBG = __PColLinux.copy('LightBG') +PromptColors.add_scheme(__PColLightBG) + +del Colors,InputColors + +#----------------------------------------------------------------------------- +def multiple_replace(dict, text): + """ Replace in 'text' all occurences of any key in the given + dictionary by its corresponding value. Returns the new string.""" + + # Function by Xavier Defrang, originally found at: + # http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/81330 + + # Create a regular expression from the dictionary keys + regex = re.compile("(%s)" % "|".join(map(re.escape, dict.keys()))) + # For each match, look-up corresponding value in dictionary + return regex.sub(lambda mo: dict[mo.string[mo.start():mo.end()]], text) + +#----------------------------------------------------------------------------- +# Special characters that can be used in prompt templates, mainly bash-like + +# If $HOME isn't defined (Windows), make it an absurd string so that it can +# never be expanded out into '~'. Basically anything which can never be a +# reasonable directory name will do, we just want the $HOME -> '~' operation +# to become a no-op. We pre-compute $HOME here so it's not done on every +# prompt call. + +# FIXME: + +# - This should be turned into a class which does proper namespace management, +# since the prompt specials need to be evaluated in a certain namespace. +# Currently it's just globals, which need to be managed manually by code +# below. + +# - I also need to split up the color schemes from the prompt specials +# somehow. I don't have a clean design for that quite yet. + +HOME = os.environ.get("HOME","//////:::::ZZZZZ,,,~~~") + +# We precompute a few more strings here for the prompt_specials, which are +# fixed once ipython starts. This reduces the runtime overhead of computing +# prompt strings. +USER = os.environ.get("USER") +HOSTNAME = socket.gethostname() +HOSTNAME_SHORT = HOSTNAME.split(".")[0] +ROOT_SYMBOL = "$#"[os.name=='nt' or os.getuid()==0] + +prompt_specials_color = { + # Prompt/history count + '%n' : '${self.col_num}' '${self.cache.prompt_count}' '${self.col_p}', + r'\#': '${self.col_num}' '${self.cache.prompt_count}' '${self.col_p}', + # Just the prompt counter number, WITHOUT any coloring wrappers, so users + # can get numbers displayed in whatever color they want. + r'\N': '${self.cache.prompt_count}', + # Prompt/history count, with the actual digits replaced by dots. Used + # mainly in continuation prompts (prompt_in2) + r'\D': '${"."*len(str(self.cache.prompt_count))}', + # Current working directory + r'\w': '${os.getcwd()}', + # Current time + r'\t' : '${time.strftime("%H:%M:%S")}', + # Basename of current working directory. + # (use os.sep to make this portable across OSes) + r'\W' : '${os.getcwd().split("%s")[-1]}' % os.sep, + # These X are an extension to the normal bash prompts. They return + # N terms of the path, after replacing $HOME with '~' + r'\X0': '${os.getcwd().replace("%s","~")}' % HOME, + r'\X1': '${self.cwd_filt(1)}', + r'\X2': '${self.cwd_filt(2)}', + r'\X3': '${self.cwd_filt(3)}', + r'\X4': '${self.cwd_filt(4)}', + r'\X5': '${self.cwd_filt(5)}', + # Y are similar to X, but they show '~' if it's the directory + # N+1 in the list. Somewhat like %cN in tcsh. + r'\Y0': '${self.cwd_filt2(0)}', + r'\Y1': '${self.cwd_filt2(1)}', + r'\Y2': '${self.cwd_filt2(2)}', + r'\Y3': '${self.cwd_filt2(3)}', + r'\Y4': '${self.cwd_filt2(4)}', + r'\Y5': '${self.cwd_filt2(5)}', + # Hostname up to first . + r'\h': HOSTNAME_SHORT, + # Full hostname + r'\H': HOSTNAME, + # Username of current user + r'\u': USER, + # Escaped '\' + '\\\\': '\\', + # Newline + r'\n': '\n', + # Carriage return + r'\r': '\r', + # Release version + r'\v': __version__, + # Root symbol ($ or #) + r'\$': ROOT_SYMBOL, + } + +# A copy of the prompt_specials dictionary but with all color escapes removed, +# so we can correctly compute the prompt length for the auto_rewrite method. +prompt_specials_nocolor = prompt_specials_color.copy() +prompt_specials_nocolor['%n'] = '${self.cache.prompt_count}' +prompt_specials_nocolor[r'\#'] = '${self.cache.prompt_count}' + +# Add in all the InputTermColors color escapes as valid prompt characters. +# They all get added as \\C_COLORNAME, so that we don't have any conflicts +# with a color name which may begin with a letter used by any other of the +# allowed specials. This of course means that \\C will never be allowed for +# anything else. +input_colors = ColorANSI.InputTermColors +for _color in dir(input_colors): + if _color[0] != '_': + c_name = r'\C_'+_color + prompt_specials_color[c_name] = getattr(input_colors,_color) + prompt_specials_nocolor[c_name] = '' + +# we default to no color for safety. Note that prompt_specials is a global +# variable used by all prompt objects. +prompt_specials = prompt_specials_nocolor + +#----------------------------------------------------------------------------- +def str_safe(arg): + """Convert to a string, without ever raising an exception. + + If str(arg) fails, is returned, where ... is the exception + error message.""" + + try: + out = str(arg) + except UnicodeError: + try: + out = arg.encode('utf_8','replace') + except Exception,msg: + # let's keep this little duplication here, so that the most common + # case doesn't suffer from a double try wrapping. + out = '' % msg + except Exception,msg: + out = '' % msg + return out + +class BasePrompt(object): + """Interactive prompt similar to Mathematica's.""" + + def _get_p_template(self): + return self._p_template + + def _set_p_template(self,val): + self._p_template = val + self.set_p_str() + + p_template = property(_get_p_template,_set_p_template, + doc='Template for prompt string creation') + + def __init__(self,cache,sep,prompt,pad_left=False): + + # Hack: we access information about the primary prompt through the + # cache argument. We need this, because we want the secondary prompt + # to be aligned with the primary one. Color table info is also shared + # by all prompt classes through the cache. Nice OO spaghetti code! + self.cache = cache + self.sep = sep + + # regexp to count the number of spaces at the end of a prompt + # expression, useful for prompt auto-rewriting + self.rspace = re.compile(r'(\s*)$') + # Flag to left-pad prompt strings to match the length of the primary + # prompt + self.pad_left = pad_left + + # Set template to create each actual prompt (where numbers change). + # Use a property + self.p_template = prompt + self.set_p_str() + + def set_p_str(self): + """ Set the interpolating prompt strings. + + This must be called every time the color settings change, because the + prompt_specials global may have changed.""" + + import os,time # needed in locals for prompt string handling + loc = locals() + self.p_str = ItplNS('%s%s%s' % + ('${self.sep}${self.col_p}', + multiple_replace(prompt_specials, self.p_template), + '${self.col_norm}'),self.cache.user_ns,loc) + + self.p_str_nocolor = ItplNS(multiple_replace(prompt_specials_nocolor, + self.p_template), + self.cache.user_ns,loc) + + def write(self,msg): # dbg + sys.stdout.write(msg) + return '' + + def __str__(self): + """Return a string form of the prompt. + + This for is useful for continuation and output prompts, since it is + left-padded to match lengths with the primary one (if the + self.pad_left attribute is set).""" + + out_str = str_safe(self.p_str) + if self.pad_left: + # We must find the amount of padding required to match lengths, + # taking the color escapes (which are invisible on-screen) into + # account. + esc_pad = len(out_str) - len(str_safe(self.p_str_nocolor)) + format = '%%%ss' % (len(str(self.cache.last_prompt))+esc_pad) + return format % out_str + else: + return out_str + + # these path filters are put in as methods so that we can control the + # namespace where the prompt strings get evaluated + def cwd_filt(self,depth): + """Return the last depth elements of the current working directory. + + $HOME is always replaced with '~'. + If depth==0, the full path is returned.""" + + cwd = os.getcwd().replace(HOME,"~") + out = os.sep.join(cwd.split(os.sep)[-depth:]) + if out: + return out + else: + return os.sep + + def cwd_filt2(self,depth): + """Return the last depth elements of the current working directory. + + $HOME is always replaced with '~'. + If depth==0, the full path is returned.""" + + cwd = os.getcwd().replace(HOME,"~").split(os.sep) + if '~' in cwd and len(cwd) == depth+1: + depth += 1 + out = os.sep.join(cwd[-depth:]) + if out: + return out + else: + return os.sep + +class Prompt1(BasePrompt): + """Input interactive prompt similar to Mathematica's.""" + + def __init__(self,cache,sep='\n',prompt='In [\\#]: ',pad_left=True): + BasePrompt.__init__(self,cache,sep,prompt,pad_left) + + def set_colors(self): + self.set_p_str() + Colors = self.cache.color_table.active_colors # shorthand + self.col_p = Colors.in_prompt + self.col_num = Colors.in_number + self.col_norm = Colors.in_normal + # We need a non-input version of these escapes for the '--->' + # auto-call prompts used in the auto_rewrite() method. + self.col_p_ni = self.col_p.replace('\001','').replace('\002','') + self.col_norm_ni = Colors.normal + + def __str__(self): + self.cache.prompt_count += 1 + self.cache.last_prompt = str_safe(self.p_str_nocolor).split('\n')[-1] + return str_safe(self.p_str) + + def auto_rewrite(self): + """Print a string of the form '--->' which lines up with the previous + input string. Useful for systems which re-write the user input when + handling automatically special syntaxes.""" + + curr = str(self.cache.last_prompt) + nrspaces = len(self.rspace.search(curr).group()) + return '%s%s>%s%s' % (self.col_p_ni,'-'*(len(curr)-nrspaces-1), + ' '*nrspaces,self.col_norm_ni) + +class PromptOut(BasePrompt): + """Output interactive prompt similar to Mathematica's.""" + + def __init__(self,cache,sep='',prompt='Out[\\#]: ',pad_left=True): + BasePrompt.__init__(self,cache,sep,prompt,pad_left) + if not self.p_template: + self.__str__ = lambda: '' + + def set_colors(self): + self.set_p_str() + Colors = self.cache.color_table.active_colors # shorthand + self.col_p = Colors.out_prompt + self.col_num = Colors.out_number + self.col_norm = Colors.normal + +class Prompt2(BasePrompt): + """Interactive continuation prompt.""" + + def __init__(self,cache,prompt=' .\\D.: ',pad_left=True): + self.cache = cache + self.p_template = prompt + self.pad_left = pad_left + self.set_p_str() + + def set_p_str(self): + import os,time # needed in locals for prompt string handling + loc = locals() + self.p_str = ItplNS('%s%s%s' % + ('${self.col_p2}', + multiple_replace(prompt_specials, self.p_template), + '$self.col_norm'), + self.cache.user_ns,loc) + self.p_str_nocolor = ItplNS(multiple_replace(prompt_specials_nocolor, + self.p_template), + self.cache.user_ns,loc) + + def set_colors(self): + self.set_p_str() + Colors = self.cache.color_table.active_colors + self.col_p2 = Colors.in_prompt2 + self.col_norm = Colors.in_normal + # FIXME (2004-06-16) HACK: prevent crashes for users who haven't + # updated their prompt_in2 definitions. Remove eventually. + self.col_p = Colors.out_prompt + self.col_num = Colors.out_number + + +#----------------------------------------------------------------------------- +class CachedOutput: + """Class for printing output from calculations while keeping a cache of + reults. It dynamically creates global variables prefixed with _ which + contain these results. + + Meant to be used as a sys.displayhook replacement, providing numbered + prompts and cache services. + + Initialize with initial and final values for cache counter (this defines + the maximum size of the cache.""" + + def __init__(self,shell,cache_size,Pprint, + colors='NoColor',input_sep='\n', + output_sep='\n',output_sep2='', + ps1 = None, ps2 = None,ps_out = None,pad_left=True): + + cache_size_min = 3 + if cache_size <= 0: + self.do_full_cache = 0 + cache_size = 0 + elif cache_size < cache_size_min: + self.do_full_cache = 0 + cache_size = 0 + warn('caching was disabled (min value for cache size is %s).' % + cache_size_min,level=3) + else: + self.do_full_cache = 1 + + self.cache_size = cache_size + self.input_sep = input_sep + + # we need a reference to the user-level namespace + self.shell = shell + self.user_ns = shell.user_ns + # and to the user's input + self.input_hist = shell.history.input_cache + + # Set input prompt strings and colors + if cache_size == 0: + if ps1.find('%n') > -1 or ps1.find(r'\#') > -1 \ + or ps1.find(r'\N') > -1: + ps1 = '>>> ' + if ps2.find('%n') > -1 or ps2.find(r'\#') > -1 \ + or ps2.find(r'\N') > -1: + ps2 = '... ' + self.ps1_str = self._set_prompt_str(ps1,'In [\\#]: ','>>> ') + self.ps2_str = self._set_prompt_str(ps2,' .\\D.: ','... ') + self.ps_out_str = self._set_prompt_str(ps_out,'Out[\\#]: ','') + + self.color_table = PromptColors + self.prompt1 = Prompt1(self,sep=input_sep,prompt=self.ps1_str, + pad_left=pad_left) + self.prompt2 = Prompt2(self,prompt=self.ps2_str,pad_left=pad_left) + self.prompt_out = PromptOut(self,sep='',prompt=self.ps_out_str, + pad_left=pad_left) + self.set_colors(colors) + + # other more normal stuff + # b/c each call to the In[] prompt raises it by 1, even the first. + self.prompt_count = 0 + # Store the last prompt string each time, we need it for aligning + # continuation and auto-rewrite prompts + self.last_prompt = '' + self.Pprint = Pprint + self.output_sep = output_sep + self.output_sep2 = output_sep2 + self._,self.__,self.___ = '','','' + self.pprint_types = map(type,[(),[],{}]) + + # these are deliberately global: + to_user_ns = {'_':self._,'__':self.__,'___':self.___} + self.user_ns.update(to_user_ns) + + def _set_prompt_str(self,p_str,cache_def,no_cache_def): + if p_str is None: + if self.do_full_cache: + return cache_def + else: + return no_cache_def + else: + return p_str + + def set_colors(self,colors): + """Set the active color scheme and configure colors for the three + prompt subsystems.""" + + # FIXME: the prompt_specials global should be gobbled inside this + # class instead. Do it when cleaning up the whole 3-prompt system. + global prompt_specials + if colors.lower()=='nocolor': + prompt_specials = prompt_specials_nocolor + else: + prompt_specials = prompt_specials_color + + self.color_table.set_active_scheme(colors) + self.prompt1.set_colors() + self.prompt2.set_colors() + self.prompt_out.set_colors() + + def __call__(self,arg=None): + """Printing with history cache management. + + This is invoked everytime the interpreter needs to print, and is + activated by setting the variable sys.displayhook to it.""" + + # If something injected a '_' variable in __builtin__, delete + # ipython's automatic one so we don't clobber that. gettext() in + # particular uses _, so we need to stay away from it. + if '_' in __builtin__.__dict__: + try: + del self.user_ns['_'] + except KeyError: + pass + if arg is not None: + cout_write = Term.cout.write # fast lookup + # first handle the cache and counters + + # do not print output if input ends in ';' + if self.input_hist[self.prompt_count].endswith(';\n'): + return + # don't use print, puts an extra space + cout_write(self.output_sep) + outprompt = self.shell.hooks.generate_output_prompt() + if self.do_full_cache: + cout_write(outprompt) + + # and now call a possibly user-defined print mechanism + manipulated_val = self.display(arg) + + # user display hooks can change the variable to be stored in + # output history + + if manipulated_val is not None: + arg = manipulated_val + + # avoid recursive reference when displaying _oh/Out + if arg is not self.user_ns['_oh']: + self.update(arg) + + cout_write(self.output_sep2) + Term.cout.flush() + + def _display(self,arg): + """Default printer method, uses pprint. + + Do ip.set_hook("result_display", my_displayhook) for custom result + display, e.g. when your own objects need special formatting. + """ + try: + return IPython.generics.result_display(arg) + except TryNext: + return self.shell.hooks.result_display(arg) + + # Assign the default display method: + display = _display + + def update(self,arg): + #print '***cache_count', self.cache_count # dbg + if len(self.user_ns['_oh']) >= self.cache_size and self.do_full_cache: + warn('Output cache limit (currently '+ + `self.cache_size`+' entries) hit.\n' + 'Flushing cache and resetting history counter...\n' + 'The only history variables available will be _,__,___ and _1\n' + 'with the current result.') + + self.flush() + # Don't overwrite '_' and friends if '_' is in __builtin__ (otherwise + # we cause buggy behavior for things like gettext). + if '_' not in __builtin__.__dict__: + self.___ = self.__ + self.__ = self._ + self._ = arg + self.user_ns.update({'_':self._,'__':self.__,'___':self.___}) + + # hackish access to top-level namespace to create _1,_2... dynamically + to_main = {} + if self.do_full_cache: + new_result = '_'+`self.prompt_count` + to_main[new_result] = arg + self.user_ns.update(to_main) + self.user_ns['_oh'][self.prompt_count] = arg + + def flush(self): + if not self.do_full_cache: + raise ValueError,"You shouldn't have reached the cache flush "\ + "if full caching is not enabled!" + # delete auto-generated vars from global namespace + + for n in range(1,self.prompt_count + 1): + key = '_'+`n` + try: + del self.user_ns[key] + except: pass + self.user_ns['_oh'].clear() + + if '_' not in __builtin__.__dict__: + self.user_ns.update({'_':None,'__':None, '___':None}) + import gc + gc.collect() # xxx needed? + diff --git a/IPython/kernel/core/shell.py b/IPython/kernel/core/shell.py new file mode 100644 index 0000000..eb266ae --- /dev/null +++ b/IPython/kernel/core/shell.py @@ -0,0 +1,357 @@ +# encoding: utf-8 +# -*- test-case-name: IPython.test.test_shell -*- + +"""The core IPython Shell""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import pprint +import signal +import sys +import threading +import time + +from code import InteractiveConsole, softspace +from StringIO import StringIO + +from IPython.OutputTrap import OutputTrap +from IPython import ultraTB + +from IPython.kernel.error import NotDefined + + +class InteractiveShell(InteractiveConsole): + """The Basic IPython Shell class. + + This class provides the basic capabilities of IPython. Currently + this class does not do anything IPython specific. That is, it is + just a python shell. + + It is modelled on code.InteractiveConsole, but adds additional + capabilities. These additional capabilities are what give IPython + its power. + + The current version of this class is meant to be a prototype that guides + the future design of the IPython core. This class must not use Twisted + in any way, but it must be designed in a way that makes it easy to + incorporate into Twisted and hook network protocols up to. + + Some of the methods of this class comprise the official IPython core + interface. These methods must be tread safe and they must return types + that can be easily serialized by protocols such as PB, XML-RPC and SOAP. + Locks have been provided for making the methods thread safe, but additional + locks can be added as needed. + + Any method that is meant to be a part of the official interface must also + be declared in the kernel.coreservice.ICoreService interface. Eventually + all other methods should have single leading underscores to note that they + are not designed to be 'public.' Currently, because this class inherits + from code.InteractiveConsole there are many private methods w/o leading + underscores. The interface should be as simple as possible and methods + should not be added to the interface unless they really need to be there. + + Note: + + - For now I am using methods named put/get to move objects in/out of the + users namespace. Originally, I was calling these methods push/pull, but + because code.InteractiveConsole already has a push method, I had to use + something different. Eventually, we probably won't subclass this class + so we can call these methods whatever we want. So, what do we want to + call them? + - We need a way of running the trapping of stdout/stderr in different ways. + We should be able to i) trap, ii) not trap at all or iii) trap and echo + things to stdout and stderr. + - How should errors be handled? Should exceptions be raised? + - What should methods that don't compute anything return? The default of + None? + """ + + def __init__(self, locals=None, filename=""): + """Creates a new TrappingInteractiveConsole object.""" + InteractiveConsole.__init__(self,locals,filename) + self._trap = OutputTrap(debug=0) + self._stdin = [] + self._stdout = [] + self._stderr = [] + self._last_type = self._last_traceback = self._last_value = None + #self._namespace_lock = threading.Lock() + #self._command_lock = threading.Lock() + self.lastCommandIndex = -1 + # I am using this user defined signal to interrupt the currently + # running command. I am not sure if this is the best way, but + # it is working! + # This doesn't work on Windows as it doesn't have this signal. + #signal.signal(signal.SIGUSR1, self._handleSIGUSR1) + + # An exception handler. Experimental: later we need to make the + # modes/colors available to user configuration, etc. + self.tbHandler = ultraTB.FormattedTB(color_scheme='NoColor', + mode='Context', + tb_offset=2) + + def _handleSIGUSR1(self, signum, frame): + """Handle the SIGUSR1 signal by printing to stderr.""" + print>>sys.stderr, "Command stopped." + + def _prefilter(self, line, more): + return line + + def _trapRunlines(self, lines): + """ + This executes the python source code, source, in the + self.locals namespace and traps stdout and stderr. Upon + exiting, self.out and self.err contain the values of + stdout and stderr for the last executed command only. + """ + + # Execute the code + #self._namespace_lock.acquire() + self._trap.flush() + self._trap.trap() + self._runlines(lines) + self.lastCommandIndex += 1 + self._trap.release() + #self._namespace_lock.release() + + # Save stdin, stdout and stderr to lists + #self._command_lock.acquire() + self._stdin.append(lines) + self._stdout.append(self.prune_output(self._trap.out.getvalue())) + self._stderr.append(self.prune_output(self._trap.err.getvalue())) + #self._command_lock.release() + + def prune_output(self, s): + """Only return the first and last 1600 chars of stdout and stderr. + + Something like this is required to make sure that the engine and + controller don't become overwhelmed by the size of stdout/stderr. + """ + if len(s) > 3200: + return s[:1600] + '\n............\n' + s[-1600:] + else: + return s + + # Lifted from iplib.InteractiveShell + def _runlines(self,lines): + """Run a string of one or more lines of source. + + This method is capable of running a string containing multiple source + lines, as if they had been entered at the IPython prompt. Since it + exposes IPython's processing machinery, the given strings can contain + magic calls (%magic), special shell access (!cmd), etc.""" + + # We must start with a clean buffer, in case this is run from an + # interactive IPython session (via a magic, for example). + self.resetbuffer() + lines = lines.split('\n') + more = 0 + for line in lines: + # skip blank lines so we don't mess up the prompt counter, but do + # NOT skip even a blank line if we are in a code block (more is + # true) + if line or more: + more = self.push((self._prefilter(line,more))) + # IPython's runsource returns None if there was an error + # compiling the code. This allows us to stop processing right + # away, so the user gets the error message at the right place. + if more is None: + break + # final newline in case the input didn't have it, so that the code + # actually does get executed + if more: + self.push('\n') + + def runcode(self, code): + """Execute a code object. + + When an exception occurs, self.showtraceback() is called to + display a traceback. All exceptions are caught except + SystemExit, which is reraised. + + A note about KeyboardInterrupt: this exception may occur + elsewhere in this code, and may not always be caught. The + caller should be prepared to deal with it. + + """ + + self._last_type = self._last_traceback = self._last_value = None + try: + exec code in self.locals + except: + # Since the exception info may need to travel across the wire, we + # pack it in right away. Note that we are abusing the exception + # value to store a fully formatted traceback, since the stack can + # not be serialized for network transmission. + et,ev,tb = sys.exc_info() + self._last_type = et + self._last_traceback = tb + tbinfo = self.tbHandler.text(et,ev,tb) + # Construct a meaningful traceback message for shipping over the + # wire. + buf = pprint.pformat(self.buffer) + try: + ename = et.__name__ + except: + ename = et + msg = """\ +%(ev)s +*************************************************************************** +An exception occurred in an IPython engine while executing user code. + +Current execution buffer (lines being run): +%(buf)s + +A full traceback from the actual engine: +%(tbinfo)s +*************************************************************************** + """ % locals() + self._last_value = msg + else: + if softspace(sys.stdout, 0): + print + + ################################################################## + # Methods that are a part of the official interface + # + # These methods should also be put in the + # kernel.coreservice.ICoreService interface. + # + # These methods must conform to certain restrictions that allow + # them to be exposed to various network protocols: + # + # - As much as possible, these methods must return types that can be + # serialized by PB, XML-RPC and SOAP. None is OK. + # - Every method must be thread safe. There are some locks provided + # for this purpose, but new, specialized locks can be added to the + # class. + ################################################################## + + # Methods for running code + + def exc_info(self): + """Return exception information much like sys.exc_info(). + + This method returns the same (etype,evalue,tb) tuple as sys.exc_info, + but from the last time that the engine had an exception fire.""" + + return self._last_type,self._last_value,self._last_traceback + + def execute(self, lines): + self._trapRunlines(lines) + if self._last_type is None: + return self.getCommand() + else: + raise self._last_type(self._last_value) + + # Methods for working with the namespace + + def put(self, key, value): + """Put value into locals namespace with name key. + + I have often called this push(), but in this case the + InteractiveConsole class already defines a push() method that + is different. + """ + + if not isinstance(key, str): + raise TypeError, "Objects must be keyed by strings." + self.update({key:value}) + + def get(self, key): + """Gets an item out of the self.locals dict by key. + + Raise NameError if the object doesn't exist. + + I have often called this pull(). I still like that better. + """ + + class NotDefined(object): + """A class to signify an objects that is not in the users ns.""" + pass + + if not isinstance(key, str): + raise TypeError, "Objects must be keyed by strings." + result = self.locals.get(key, NotDefined()) + if isinstance(result, NotDefined): + raise NameError('name %s is not defined' % key) + else: + return result + + + def update(self, dictOfData): + """Loads a dict of key value pairs into the self.locals namespace.""" + if not isinstance(dictOfData, dict): + raise TypeError, "update() takes a dict object." + #self._namespace_lock.acquire() + self.locals.update(dictOfData) + #self._namespace_lock.release() + + # Methods for getting stdout/stderr/stdin + + def reset(self): + """Reset the InteractiveShell.""" + + #self._command_lock.acquire() + self._stdin = [] + self._stdout = [] + self._stderr = [] + self.lastCommandIndex = -1 + #self._command_lock.release() + + #self._namespace_lock.acquire() + # preserve id, mpi objects + mpi = self.locals.get('mpi', None) + id = self.locals.get('id', None) + del self.locals + self.locals = {'mpi': mpi, 'id': id} + #self._namespace_lock.release() + + def getCommand(self,i=None): + """Get the stdin/stdout/stderr of command i.""" + + #self._command_lock.acquire() + + + if i is not None and not isinstance(i, int): + raise TypeError("Command index not an int: " + str(i)) + + if i in range(self.lastCommandIndex + 1): + inResult = self._stdin[i] + outResult = self._stdout[i] + errResult = self._stderr[i] + cmdNum = i + elif i is None and self.lastCommandIndex >= 0: + inResult = self._stdin[self.lastCommandIndex] + outResult = self._stdout[self.lastCommandIndex] + errResult = self._stderr[self.lastCommandIndex] + cmdNum = self.lastCommandIndex + else: + inResult = None + outResult = None + errResult = None + + #self._command_lock.release() + + if inResult is not None: + return dict(commandIndex=cmdNum, stdin=inResult, stdout=outResult, stderr=errResult) + else: + raise IndexError("Command with index %s does not exist" % str(i)) + + def getLastCommandIndex(self): + """Get the index of the last command.""" + #self._command_lock.acquire() + ind = self.lastCommandIndex + #self._command_lock.release() + return ind + diff --git a/IPython/kernel/core/tests/__init__.py b/IPython/kernel/core/tests/__init__.py new file mode 100644 index 0000000..9a4495f --- /dev/null +++ b/IPython/kernel/core/tests/__init__.py @@ -0,0 +1,10 @@ +# encoding: utf-8 + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- \ No newline at end of file diff --git a/IPython/kernel/core/tests/test_shell.py b/IPython/kernel/core/tests/test_shell.py new file mode 100644 index 0000000..87d7ee2 --- /dev/null +++ b/IPython/kernel/core/tests/test_shell.py @@ -0,0 +1,67 @@ +# encoding: utf-8 + +"""This file contains unittests for the shell.py module.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import unittest +from IPython.kernel.core import shell + +resultKeys = ('commandIndex', 'stdin', 'stdout', 'stderr') + +class BasicShellTest(unittest.TestCase): + + def setUp(self): + self.s = shell.InteractiveShell() + + def testExecute(self): + commands = [(0,"a = 5","",""), + (1,"b = 10","",""), + (2,"c = a + b","",""), + (3,"print c","15\n",""), + (4,"import math","",""), + (5,"2.0*math.pi","6.2831853071795862\n","")] + for c in commands: + result = self.s.execute(c[1]) + self.assertEquals(result, dict(zip(resultKeys,c))) + + def testPutGet(self): + objs = [10,"hi there",1.2342354,{"p":(1,2)}] + for o in objs: + self.s.put("key",o) + value = self.s.get("key") + self.assertEquals(value,o) + self.assertRaises(TypeError, self.s.put,10) + self.assertRaises(TypeError, self.s.get,10) + self.s.reset() + self.assertRaises(NameError, self.s.get, 'a') + + def testUpdate(self): + d = {"a": 10, "b": 34.3434, "c": "hi there"} + self.s.update(d) + for k in d.keys(): + value = self.s.get(k) + self.assertEquals(value, d[k]) + self.assertRaises(TypeError, self.s.update, [1,2,2]) + + def testCommand(self): + self.assertRaises(IndexError,self.s.getCommand) + self.s.execute("a = 5") + self.assertEquals(self.s.getCommand(), dict(zip(resultKeys, (0,"a = 5","","")))) + self.assertEquals(self.s.getCommand(0), dict(zip(resultKeys, (0,"a = 5","","")))) + self.s.reset() + self.assertEquals(self.s.getLastCommandIndex(),-1) + self.assertRaises(IndexError,self.s.getCommand) + + \ No newline at end of file diff --git a/IPython/kernel/core/traceback_formatter.py b/IPython/kernel/core/traceback_formatter.py new file mode 100644 index 0000000..05628e7 --- /dev/null +++ b/IPython/kernel/core/traceback_formatter.py @@ -0,0 +1,62 @@ +# encoding: utf-8 + +"""Some formatter objects to extract traceback information by replacing +sys.excepthook().""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import traceback + + +class ITracebackFormatter(object): + """ Objects conforming to this interface will format tracebacks into other + objects. + """ + + # The kind of formatter. + kind = 'traceback' + + # The unique identifier for this formatter. + identifier = None + + + def __call__(self, exc_type, exc_value, exc_traceback): + """ Return a formatted representation of a traceback. + """ + + raise NotImplementedError + + +class PlainTracebackFormatter(ITracebackFormatter): + """ Return a string with the regular traceback information. + """ + + # The unique identifier for this formatter. + identifier = 'plain' + + + def __init__(self, limit=None): + # The maximum number of stack levels to go back. + # None implies all stack levels are returned. + self.limit = limit + + def __call__(self, exc_type, exc_value, exc_traceback): + """ Return a string with the regular traceback information. + """ + + lines = traceback.format_tb(exc_traceback, self.limit) + lines.append('%s: %s' % (exc_type.__name__, exc_value)) + return '\n'.join(lines) + + diff --git a/IPython/kernel/core/traceback_trap.py b/IPython/kernel/core/traceback_trap.py new file mode 100644 index 0000000..7fd1d17 --- /dev/null +++ b/IPython/kernel/core/traceback_trap.py @@ -0,0 +1,83 @@ +# encoding: utf-8 + +"""Object to manage sys.excepthook().""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import sys + + +class TracebackTrap(object): + """ Object to trap and format tracebacks. + """ + + def __init__(self, formatters=None): + # A list of formatters to apply. + if formatters is None: + formatters = [] + self.formatters = formatters + + # All of the traceback information provided to sys.excepthook(). + self.args = None + + # The previous hook before we replace it. + self.old_hook = None + + + def hook(self, *args): + """ This method actually implements the hook. + """ + + self.args = args + + def set(self): + """ Set the hook. + """ + + if sys.excepthook is not self.hook: + self.old_hook = sys.excepthook + sys.excepthook = self.hook + + def unset(self): + """ Unset the hook. + """ + + sys.excepthook = self.old_hook + + def clear(self): + """ Remove the stored traceback. + """ + + self.args = None + + def add_to_message(self, message): + """ Add the formatted display of the traceback to the message dictionary + being returned from the interpreter to its listeners. + + Parameters + ---------- + message : dict + """ + + # If there was no traceback, then don't add anything. + if self.args is None: + return + + # Go through the list of formatters and let them add their formatting. + traceback = {} + for formatter in self.formatters: + traceback[formatter.identifier] = formatter(*self.args) + + message['traceback'] = traceback + diff --git a/IPython/kernel/core/ultraTB.py b/IPython/kernel/core/ultraTB.py new file mode 100644 index 0000000..123d6ae --- /dev/null +++ b/IPython/kernel/core/ultraTB.py @@ -0,0 +1,1018 @@ +# encoding: utf-8 + +""" +ultraTB.py -- Spice up your tracebacks! + +* ColorTB +I've always found it a bit hard to visually parse tracebacks in Python. The +ColorTB class is a solution to that problem. It colors the different parts of a +traceback in a manner similar to what you would expect from a syntax-highlighting +text editor. + +Installation instructions for ColorTB: + import sys,ultraTB + sys.excepthook = ultraTB.ColorTB() + +* VerboseTB +I've also included a port of Ka-Ping Yee's "cgitb.py" that produces all kinds +of useful info when a traceback occurs. Ping originally had it spit out HTML +and intended it for CGI programmers, but why should they have all the fun? I +altered it to spit out colored text to the terminal. It's a bit overwhelming, +but kind of neat, and maybe useful for long-running programs that you believe +are bug-free. If a crash *does* occur in that type of program you want details. +Give it a shot--you'll love it or you'll hate it. + +Note: + + The Verbose mode prints the variables currently visible where the exception + happened (shortening their strings if too long). This can potentially be + very slow, if you happen to have a huge data structure whose string + representation is complex to compute. Your computer may appear to freeze for + a while with cpu usage at 100%. If this occurs, you can cancel the traceback + with Ctrl-C (maybe hitting it more than once). + + If you encounter this kind of situation often, you may want to use the + Verbose_novars mode instead of the regular Verbose, which avoids formatting + variables (but otherwise includes the information and context given by + Verbose). + + +Installation instructions for ColorTB: + import sys,ultraTB + sys.excepthook = ultraTB.VerboseTB() + +Note: Much of the code in this module was lifted verbatim from the standard +library module 'traceback.py' and Ka-Ping Yee's 'cgitb.py'. + +* Color schemes +The colors are defined in the class TBTools through the use of the +ColorSchemeTable class. Currently the following exist: + + - NoColor: allows all of this module to be used in any terminal (the color + escapes are just dummy blank strings). + + - Linux: is meant to look good in a terminal like the Linux console (black + or very dark background). + + - LightBG: similar to Linux but swaps dark/light colors to be more readable + in light background terminals. + +You can implement other color schemes easily, the syntax is fairly +self-explanatory. Please send back new schemes you develop to the author for +possible inclusion in future releases. + +$Id: ultraTB.py 2480 2007-07-06 19:33:43Z fperez $""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from IPython import Release +__author__ = '%s <%s>\n%s <%s>' % (Release.authors['Nathan']+ + Release.authors['Fernando']) +__license__ = Release.license + +# Required modules +import inspect +import keyword +import linecache +import os +import pydoc +import re +import string +import sys +import time +import tokenize +import traceback +import types + +# For purposes of monkeypatching inspect to fix a bug in it. +from inspect import getsourcefile, getfile, getmodule,\ + ismodule, isclass, ismethod, isfunction, istraceback, isframe, iscode + + +# IPython's own modules +# Modified pdb which doesn't damage IPython's readline handling +from IPython import Debugger, PyColorize +from IPython.ipstruct import Struct +from IPython.excolors import ExceptionColors +from IPython.genutils import Term,uniq_stable,error,info + +# Globals +# amount of space to put line numbers before verbose tracebacks +INDENT_SIZE = 8 + +# Default color scheme. This is used, for example, by the traceback +# formatter. When running in an actual IPython instance, the user's rc.colors +# value is used, but havinga module global makes this functionality available +# to users of ultraTB who are NOT running inside ipython. +DEFAULT_SCHEME = 'NoColor' + +#--------------------------------------------------------------------------- +# Code begins + +# Utility functions +def inspect_error(): + """Print a message about internal inspect errors. + + These are unfortunately quite common.""" + + error('Internal Python error in the inspect module.\n' + 'Below is the traceback from this internal error.\n') + + +def findsource(object): + """Return the entire source file and starting line number for an object. + + The argument may be a module, class, method, function, traceback, frame, + or code object. The source code is returned as a list of all the lines + in the file and the line number indexes a line in that list. An IOError + is raised if the source code cannot be retrieved. + + FIXED version with which we monkeypatch the stdlib to work around a bug.""" + + file = getsourcefile(object) or getfile(object) + module = getmodule(object, file) + if module: + lines = linecache.getlines(file, module.__dict__) + else: + lines = linecache.getlines(file) + if not lines: + raise IOError('could not get source code') + + if ismodule(object): + return lines, 0 + + if isclass(object): + name = object.__name__ + pat = re.compile(r'^(\s*)class\s*' + name + r'\b') + # make some effort to find the best matching class definition: + # use the one with the least indentation, which is the one + # that's most probably not inside a function definition. + candidates = [] + for i in range(len(lines)): + match = pat.match(lines[i]) + if match: + # if it's at toplevel, it's already the best one + if lines[i][0] == 'c': + return lines, i + # else add whitespace to candidate list + candidates.append((match.group(1), i)) + if candidates: + # this will sort by whitespace, and by line number, + # less whitespace first + candidates.sort() + return lines, candidates[0][1] + else: + raise IOError('could not find class definition') + + if ismethod(object): + object = object.im_func + if isfunction(object): + object = object.func_code + if istraceback(object): + object = object.tb_frame + if isframe(object): + object = object.f_code + if iscode(object): + if not hasattr(object, 'co_firstlineno'): + raise IOError('could not find function definition') + pat = re.compile(r'^(\s*def\s)|(.*(? 0: + if pmatch(lines[lnum]): break + lnum -= 1 + + return lines, lnum + raise IOError('could not find code object') + +# Monkeypatch inspect to apply our bugfix. This code only works with py25 +if sys.version_info[:2] >= (2,5): + inspect.findsource = findsource + +def _fixed_getinnerframes(etb, context=1,tb_offset=0): + import linecache + LNUM_POS, LINES_POS, INDEX_POS = 2, 4, 5 + + records = inspect.getinnerframes(etb, context) + + # If the error is at the console, don't build any context, since it would + # otherwise produce 5 blank lines printed out (there is no file at the + # console) + rec_check = records[tb_offset:] + try: + rname = rec_check[0][1] + if rname == '' or rname.endswith(''): + return rec_check + except IndexError: + pass + + aux = traceback.extract_tb(etb) + assert len(records) == len(aux) + for i, (file, lnum, _, _) in zip(range(len(records)), aux): + maybeStart = lnum-1 - context//2 + start = max(maybeStart, 0) + end = start + context + lines = linecache.getlines(file)[start:end] + # pad with empty lines if necessary + if maybeStart < 0: + lines = (['\n'] * -maybeStart) + lines + if len(lines) < context: + lines += ['\n'] * (context - len(lines)) + buf = list(records[i]) + buf[LNUM_POS] = lnum + buf[INDEX_POS] = lnum - 1 - start + buf[LINES_POS] = lines + records[i] = tuple(buf) + return records[tb_offset:] + +# Helper function -- largely belongs to VerboseTB, but we need the same +# functionality to produce a pseudo verbose TB for SyntaxErrors, so that they +# can be recognized properly by ipython.el's py-traceback-line-re +# (SyntaxErrors have to be treated specially because they have no traceback) + +_parser = PyColorize.Parser() + +def _formatTracebackLines(lnum, index, lines, Colors, lvals=None,scheme=None): + numbers_width = INDENT_SIZE - 1 + res = [] + i = lnum - index + + # This lets us get fully syntax-highlighted tracebacks. + if scheme is None: + try: + scheme = __IPYTHON__.rc.colors + except: + scheme = DEFAULT_SCHEME + _line_format = _parser.format2 + + for line in lines: + new_line, err = _line_format(line,'str',scheme) + if not err: line = new_line + + if i == lnum: + # This is the line with the error + pad = numbers_width - len(str(i)) + if pad >= 3: + marker = '-'*(pad-3) + '-> ' + elif pad == 2: + marker = '> ' + elif pad == 1: + marker = '>' + else: + marker = '' + num = marker + str(i) + line = '%s%s%s %s%s' %(Colors.linenoEm, num, + Colors.line, line, Colors.Normal) + else: + num = '%*s' % (numbers_width,i) + line = '%s%s%s %s' %(Colors.lineno, num, + Colors.Normal, line) + + res.append(line) + if lvals and i == lnum: + res.append(lvals + '\n') + i = i + 1 + return res + + +#--------------------------------------------------------------------------- +# Module classes +class TBTools: + """Basic tools used by all traceback printer classes.""" + + def __init__(self,color_scheme = 'NoColor',call_pdb=False): + # Whether to call the interactive pdb debugger after printing + # tracebacks or not + self.call_pdb = call_pdb + + # Create color table + self.color_scheme_table = ExceptionColors + + self.set_colors(color_scheme) + self.old_scheme = color_scheme # save initial value for toggles + + if call_pdb: + self.pdb = Debugger.Pdb(self.color_scheme_table.active_scheme_name) + else: + self.pdb = None + + def set_colors(self,*args,**kw): + """Shorthand access to the color table scheme selector method.""" + + # Set own color table + self.color_scheme_table.set_active_scheme(*args,**kw) + # for convenience, set Colors to the active scheme + self.Colors = self.color_scheme_table.active_colors + # Also set colors of debugger + if hasattr(self,'pdb') and self.pdb is not None: + self.pdb.set_colors(*args,**kw) + + def color_toggle(self): + """Toggle between the currently active color scheme and NoColor.""" + + if self.color_scheme_table.active_scheme_name == 'NoColor': + self.color_scheme_table.set_active_scheme(self.old_scheme) + self.Colors = self.color_scheme_table.active_colors + else: + self.old_scheme = self.color_scheme_table.active_scheme_name + self.color_scheme_table.set_active_scheme('NoColor') + self.Colors = self.color_scheme_table.active_colors + +#--------------------------------------------------------------------------- +class ListTB(TBTools): + """Print traceback information from a traceback list, with optional color. + + Calling: requires 3 arguments: + (etype, evalue, elist) + as would be obtained by: + etype, evalue, tb = sys.exc_info() + if tb: + elist = traceback.extract_tb(tb) + else: + elist = None + + It can thus be used by programs which need to process the traceback before + printing (such as console replacements based on the code module from the + standard library). + + Because they are meant to be called without a full traceback (only a + list), instances of this class can't call the interactive pdb debugger.""" + + def __init__(self,color_scheme = 'NoColor'): + TBTools.__init__(self,color_scheme = color_scheme,call_pdb=0) + + def __call__(self, etype, value, elist): + Term.cout.flush() + Term.cerr.flush() + print >> Term.cerr, self.text(etype,value,elist) + + def text(self,etype, value, elist,context=5): + """Return a color formatted string with the traceback info.""" + + Colors = self.Colors + out_string = ['%s%s%s\n' % (Colors.topline,'-'*60,Colors.Normal)] + if elist: + out_string.append('Traceback %s(most recent call last)%s:' % \ + (Colors.normalEm, Colors.Normal) + '\n') + out_string.extend(self._format_list(elist)) + lines = self._format_exception_only(etype, value) + for line in lines[:-1]: + out_string.append(" "+line) + out_string.append(lines[-1]) + return ''.join(out_string) + + def _format_list(self, extracted_list): + """Format a list of traceback entry tuples for printing. + + Given a list of tuples as returned by extract_tb() or + extract_stack(), return a list of strings ready for printing. + Each string in the resulting list corresponds to the item with the + same index in the argument list. Each string ends in a newline; + the strings may contain internal newlines as well, for those items + whose source text line is not None. + + Lifted almost verbatim from traceback.py + """ + + Colors = self.Colors + list = [] + for filename, lineno, name, line in extracted_list[:-1]: + item = ' File %s"%s"%s, line %s%d%s, in %s%s%s\n' % \ + (Colors.filename, filename, Colors.Normal, + Colors.lineno, lineno, Colors.Normal, + Colors.name, name, Colors.Normal) + if line: + item = item + ' %s\n' % line.strip() + list.append(item) + # Emphasize the last entry + filename, lineno, name, line = extracted_list[-1] + item = '%s File %s"%s"%s, line %s%d%s, in %s%s%s%s\n' % \ + (Colors.normalEm, + Colors.filenameEm, filename, Colors.normalEm, + Colors.linenoEm, lineno, Colors.normalEm, + Colors.nameEm, name, Colors.normalEm, + Colors.Normal) + if line: + item = item + '%s %s%s\n' % (Colors.line, line.strip(), + Colors.Normal) + list.append(item) + return list + + def _format_exception_only(self, etype, value): + """Format the exception part of a traceback. + + The arguments are the exception type and value such as given by + sys.exc_info()[:2]. The return value is a list of strings, each ending + in a newline. Normally, the list contains a single string; however, + for SyntaxError exceptions, it contains several lines that (when + printed) display detailed information about where the syntax error + occurred. The message indicating which exception occurred is the + always last string in the list. + + Also lifted nearly verbatim from traceback.py + """ + + Colors = self.Colors + list = [] + try: + stype = Colors.excName + etype.__name__ + Colors.Normal + except AttributeError: + stype = etype # String exceptions don't get special coloring + if value is None: + list.append( str(stype) + '\n') + else: + if etype is SyntaxError: + try: + msg, (filename, lineno, offset, line) = value + except: + pass + else: + #print 'filename is',filename # dbg + if not filename: filename = "" + list.append('%s File %s"%s"%s, line %s%d%s\n' % \ + (Colors.normalEm, + Colors.filenameEm, filename, Colors.normalEm, + Colors.linenoEm, lineno, Colors.Normal )) + if line is not None: + i = 0 + while i < len(line) and line[i].isspace(): + i = i+1 + list.append('%s %s%s\n' % (Colors.line, + line.strip(), + Colors.Normal)) + if offset is not None: + s = ' ' + for c in line[i:offset-1]: + if c.isspace(): + s = s + c + else: + s = s + ' ' + list.append('%s%s^%s\n' % (Colors.caret, s, + Colors.Normal) ) + value = msg + s = self._some_str(value) + if s: + list.append('%s%s:%s %s\n' % (str(stype), Colors.excName, + Colors.Normal, s)) + else: + list.append('%s\n' % str(stype)) + return list + + def _some_str(self, value): + # Lifted from traceback.py + try: + return str(value) + except: + return '' % type(value).__name__ + +#---------------------------------------------------------------------------- +class VerboseTB(TBTools): + """A port of Ka-Ping Yee's cgitb.py module that outputs color text instead + of HTML. Requires inspect and pydoc. Crazy, man. + + Modified version which optionally strips the topmost entries from the + traceback, to be used with alternate interpreters (because their own code + would appear in the traceback).""" + + def __init__(self,color_scheme = 'Linux',tb_offset=0,long_header=0, + call_pdb = 0, include_vars=1): + """Specify traceback offset, headers and color scheme. + + Define how many frames to drop from the tracebacks. Calling it with + tb_offset=1 allows use of this handler in interpreters which will have + their own code at the top of the traceback (VerboseTB will first + remove that frame before printing the traceback info).""" + TBTools.__init__(self,color_scheme=color_scheme,call_pdb=call_pdb) + self.tb_offset = tb_offset + self.long_header = long_header + self.include_vars = include_vars + + def text(self, etype, evalue, etb, context=5): + """Return a nice text document describing the traceback.""" + + # some locals + try: + etype = etype.__name__ + except AttributeError: + pass + Colors = self.Colors # just a shorthand + quicker name lookup + ColorsNormal = Colors.Normal # used a lot + col_scheme = self.color_scheme_table.active_scheme_name + indent = ' '*INDENT_SIZE + em_normal = '%s\n%s%s' % (Colors.valEm, indent,ColorsNormal) + undefined = '%sundefined%s' % (Colors.em, ColorsNormal) + exc = '%s%s%s' % (Colors.excName,etype,ColorsNormal) + + # some internal-use functions + def text_repr(value): + """Hopefully pretty robust repr equivalent.""" + # this is pretty horrible but should always return *something* + try: + return pydoc.text.repr(value) + except KeyboardInterrupt: + raise + except: + try: + return repr(value) + except KeyboardInterrupt: + raise + except: + try: + # all still in an except block so we catch + # getattr raising + name = getattr(value, '__name__', None) + if name: + # ick, recursion + return text_repr(name) + klass = getattr(value, '__class__', None) + if klass: + return '%s instance' % text_repr(klass) + except KeyboardInterrupt: + raise + except: + return 'UNRECOVERABLE REPR FAILURE' + def eqrepr(value, repr=text_repr): return '=%s' % repr(value) + def nullrepr(value, repr=text_repr): return '' + + # meat of the code begins + try: + etype = etype.__name__ + except AttributeError: + pass + + if self.long_header: + # Header with the exception type, python version, and date + pyver = 'Python ' + string.split(sys.version)[0] + ': ' + sys.executable + date = time.ctime(time.time()) + + head = '%s%s%s\n%s%s%s\n%s' % (Colors.topline, '-'*75, ColorsNormal, + exc, ' '*(75-len(str(etype))-len(pyver)), + pyver, string.rjust(date, 75) ) + head += "\nA problem occured executing Python code. Here is the sequence of function"\ + "\ncalls leading up to the error, with the most recent (innermost) call last." + else: + # Simplified header + head = '%s%s%s\n%s%s' % (Colors.topline, '-'*75, ColorsNormal,exc, + string.rjust('Traceback (most recent call last)', + 75 - len(str(etype)) ) ) + frames = [] + # Flush cache before calling inspect. This helps alleviate some of the + # problems with python 2.3's inspect.py. + linecache.checkcache() + # Drop topmost frames if requested + try: + # Try the default getinnerframes and Alex's: Alex's fixes some + # problems, but it generates empty tracebacks for console errors + # (5 blanks lines) where none should be returned. + #records = inspect.getinnerframes(etb, context)[self.tb_offset:] + #print 'python records:', records # dbg + records = _fixed_getinnerframes(etb, context,self.tb_offset) + #print 'alex records:', records # dbg + except: + + # FIXME: I've been getting many crash reports from python 2.3 + # users, traceable to inspect.py. If I can find a small test-case + # to reproduce this, I should either write a better workaround or + # file a bug report against inspect (if that's the real problem). + # So far, I haven't been able to find an isolated example to + # reproduce the problem. + inspect_error() + traceback.print_exc(file=Term.cerr) + info('\nUnfortunately, your original traceback can not be constructed.\n') + return '' + + # build some color string templates outside these nested loops + tpl_link = '%s%%s%s' % (Colors.filenameEm,ColorsNormal) + tpl_call = 'in %s%%s%s%%s%s' % (Colors.vName, Colors.valEm, + ColorsNormal) + tpl_call_fail = 'in %s%%s%s(***failed resolving arguments***)%s' % \ + (Colors.vName, Colors.valEm, ColorsNormal) + tpl_local_var = '%s%%s%s' % (Colors.vName, ColorsNormal) + tpl_global_var = '%sglobal%s %s%%s%s' % (Colors.em, ColorsNormal, + Colors.vName, ColorsNormal) + tpl_name_val = '%%s %s= %%s%s' % (Colors.valEm, ColorsNormal) + tpl_line = '%s%%s%s %%s' % (Colors.lineno, ColorsNormal) + tpl_line_em = '%s%%s%s %%s%s' % (Colors.linenoEm,Colors.line, + ColorsNormal) + + # now, loop over all records printing context and info + abspath = os.path.abspath + for frame, file, lnum, func, lines, index in records: + #print '*** record:',file,lnum,func,lines,index # dbg + try: + file = file and abspath(file) or '?' + except OSError: + # if file is '' or something not in the filesystem, + # the abspath call will throw an OSError. Just ignore it and + # keep the original file string. + pass + link = tpl_link % file + try: + args, varargs, varkw, locals = inspect.getargvalues(frame) + except: + # This can happen due to a bug in python2.3. We should be + # able to remove this try/except when 2.4 becomes a + # requirement. Bug details at http://python.org/sf/1005466 + inspect_error() + traceback.print_exc(file=Term.cerr) + info("\nIPython's exception reporting continues...\n") + + if func == '?': + call = '' + else: + # Decide whether to include variable details or not + var_repr = self.include_vars and eqrepr or nullrepr + try: + call = tpl_call % (func,inspect.formatargvalues(args, + varargs, varkw, + locals,formatvalue=var_repr)) + except KeyError: + # Very odd crash from inspect.formatargvalues(). The + # scenario under which it appeared was a call to + # view(array,scale) in NumTut.view.view(), where scale had + # been defined as a scalar (it should be a tuple). Somehow + # inspect messes up resolving the argument list of view() + # and barfs out. At some point I should dig into this one + # and file a bug report about it. + inspect_error() + traceback.print_exc(file=Term.cerr) + info("\nIPython's exception reporting continues...\n") + call = tpl_call_fail % func + + # Initialize a list of names on the current line, which the + # tokenizer below will populate. + names = [] + + def tokeneater(token_type, token, start, end, line): + """Stateful tokeneater which builds dotted names. + + The list of names it appends to (from the enclosing scope) can + contain repeated composite names. This is unavoidable, since + there is no way to disambguate partial dotted structures until + the full list is known. The caller is responsible for pruning + the final list of duplicates before using it.""" + + # build composite names + if token == '.': + try: + names[-1] += '.' + # store state so the next token is added for x.y.z names + tokeneater.name_cont = True + return + except IndexError: + pass + if token_type == tokenize.NAME and token not in keyword.kwlist: + if tokeneater.name_cont: + # Dotted names + names[-1] += token + tokeneater.name_cont = False + else: + # Regular new names. We append everything, the caller + # will be responsible for pruning the list later. It's + # very tricky to try to prune as we go, b/c composite + # names can fool us. The pruning at the end is easy + # to do (or the caller can print a list with repeated + # names if so desired. + names.append(token) + elif token_type == tokenize.NEWLINE: + raise IndexError + # we need to store a bit of state in the tokenizer to build + # dotted names + tokeneater.name_cont = False + + def linereader(file=file, lnum=[lnum], getline=linecache.getline): + line = getline(file, lnum[0]) + lnum[0] += 1 + return line + + # Build the list of names on this line of code where the exception + # occurred. + try: + # This builds the names list in-place by capturing it from the + # enclosing scope. + tokenize.tokenize(linereader, tokeneater) + except IndexError: + # signals exit of tokenizer + pass + except tokenize.TokenError,msg: + _m = ("An unexpected error occurred while tokenizing input\n" + "The following traceback may be corrupted or invalid\n" + "The error message is: %s\n" % msg) + error(_m) + + # prune names list of duplicates, but keep the right order + unique_names = uniq_stable(names) + + # Start loop over vars + lvals = [] + if self.include_vars: + for name_full in unique_names: + name_base = name_full.split('.',1)[0] + if name_base in frame.f_code.co_varnames: + if locals.has_key(name_base): + try: + value = repr(eval(name_full,locals)) + except: + value = undefined + else: + value = undefined + name = tpl_local_var % name_full + else: + if frame.f_globals.has_key(name_base): + try: + value = repr(eval(name_full,frame.f_globals)) + except: + value = undefined + else: + value = undefined + name = tpl_global_var % name_full + lvals.append(tpl_name_val % (name,value)) + if lvals: + lvals = '%s%s' % (indent,em_normal.join(lvals)) + else: + lvals = '' + + level = '%s %s\n' % (link,call) + + if index is None: + frames.append(level) + else: + frames.append('%s%s' % (level,''.join( + _formatTracebackLines(lnum,index,lines,Colors,lvals, + col_scheme)))) + + # Get (safely) a string form of the exception info + try: + etype_str,evalue_str = map(str,(etype,evalue)) + except: + # User exception is improperly defined. + etype,evalue = str,sys.exc_info()[:2] + etype_str,evalue_str = map(str,(etype,evalue)) + # ... and format it + exception = ['%s%s%s: %s' % (Colors.excName, etype_str, + ColorsNormal, evalue_str)] + if type(evalue) is types.InstanceType: + try: + names = [w for w in dir(evalue) if isinstance(w, basestring)] + except: + # Every now and then, an object with funny inernals blows up + # when dir() is called on it. We do the best we can to report + # the problem and continue + _m = '%sException reporting error (object with broken dir())%s:' + exception.append(_m % (Colors.excName,ColorsNormal)) + etype_str,evalue_str = map(str,sys.exc_info()[:2]) + exception.append('%s%s%s: %s' % (Colors.excName,etype_str, + ColorsNormal, evalue_str)) + names = [] + for name in names: + value = text_repr(getattr(evalue, name)) + exception.append('\n%s%s = %s' % (indent, name, value)) + # return all our info assembled as a single string + return '%s\n\n%s\n%s' % (head,'\n'.join(frames),''.join(exception[0]) ) + + def debugger(self,force=False): + """Call up the pdb debugger if desired, always clean up the tb + reference. + + Keywords: + + - force(False): by default, this routine checks the instance call_pdb + flag and does not actually invoke the debugger if the flag is false. + The 'force' option forces the debugger to activate even if the flag + is false. + + If the call_pdb flag is set, the pdb interactive debugger is + invoked. In all cases, the self.tb reference to the current traceback + is deleted to prevent lingering references which hamper memory + management. + + Note that each call to pdb() does an 'import readline', so if your app + requires a special setup for the readline completers, you'll have to + fix that by hand after invoking the exception handler.""" + + if force or self.call_pdb: + if self.pdb is None: + self.pdb = Debugger.Pdb( + self.color_scheme_table.active_scheme_name) + # the system displayhook may have changed, restore the original + # for pdb + dhook = sys.displayhook + sys.displayhook = sys.__displayhook__ + self.pdb.reset() + # Find the right frame so we don't pop up inside ipython itself + if hasattr(self,'tb'): + etb = self.tb + else: + etb = self.tb = sys.last_traceback + while self.tb.tb_next is not None: + self.tb = self.tb.tb_next + try: + if etb and etb.tb_next: + etb = etb.tb_next + self.pdb.botframe = etb.tb_frame + self.pdb.interaction(self.tb.tb_frame, self.tb) + finally: + sys.displayhook = dhook + + if hasattr(self,'tb'): + del self.tb + + def handler(self, info=None): + (etype, evalue, etb) = info or sys.exc_info() + self.tb = etb + Term.cout.flush() + Term.cerr.flush() + print >> Term.cerr, self.text(etype, evalue, etb) + + # Changed so an instance can just be called as VerboseTB_inst() and print + # out the right info on its own. + def __call__(self, etype=None, evalue=None, etb=None): + """This hook can replace sys.excepthook (for Python 2.1 or higher).""" + if etb is None: + self.handler() + else: + self.handler((etype, evalue, etb)) + self.debugger() + +#---------------------------------------------------------------------------- +class FormattedTB(VerboseTB,ListTB): + """Subclass ListTB but allow calling with a traceback. + + It can thus be used as a sys.excepthook for Python > 2.1. + + Also adds 'Context' and 'Verbose' modes, not available in ListTB. + + Allows a tb_offset to be specified. This is useful for situations where + one needs to remove a number of topmost frames from the traceback (such as + occurs with python programs that themselves execute other python code, + like Python shells). """ + + def __init__(self, mode = 'Plain', color_scheme='Linux', + tb_offset = 0,long_header=0,call_pdb=0,include_vars=0): + + # NEVER change the order of this list. Put new modes at the end: + self.valid_modes = ['Plain','Context','Verbose'] + self.verbose_modes = self.valid_modes[1:3] + + VerboseTB.__init__(self,color_scheme,tb_offset,long_header, + call_pdb=call_pdb,include_vars=include_vars) + self.set_mode(mode) + + def _extract_tb(self,tb): + if tb: + return traceback.extract_tb(tb) + else: + return None + + def text(self, etype, value, tb,context=5,mode=None): + """Return formatted traceback. + + If the optional mode parameter is given, it overrides the current + mode.""" + + if mode is None: + mode = self.mode + if mode in self.verbose_modes: + # verbose modes need a full traceback + return VerboseTB.text(self,etype, value, tb,context=5) + else: + # We must check the source cache because otherwise we can print + # out-of-date source code. + linecache.checkcache() + # Now we can extract and format the exception + elist = self._extract_tb(tb) + if len(elist) > self.tb_offset: + del elist[:self.tb_offset] + return ListTB.text(self,etype,value,elist) + + def set_mode(self,mode=None): + """Switch to the desired mode. + + If mode is not specified, cycles through the available modes.""" + + if not mode: + new_idx = ( self.valid_modes.index(self.mode) + 1 ) % \ + len(self.valid_modes) + self.mode = self.valid_modes[new_idx] + elif mode not in self.valid_modes: + raise ValueError, 'Unrecognized mode in FormattedTB: <'+mode+'>\n'\ + 'Valid modes: '+str(self.valid_modes) + else: + self.mode = mode + # include variable details only in 'Verbose' mode + self.include_vars = (self.mode == self.valid_modes[2]) + + # some convenient shorcuts + def plain(self): + self.set_mode(self.valid_modes[0]) + + def context(self): + self.set_mode(self.valid_modes[1]) + + def verbose(self): + self.set_mode(self.valid_modes[2]) + +#---------------------------------------------------------------------------- +class AutoFormattedTB(FormattedTB): + """A traceback printer which can be called on the fly. + + It will find out about exceptions by itself. + + A brief example: + + AutoTB = AutoFormattedTB(mode = 'Verbose',color_scheme='Linux') + try: + ... + except: + AutoTB() # or AutoTB(out=logfile) where logfile is an open file object + """ + def __call__(self,etype=None,evalue=None,etb=None, + out=None,tb_offset=None): + """Print out a formatted exception traceback. + + Optional arguments: + - out: an open file-like object to direct output to. + + - tb_offset: the number of frames to skip over in the stack, on a + per-call basis (this overrides temporarily the instance's tb_offset + given at initialization time. """ + + if out is None: + out = Term.cerr + Term.cout.flush() + out.flush() + if tb_offset is not None: + tb_offset, self.tb_offset = self.tb_offset, tb_offset + print >> out, self.text(etype, evalue, etb) + self.tb_offset = tb_offset + else: + print >> out, self.text(etype, evalue, etb) + self.debugger() + + def text(self,etype=None,value=None,tb=None,context=5,mode=None): + if etype is None: + etype,value,tb = sys.exc_info() + self.tb = tb + return FormattedTB.text(self,etype,value,tb,context=5,mode=mode) + +#--------------------------------------------------------------------------- +# A simple class to preserve Nathan's original functionality. +class ColorTB(FormattedTB): + """Shorthand to initialize a FormattedTB in Linux colors mode.""" + def __init__(self,color_scheme='Linux',call_pdb=0): + FormattedTB.__init__(self,color_scheme=color_scheme, + call_pdb=call_pdb) + +#---------------------------------------------------------------------------- +# module testing (minimal) +if __name__ == "__main__": + def spam(c, (d, e)): + x = c + d + y = c * d + foo(x, y) + + def foo(a, b, bar=1): + eggs(a, b + bar) + + def eggs(f, g, z=globals()): + h = f + g + i = f - g + return h / i + + print '' + print '*** Before ***' + try: + print spam(1, (2, 3)) + except: + traceback.print_exc() + print '' + + handler = ColorTB() + print '*** ColorTB ***' + try: + print spam(1, (2, 3)) + except: + apply(handler, sys.exc_info() ) + print '' + + handler = VerboseTB() + print '*** VerboseTB ***' + try: + print spam(1, (2, 3)) + except: + apply(handler, sys.exc_info() ) + print '' + diff --git a/IPython/kernel/core/util.py b/IPython/kernel/core/util.py new file mode 100644 index 0000000..7465aff --- /dev/null +++ b/IPython/kernel/core/util.py @@ -0,0 +1,197 @@ +# encoding: utf-8 + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import os +import sys + + +# This class is mostly taken from IPython. +class InputList(list): + """ Class to store user input. + + It's basically a list, but slices return a string instead of a list, thus + allowing things like (assuming 'In' is an instance): + + exec In[4:7] + + or + + exec In[5:9] + In[14] + In[21:25] + """ + + def __getslice__(self, i, j): + return ''.join(list.__getslice__(self, i, j)) + + def add(self, index, command): + """ Add a command to the list with the appropriate index. + + If the index is greater than the current length of the list, empty + strings are added in between. + """ + + length = len(self) + if length == index: + self.append(command) + elif length > index: + self[index] = command + else: + extras = index - length + self.extend([''] * extras) + self.append(command) + + +class Bunch(dict): + """ A dictionary that exposes its keys as attributes. + """ + + def __init__(self, *args, **kwds): + dict.__init__(self, *args, **kwds) + self.__dict__ = self + + +def esc_quotes(strng): + """ Return the input string with single and double quotes escaped out. + """ + + return strng.replace('"', '\\"').replace("'", "\\'") + +def make_quoted_expr(s): + """Return string s in appropriate quotes, using raw string if possible. + + Effectively this turns string: cd \ao\ao\ + to: r"cd \ao\ao\_"[:-1] + + Note the use of raw string and padding at the end to allow trailing + backslash. + """ + + tail = '' + tailpadding = '' + raw = '' + if "\\" in s: + raw = 'r' + if s.endswith('\\'): + tail = '[:-1]' + tailpadding = '_' + if '"' not in s: + quote = '"' + elif "'" not in s: + quote = "'" + elif '"""' not in s and not s.endswith('"'): + quote = '"""' + elif "'''" not in s and not s.endswith("'"): + quote = "'''" + else: + # Give up, backslash-escaped string will do + return '"%s"' % esc_quotes(s) + res = ''.join([raw, quote, s, tailpadding, quote, tail]) + return res + +# This function is used by ipython in a lot of places to make system calls. +# We need it to be slightly different under win32, due to the vagaries of +# 'network shares'. A win32 override is below. + +def system_shell(cmd, verbose=False, debug=False, header=''): + """ Execute a command in the system shell; always return None. + + Parameters + ---------- + cmd : str + The command to execute. + verbose : bool + If True, print the command to be executed. + debug : bool + Only print, do not actually execute. + header : str + Header to print to screen prior to the executed command. No extra + newlines are added. + + Description + ----------- + This returns None so it can be conveniently used in interactive loops + without getting the return value (typically 0) printed many times. + """ + + if verbose or debug: + print header + cmd + + # Flush stdout so we don't mangle python's buffering. + sys.stdout.flush() + if not debug: + os.system(cmd) + +# Override shell() for win32 to deal with network shares. +if os.name in ('nt', 'dos'): + + system_shell_ori = system_shell + + def system_shell(cmd, verbose=False, debug=False, header=''): + if os.getcwd().startswith(r"\\"): + path = os.getcwd() + # Change to c drive (cannot be on UNC-share when issuing os.system, + # as cmd.exe cannot handle UNC addresses). + os.chdir("c:") + # Issue pushd to the UNC-share and then run the command. + try: + system_shell_ori('"pushd %s&&"'%path+cmd,verbose,debug,header) + finally: + os.chdir(path) + else: + system_shell_ori(cmd,verbose,debug,header) + + system_shell.__doc__ = system_shell_ori.__doc__ + +def getoutputerror(cmd, verbose=False, debug=False, header='', split=False): + """ Executes a command and returns the output. + + Parameters + ---------- + cmd : str + The command to execute. + verbose : bool + If True, print the command to be executed. + debug : bool + Only print, do not actually execute. + header : str + Header to print to screen prior to the executed command. No extra + newlines are added. + split : bool + If True, return the output as a list split on newlines. + + """ + + if verbose or debug: + print header+cmd + + if not cmd: + # Return empty lists or strings. + if split: + return [], [] + else: + return '', '' + + if not debug: + # fixme: use subprocess. + pin,pout,perr = os.popen3(cmd) + tout = pout.read().rstrip() + terr = perr.read().rstrip() + pin.close() + pout.close() + perr.close() + if split: + return tout.split('\n'), terr.split('\n') + else: + return tout, terr + diff --git a/IPython/kernel/engineconnector.py b/IPython/kernel/engineconnector.py new file mode 100644 index 0000000..93626e8 --- /dev/null +++ b/IPython/kernel/engineconnector.py @@ -0,0 +1,87 @@ +# encoding: utf-8 + +"""A class that manages the engines connection to the controller.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import os +import cPickle as pickle + +from twisted.python import log + +from IPython.kernel.fcutil import find_furl +from IPython.kernel.enginefc import IFCEngine + +#------------------------------------------------------------------------------- +# The ClientConnector class +#------------------------------------------------------------------------------- + +class EngineConnector(object): + """Manage an engines connection to a controller. + + This class takes a foolscap `Tub` and provides a `connect_to_controller` + method that will use the `Tub` to connect to a controller and register + the engine with the controller. + """ + + def __init__(self, tub): + self.tub = tub + + def connect_to_controller(self, engine_service, furl_or_file): + """ + Make a connection to a controller specified by a furl. + + This method takes an `IEngineBase` instance and a foolcap URL and uses + the `tub` attribute to make a connection to the controller. The + foolscap URL contains all the information needed to connect to the + controller, including the ip and port as well as any encryption and + authentication information needed for the connection. + + After getting a reference to the controller, this method calls the + `register_engine` method of the controller to actually register the + engine. + + :Parameters: + engine_service : IEngineBase + An instance of an `IEngineBase` implementer + furl_or_file : str + A furl or a filename containing a furl + """ + if not self.tub.running: + self.tub.startService() + self.engine_service = engine_service + self.engine_reference = IFCEngine(self.engine_service) + self.furl = find_furl(furl_or_file) + d = self.tub.getReference(self.furl) + d.addCallbacks(self._register, self._log_failure) + return d + + def _log_failure(self, reason): + log.err('engine registration failed:') + log.err(reason) + return reason + + def _register(self, rr): + self.remote_ref = rr + # Now register myself with the controller + desired_id = self.engine_service.id + d = self.remote_ref.callRemote('register_engine', self.engine_reference, + desired_id, os.getpid(), pickle.dumps(self.engine_service.properties,2)) + return d.addCallbacks(self._reference_sent, self._log_failure) + + def _reference_sent(self, registration_dict): + self.engine_service.id = registration_dict['id'] + log.msg("engine registration succeeded, got id: %r" % self.engine_service.id) + return self.engine_service.id + diff --git a/IPython/kernel/enginefc.py b/IPython/kernel/enginefc.py new file mode 100644 index 0000000..ebeff5c --- /dev/null +++ b/IPython/kernel/enginefc.py @@ -0,0 +1,548 @@ +# encoding: utf-8 +# -*- test-case-name: IPython.kernel.test.test_enginepb -*- + +""" +Expose the IPython EngineService using the Foolscap network protocol. + +Foolscap is a high-performance and secure network protocol. +""" +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import os, time +import cPickle as pickle + +from twisted.python import components, log, failure +from twisted.python.failure import Failure +from twisted.internet import defer, reactor, threads +from twisted.internet.interfaces import IProtocolFactory +from zope.interface import Interface, implements, Attribute + +from twisted.internet.base import DelayedCall +DelayedCall.debug = True + +from foolscap import Referenceable, DeadReferenceError +from foolscap.referenceable import RemoteReference + +from IPython.kernel.pbutil import packageFailure, unpackageFailure +from IPython.kernel.util import printer +from IPython.kernel.twistedutil import gatherBoth +from IPython.kernel import newserialized +from IPython.kernel.error import ProtocolError +from IPython.kernel import controllerservice +from IPython.kernel.controllerservice import IControllerBase +from IPython.kernel.engineservice import \ + IEngineBase, \ + IEngineQueued, \ + EngineService, \ + StrictDict +from IPython.kernel.pickleutil import \ + can, \ + canDict, \ + canSequence, \ + uncan, \ + uncanDict, \ + uncanSequence + + +#------------------------------------------------------------------------------- +# The client (Engine) side of things +#------------------------------------------------------------------------------- + +# Expose a FC interface to the EngineService + +class IFCEngine(Interface): + """An interface that exposes an EngineService over Foolscap. + + The methods in this interface are similar to those from IEngine, + but their arguments and return values slightly different to reflect + that FC cannot send arbitrary objects. We handle this by pickling/ + unpickling that the two endpoints. + + If a remote or local exception is raised, the appropriate Failure + will be returned instead. + """ + pass + + +class FCEngineReferenceFromService(Referenceable, object): + """Adapt an `IEngineBase` to an `IFCEngine` implementer. + + This exposes an `IEngineBase` to foolscap by adapting it to a + `foolscap.Referenceable`. + + See the documentation of the `IEngineBase` methods for more details. + """ + + implements(IFCEngine) + + def __init__(self, service): + assert IEngineBase.providedBy(service), \ + "IEngineBase is not provided by" + repr(service) + self.service = service + self.collectors = {} + + def remote_get_id(self): + return self.service.id + + def remote_set_id(self, id): + self.service.id = id + + def _checkProperties(self, result): + dosync = self.service.properties.modified + self.service.properties.modified = False + return (dosync and pickle.dumps(self.service.properties, 2)), result + + def remote_execute(self, lines): + d = self.service.execute(lines) + d.addErrback(packageFailure) + d.addCallback(self._checkProperties) + d.addErrback(packageFailure) + #d.addCallback(lambda r: log.msg("Got result: " + str(r))) + return d + + #--------------------------------------------------------------------------- + # Old version of push + #--------------------------------------------------------------------------- + + def remote_push(self, pNamespace): + try: + namespace = pickle.loads(pNamespace) + except: + return defer.fail(failure.Failure()).addErrback(packageFailure) + else: + return self.service.push(namespace).addErrback(packageFailure) + + #--------------------------------------------------------------------------- + # pull + #--------------------------------------------------------------------------- + + def remote_pull(self, keys): + d = self.service.pull(keys) + d.addCallback(pickle.dumps, 2) + d.addErrback(packageFailure) + return d + + #--------------------------------------------------------------------------- + # push/pullFuction + #--------------------------------------------------------------------------- + + def remote_push_function(self, pNamespace): + try: + namespace = pickle.loads(pNamespace) + except: + return defer.fail(failure.Failure()).addErrback(packageFailure) + else: + # The usage of globals() here is an attempt to bind any pickled functions + # to the globals of this module. What we really want is to have it bound + # to the globals of the callers module. This will require walking the + # stack. BG 10/3/07. + namespace = uncanDict(namespace, globals()) + return self.service.push_function(namespace).addErrback(packageFailure) + + def remote_pull_function(self, keys): + d = self.service.pull_function(keys) + if len(keys)>1: + d.addCallback(canSequence) + elif len(keys)==1: + d.addCallback(can) + d.addCallback(pickle.dumps, 2) + d.addErrback(packageFailure) + return d + + #--------------------------------------------------------------------------- + # Other methods + #--------------------------------------------------------------------------- + + def remote_get_result(self, i=None): + return self.service.get_result(i).addErrback(packageFailure) + + def remote_reset(self): + return self.service.reset().addErrback(packageFailure) + + def remote_kill(self): + return self.service.kill().addErrback(packageFailure) + + def remote_keys(self): + return self.service.keys().addErrback(packageFailure) + + #--------------------------------------------------------------------------- + # push/pull_serialized + #--------------------------------------------------------------------------- + + def remote_push_serialized(self, pNamespace): + try: + namespace = pickle.loads(pNamespace) + except: + return defer.fail(failure.Failure()).addErrback(packageFailure) + else: + d = self.service.push_serialized(namespace) + return d.addErrback(packageFailure) + + def remote_pull_serialized(self, keys): + d = self.service.pull_serialized(keys) + d.addCallback(pickle.dumps, 2) + d.addErrback(packageFailure) + return d + + #--------------------------------------------------------------------------- + # Properties interface + #--------------------------------------------------------------------------- + + def remote_set_properties(self, pNamespace): + try: + namespace = pickle.loads(pNamespace) + except: + return defer.fail(failure.Failure()).addErrback(packageFailure) + else: + return self.service.set_properties(namespace).addErrback(packageFailure) + + def remote_get_properties(self, keys=None): + d = self.service.get_properties(keys) + d.addCallback(pickle.dumps, 2) + d.addErrback(packageFailure) + return d + + def remote_has_properties(self, keys): + d = self.service.has_properties(keys) + d.addCallback(pickle.dumps, 2) + d.addErrback(packageFailure) + return d + + def remote_del_properties(self, keys): + d = self.service.del_properties(keys) + d.addErrback(packageFailure) + return d + + def remote_clear_properties(self): + d = self.service.clear_properties() + d.addErrback(packageFailure) + return d + + +components.registerAdapter(FCEngineReferenceFromService, + IEngineBase, + IFCEngine) + + +#------------------------------------------------------------------------------- +# Now the server (Controller) side of things +#------------------------------------------------------------------------------- + +class EngineFromReference(object): + """Adapt a `RemoteReference` to an `IEngineBase` implementing object. + + When an engine connects to a controller, it calls the `register_engine` + method of the controller and passes the controller a `RemoteReference` to + itself. This class is used to adapt this `RemoteReference` to an object + that implements the full `IEngineBase` interface. + + See the documentation of `IEngineBase` for details on the methods. + """ + + implements(IEngineBase) + + def __init__(self, reference): + self.reference = reference + self._id = None + self._properties = StrictDict() + self.currentCommand = None + + def callRemote(self, *args, **kwargs): + try: + return self.reference.callRemote(*args, **kwargs) + except DeadReferenceError: + self.notifier() + self.stopNotifying(self.notifier) + return defer.fail() + + def get_id(self): + """Return the Engines id.""" + return self._id + + def set_id(self, id): + """Set the Engines id.""" + self._id = id + return self.callRemote('set_id', id) + + id = property(get_id, set_id) + + def syncProperties(self, r): + try: + psync, result = r + except (ValueError, TypeError): + return r + else: + if psync: + log.msg("sync properties") + pick = self.checkReturnForFailure(psync) + if isinstance(pick, failure.Failure): + self.properties = pick + return pick + else: + self.properties = pickle.loads(pick) + return result + + def _set_properties(self, dikt): + self._properties.clear() + self._properties.update(dikt) + + def _get_properties(self): + if isinstance(self._properties, failure.Failure): + self._properties.raiseException() + return self._properties + + properties = property(_get_properties, _set_properties) + + #--------------------------------------------------------------------------- + # Methods from IEngine + #--------------------------------------------------------------------------- + + #--------------------------------------------------------------------------- + # execute + #--------------------------------------------------------------------------- + + def execute(self, lines): + # self._needProperties = True + d = self.callRemote('execute', lines) + d.addCallback(self.syncProperties) + return d.addCallback(self.checkReturnForFailure) + + #--------------------------------------------------------------------------- + # push + #--------------------------------------------------------------------------- + + def push(self, namespace): + try: + package = pickle.dumps(namespace, 2) + except: + return defer.fail(failure.Failure()) + else: + if isinstance(package, failure.Failure): + return defer.fail(package) + else: + d = self.callRemote('push', package) + return d.addCallback(self.checkReturnForFailure) + + #--------------------------------------------------------------------------- + # pull + #--------------------------------------------------------------------------- + + def pull(self, keys): + d = self.callRemote('pull', keys) + d.addCallback(self.checkReturnForFailure) + d.addCallback(pickle.loads) + return d + + #--------------------------------------------------------------------------- + # push/pull_function + #--------------------------------------------------------------------------- + + def push_function(self, namespace): + try: + package = pickle.dumps(canDict(namespace), 2) + except: + return defer.fail(failure.Failure()) + else: + if isinstance(package, failure.Failure): + return defer.fail(package) + else: + d = self.callRemote('push_function', package) + return d.addCallback(self.checkReturnForFailure) + + def pull_function(self, keys): + d = self.callRemote('pull_function', keys) + d.addCallback(self.checkReturnForFailure) + d.addCallback(pickle.loads) + # The usage of globals() here is an attempt to bind any pickled functions + # to the globals of this module. What we really want is to have it bound + # to the globals of the callers module. This will require walking the + # stack. BG 10/3/07. + if len(keys)==1: + d.addCallback(uncan, globals()) + elif len(keys)>1: + d.addCallback(uncanSequence, globals()) + return d + + #--------------------------------------------------------------------------- + # Other methods + #--------------------------------------------------------------------------- + + def get_result(self, i=None): + return self.callRemote('get_result', i).addCallback(self.checkReturnForFailure) + + def reset(self): + self._refreshProperties = True + d = self.callRemote('reset') + d.addCallback(self.syncProperties) + return d.addCallback(self.checkReturnForFailure) + + def kill(self): + #this will raise pb.PBConnectionLost on success + d = self.callRemote('kill') + d.addCallback(self.syncProperties) + d.addCallback(self.checkReturnForFailure) + d.addErrback(self.killBack) + return d + + def killBack(self, f): + log.msg('filling engine: %s' % f) + return None + + def keys(self): + return self.callRemote('keys').addCallback(self.checkReturnForFailure) + + #--------------------------------------------------------------------------- + # Properties methods + #--------------------------------------------------------------------------- + + def set_properties(self, properties): + try: + package = pickle.dumps(properties, 2) + except: + return defer.fail(failure.Failure()) + else: + if isinstance(package, failure.Failure): + return defer.fail(package) + else: + d = self.callRemote('set_properties', package) + return d.addCallback(self.checkReturnForFailure) + return d + + def get_properties(self, keys=None): + d = self.callRemote('get_properties', keys) + d.addCallback(self.checkReturnForFailure) + d.addCallback(pickle.loads) + return d + + def has_properties(self, keys): + d = self.callRemote('has_properties', keys) + d.addCallback(self.checkReturnForFailure) + d.addCallback(pickle.loads) + return d + + def del_properties(self, keys): + d = self.callRemote('del_properties', keys) + d.addCallback(self.checkReturnForFailure) + # d.addCallback(pickle.loads) + return d + + def clear_properties(self): + d = self.callRemote('clear_properties') + d.addCallback(self.checkReturnForFailure) + return d + + #--------------------------------------------------------------------------- + # push/pull_serialized + #--------------------------------------------------------------------------- + + def push_serialized(self, namespace): + """Older version of pushSerialize.""" + try: + package = pickle.dumps(namespace, 2) + except: + return defer.fail(failure.Failure()) + else: + if isinstance(package, failure.Failure): + return defer.fail(package) + else: + d = self.callRemote('push_serialized', package) + return d.addCallback(self.checkReturnForFailure) + + def pull_serialized(self, keys): + d = self.callRemote('pull_serialized', keys) + d.addCallback(self.checkReturnForFailure) + d.addCallback(pickle.loads) + return d + + #--------------------------------------------------------------------------- + # Misc + #--------------------------------------------------------------------------- + + def checkReturnForFailure(self, r): + """See if a returned value is a pickled Failure object. + + To distinguish between general pickled objects and pickled Failures, the + other side should prepend the string FAILURE: to any pickled Failure. + """ + return unpackageFailure(r) + + +components.registerAdapter(EngineFromReference, + RemoteReference, + IEngineBase) + + +#------------------------------------------------------------------------------- +# Now adapt an IControllerBase to incoming FC connections +#------------------------------------------------------------------------------- + + +class IFCControllerBase(Interface): + """ + Interface that tells how an Engine sees a Controller. + + In our architecture, the Controller listens for Engines to connect + and register. This interface defines that registration method as it is + exposed over the Foolscap network protocol + """ + + def remote_register_engine(self, engineReference, id=None, pid=None, pproperties=None): + """ + Register new engine on the controller. + + Engines must call this upon connecting to the controller if they + want to do work for the controller. + + See the documentation of `IControllerCore` for more details. + """ + + +class FCRemoteEngineRefFromService(Referenceable): + """ + Adapt an `IControllerBase` to an `IFCControllerBase`. + """ + + implements(IFCControllerBase) + + def __init__(self, service): + assert IControllerBase.providedBy(service), \ + "IControllerBase is not provided by " + repr(service) + self.service = service + + def remote_register_engine(self, engine_reference, id=None, pid=None, pproperties=None): + # First adapt the engine_reference to a basic non-queued engine + engine = IEngineBase(engine_reference) + if pproperties: + engine.properties = pickle.loads(pproperties) + # Make it an IQueuedEngine before registration + remote_engine = IEngineQueued(engine) + # Get the ip/port of the remote side + peer_address = engine_reference.tracker.broker.transport.getPeer() + ip = peer_address.host + port = peer_address.port + reg_dict = self.service.register_engine(remote_engine, id, ip, port, pid) + # Now setup callback for disconnect and unregistering the engine + def notify(*args): + return self.service.unregister_engine(reg_dict['id']) + engine_reference.tracker.broker.notifyOnDisconnect(notify) + + engine.notifier = notify + engine.stopNotifying = engine_reference.tracker.broker.dontNotifyOnDisconnect + + return reg_dict + + +components.registerAdapter(FCRemoteEngineRefFromService, + IControllerBase, + IFCControllerBase) diff --git a/IPython/kernel/engineservice.py b/IPython/kernel/engineservice.py new file mode 100644 index 0000000..88ce0d3 --- /dev/null +++ b/IPython/kernel/engineservice.py @@ -0,0 +1,864 @@ +# encoding: utf-8 +# -*- test-case-name: IPython.kernel.tests.test_engineservice -*- + +"""A Twisted Service Representation of the IPython core. + +The IPython Core exposed to the network is called the Engine. Its +representation in Twisted in the EngineService. Interfaces and adapters +are used to abstract out the details of the actual network protocol used. +The EngineService is an Engine that knows nothing about the actual protocol +used. + +The EngineService is exposed with various network protocols in modules like: + +enginepb.py +enginevanilla.py + +As of 12/12/06 the classes in this module have been simplified greatly. It was +felt that we had over-engineered things. To improve the maintainability of the +code we have taken out the ICompleteEngine interface and the completeEngine +method that automatically added methods to engines. + +""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import os, sys, copy +import cPickle as pickle +from new import instancemethod + +from twisted.application import service +from twisted.internet import defer, reactor +from twisted.python import log, failure, components +import zope.interface as zi + +from IPython.kernel.core.interpreter import Interpreter +from IPython.kernel import newserialized, error, util +from IPython.kernel.util import printer +from IPython.kernel.twistedutil import gatherBoth, DeferredList +from IPython.kernel import codeutil + + +#------------------------------------------------------------------------------- +# Interface specification for the Engine +#------------------------------------------------------------------------------- + +class IEngineCore(zi.Interface): + """The minimal required interface for the IPython Engine. + + This interface provides a formal specification of the IPython core. + All these methods should return deferreds regardless of what side of a + network connection they are on. + + In general, this class simply wraps a shell class and wraps its return + values as Deferred objects. If the underlying shell class method raises + an exception, this class should convert it to a twisted.failure.Failure + that will be propagated along the Deferred's errback chain. + + In addition, Failures are aggressive. By this, we mean that if a method + is performing multiple actions (like pulling multiple object) if any + single one fails, the entire method will fail with that Failure. It is + all or nothing. + """ + + id = zi.interface.Attribute("the id of the Engine object") + properties = zi.interface.Attribute("A dict of properties of the Engine") + + def execute(lines): + """Execute lines of Python code. + + Returns a dictionary with keys (id, number, stdin, stdout, stderr) + upon success. + + Returns a failure object if the execution of lines raises an exception. + """ + + def push(namespace): + """Push dict namespace into the user's namespace. + + Returns a deferred to None or a failure. + """ + + def pull(keys): + """Pulls values out of the user's namespace by keys. + + Returns a deferred to a tuple objects or a single object. + + Raises NameError if any one of objects doess not exist. + """ + + def push_function(namespace): + """Push a dict of key, function pairs into the user's namespace. + + Returns a deferred to None or a failure.""" + + def pull_function(keys): + """Pulls functions out of the user's namespace by keys. + + Returns a deferred to a tuple of functions or a single function. + + Raises NameError if any one of the functions does not exist. + """ + + def get_result(i=None): + """Get the stdin/stdout/stderr of command i. + + Returns a deferred to a dict with keys + (id, number, stdin, stdout, stderr). + + Raises IndexError if command i does not exist. + Raises TypeError if i in not an int. + """ + + def reset(): + """Reset the shell. + + This clears the users namespace. Won't cause modules to be + reloaded. Should also re-initialize certain variables like id. + """ + + def kill(): + """Kill the engine by stopping the reactor.""" + + def keys(): + """Return the top level variables in the users namspace. + + Returns a deferred to a dict.""" + + +class IEngineSerialized(zi.Interface): + """Push/Pull methods that take Serialized objects. + + All methods should return deferreds. + """ + + def push_serialized(namespace): + """Push a dict of keys and Serialized objects into the user's namespace.""" + + def pull_serialized(keys): + """Pull objects by key from the user's namespace as Serialized. + + Returns a list of or one Serialized. + + Raises NameError is any one of the objects does not exist. + """ + + +class IEngineProperties(zi.Interface): + """Methods for access to the properties object of an Engine""" + + properties = zi.Attribute("A StrictDict object, containing the properties") + + def set_properties(properties): + """set properties by key and value""" + + def get_properties(keys=None): + """get a list of properties by `keys`, if no keys specified, get all""" + + def del_properties(keys): + """delete properties by `keys`""" + + def has_properties(keys): + """get a list of bool values for whether `properties` has `keys`""" + + def clear_properties(): + """clear the properties dict""" + +class IEngineBase(IEngineCore, IEngineSerialized, IEngineProperties): + """The basic engine interface that EngineService will implement. + + This exists so it is easy to specify adapters that adapt to and from the + API that the basic EngineService implements. + """ + pass + +class IEngineQueued(IEngineBase): + """Interface for adding a queue to an IEngineBase. + + This interface extends the IEngineBase interface to add methods for managing + the engine's queue. The implicit details of this interface are that the + execution of all methods declared in IEngineBase should appropriately be + put through a queue before execution. + + All methods should return deferreds. + """ + + def clear_queue(): + """Clear the queue.""" + + def queue_status(): + """Get the queued and pending commands in the queue.""" + + def register_failure_observer(obs): + """Register an observer of pending Failures. + + The observer must implement IFailureObserver. + """ + + def unregister_failure_observer(obs): + """Unregister an observer of pending Failures.""" + + +class IEngineThreaded(zi.Interface): + """A place holder for threaded commands. + + All methods should return deferreds. + """ + pass + + +#------------------------------------------------------------------------------- +# Functions and classes to implement the EngineService +#------------------------------------------------------------------------------- + + +class StrictDict(dict): + """This is a strict copying dictionary for use as the interface to the + properties of an Engine. + :IMPORTANT: + This object copies the values you set to it, and returns copies to you + when you request them. The only way to change properties os explicitly + through the setitem and getitem of the dictionary interface. + Example: + >>> e = kernel.get_engine(id) + >>> L = someList + >>> e.properties['L'] = L + >>> L == e.properties['L'] + ... True + >>> L.append(something Else) + >>> L == e.properties['L'] + ... False + + getitem copies, so calls to methods of objects do not affect the + properties, as in the following example: + >>> e.properties[1] = range(2) + >>> print e.properties[1] + ... [0, 1] + >>> e.properties[1].append(2) + >>> print e.properties[1] + ... [0, 1] + + """ + def __init__(self, *args, **kwargs): + dict.__init__(self, *args, **kwargs) + self.modified = True + + def __getitem__(self, key): + return copy.deepcopy(dict.__getitem__(self, key)) + + def __setitem__(self, key, value): + # check if this entry is valid for transport around the network + # and copying + try: + pickle.dumps(key, 2) + pickle.dumps(value, 2) + newvalue = copy.deepcopy(value) + except: + raise error.InvalidProperty(value) + dict.__setitem__(self, key, newvalue) + self.modified = True + + def __delitem__(self, key): + dict.__delitem__(self, key) + self.modified = True + + def update(self, dikt): + for k,v in dikt.iteritems(): + self[k] = v + + def pop(self, key): + self.modified = True + return dict.pop(self, key) + + def popitem(self): + self.modified = True + return dict.popitem(self) + + def clear(self): + self.modified = True + dict.clear(self) + + def subDict(self, *keys): + d = {} + for key in keys: + d[key] = self[key] + return d + + + +class EngineAPI(object): + """This is the object through which the user can edit the `properties` + attribute of an Engine. + The Engine Properties object copies all object in and out of itself. + See the EngineProperties object for details. + """ + _fix=False + def __init__(self, id): + self.id = id + self.properties = StrictDict() + self._fix=True + + def __setattr__(self, k,v): + if self._fix: + raise error.KernelError("I am protected!") + else: + object.__setattr__(self, k, v) + + def __delattr__(self, key): + raise error.KernelError("I am protected!") + + +_apiDict = {} + +def get_engine(id): + """Get the Engine API object, whcih currently just provides the properties + object, by ID""" + global _apiDict + if not _apiDict.get(id): + _apiDict[id] = EngineAPI(id) + return _apiDict[id] + +def drop_engine(id): + """remove an engine""" + global _apiDict + if _apiDict.has_key(id): + del _apiDict[id] + +class EngineService(object, service.Service): + """Adapt a IPython shell into a IEngine implementing Twisted Service.""" + + zi.implements(IEngineBase) + name = 'EngineService' + + def __init__(self, shellClass=Interpreter, mpi=None): + """Create an EngineService. + + shellClass: something that implements IInterpreter or core1 + mpi: an mpi module that has rank and size attributes + """ + self.shellClass = shellClass + self.shell = self.shellClass() + self.mpi = mpi + self.id = None + self.properties = get_engine(self.id).properties + if self.mpi is not None: + log.msg("MPI started with rank = %i and size = %i" % + (self.mpi.rank, self.mpi.size)) + self.id = self.mpi.rank + self._seedNamespace() + + # Make id a property so that the shell can get the updated id + + def _setID(self, id): + self._id = id + self.properties = get_engine(id).properties + self.shell.push({'id': id}) + + def _getID(self): + return self._id + + id = property(_getID, _setID) + + def _seedNamespace(self): + self.shell.push({'mpi': self.mpi, 'id' : self.id}) + + def executeAndRaise(self, msg, callable, *args, **kwargs): + """Call a method of self.shell and wrap any exception.""" + d = defer.Deferred() + try: + result = callable(*args, **kwargs) + except: + # This gives the following: + # et=exception class + # ev=exception class instance + # tb=traceback object + et,ev,tb = sys.exc_info() + # This call adds attributes to the exception value + et,ev,tb = self.shell.formatTraceback(et,ev,tb,msg) + # Add another attribute + ev._ipython_engine_info = msg + f = failure.Failure(ev,et,None) + d.errback(f) + else: + d.callback(result) + + return d + + # The IEngine methods. See the interface for documentation. + + def execute(self, lines): + msg = {'engineid':self.id, + 'method':'execute', + 'args':[lines]} + d = self.executeAndRaise(msg, self.shell.execute, lines) + d.addCallback(self.addIDToResult) + return d + + def addIDToResult(self, result): + result['id'] = self.id + return result + + def push(self, namespace): + msg = {'engineid':self.id, + 'method':'push', + 'args':[repr(namespace.keys())]} + d = self.executeAndRaise(msg, self.shell.push, namespace) + return d + + def pull(self, keys): + msg = {'engineid':self.id, + 'method':'pull', + 'args':[repr(keys)]} + d = self.executeAndRaise(msg, self.shell.pull, keys) + return d + + def push_function(self, namespace): + msg = {'engineid':self.id, + 'method':'push_function', + 'args':[repr(namespace.keys())]} + d = self.executeAndRaise(msg, self.shell.push_function, namespace) + return d + + def pull_function(self, keys): + msg = {'engineid':self.id, + 'method':'pull_function', + 'args':[repr(keys)]} + d = self.executeAndRaise(msg, self.shell.pull_function, keys) + return d + + def get_result(self, i=None): + msg = {'engineid':self.id, + 'method':'get_result', + 'args':[repr(i)]} + d = self.executeAndRaise(msg, self.shell.getCommand, i) + d.addCallback(self.addIDToResult) + return d + + def reset(self): + msg = {'engineid':self.id, + 'method':'reset', + 'args':[]} + del self.shell + self.shell = self.shellClass() + self.properties.clear() + d = self.executeAndRaise(msg, self._seedNamespace) + return d + + def kill(self): + drop_engine(self.id) + try: + reactor.stop() + except RuntimeError: + log.msg('The reactor was not running apparently.') + return defer.fail() + else: + return defer.succeed(None) + + def keys(self): + """Return a list of variables names in the users top level namespace. + + This used to return a dict of all the keys/repr(values) in the + user's namespace. This was too much info for the ControllerService + to handle so it is now just a list of keys. + """ + + remotes = [] + for k in self.shell.user_ns.iterkeys(): + if k not in ['__name__', '_ih', '_oh', '__builtins__', + 'In', 'Out', '_', '__', '___', '__IP', 'input', 'raw_input']: + remotes.append(k) + return defer.succeed(remotes) + + def set_properties(self, properties): + msg = {'engineid':self.id, + 'method':'set_properties', + 'args':[repr(properties.keys())]} + return self.executeAndRaise(msg, self.properties.update, properties) + + def get_properties(self, keys=None): + msg = {'engineid':self.id, + 'method':'get_properties', + 'args':[repr(keys)]} + if keys is None: + keys = self.properties.keys() + return self.executeAndRaise(msg, self.properties.subDict, *keys) + + def _doDel(self, keys): + for key in keys: + del self.properties[key] + + def del_properties(self, keys): + msg = {'engineid':self.id, + 'method':'del_properties', + 'args':[repr(keys)]} + return self.executeAndRaise(msg, self._doDel, keys) + + def _doHas(self, keys): + return [self.properties.has_key(key) for key in keys] + + def has_properties(self, keys): + msg = {'engineid':self.id, + 'method':'has_properties', + 'args':[repr(keys)]} + return self.executeAndRaise(msg, self._doHas, keys) + + def clear_properties(self): + msg = {'engineid':self.id, + 'method':'clear_properties', + 'args':[]} + return self.executeAndRaise(msg, self.properties.clear) + + def push_serialized(self, sNamespace): + msg = {'engineid':self.id, + 'method':'push_serialized', + 'args':[repr(sNamespace.keys())]} + ns = {} + for k,v in sNamespace.iteritems(): + try: + unserialized = newserialized.IUnSerialized(v) + ns[k] = unserialized.getObject() + except: + return defer.fail() + return self.executeAndRaise(msg, self.shell.push, ns) + + def pull_serialized(self, keys): + msg = {'engineid':self.id, + 'method':'pull_serialized', + 'args':[repr(keys)]} + if isinstance(keys, str): + keys = [keys] + if len(keys)==1: + d = self.executeAndRaise(msg, self.shell.pull, keys) + d.addCallback(newserialized.serialize) + return d + elif len(keys)>1: + d = self.executeAndRaise(msg, self.shell.pull, keys) + @d.addCallback + def packThemUp(values): + serials = [] + for v in values: + try: + serials.append(newserialized.serialize(v)) + except: + return defer.fail(failure.Failure()) + return serials + return packThemUp + + +def queue(methodToQueue): + def queuedMethod(this, *args, **kwargs): + name = methodToQueue.__name__ + return this.submitCommand(Command(name, *args, **kwargs)) + return queuedMethod + +class QueuedEngine(object): + """Adapt an IEngineBase to an IEngineQueued by wrapping it. + + The resulting object will implement IEngineQueued which extends + IEngineCore which extends (IEngineBase, IEngineSerialized). + + This seems like the best way of handling it, but I am not sure. The + other option is to have the various base interfaces be used like + mix-in intefaces. The problem I have with this is adpatation is + more difficult and complicated because there can be can multiple + original and final Interfaces. + """ + + zi.implements(IEngineQueued) + + def __init__(self, engine): + """Create a QueuedEngine object from an engine + + engine: An implementor of IEngineCore and IEngineSerialized + keepUpToDate: whether to update the remote status when the + queue is empty. Defaults to False. + """ + + # This is the right way to do these tests rather than + # IEngineCore in list(zi.providedBy(engine)) which will only + # picks of the interfaces that are directly declared by engine. + assert IEngineBase.providedBy(engine), \ + "engine passed to QueuedEngine doesn't provide IEngineBase" + + self.engine = engine + self.id = engine.id + self.queued = [] + self.history = {} + self.engineStatus = {} + self.currentCommand = None + self.failureObservers = [] + + def _get_properties(self): + return self.engine.properties + + properties = property(_get_properties, lambda self, _: None) + # Queue management methods. You should not call these directly + + def submitCommand(self, cmd): + """Submit command to queue.""" + + d = defer.Deferred() + cmd.setDeferred(d) + if self.currentCommand is not None: + if self.currentCommand.finished: + # log.msg("Running command immediately: %r" % cmd) + self.currentCommand = cmd + self.runCurrentCommand() + else: # command is still running + # log.msg("Command is running: %r" % self.currentCommand) + # log.msg("Queueing: %r" % cmd) + self.queued.append(cmd) + else: + # log.msg("No current commands, running: %r" % cmd) + self.currentCommand = cmd + self.runCurrentCommand() + return d + + def runCurrentCommand(self): + """Run current command.""" + + cmd = self.currentCommand + f = getattr(self.engine, cmd.remoteMethod, None) + if f: + d = f(*cmd.args, **cmd.kwargs) + if cmd.remoteMethod is 'execute': + d.addCallback(self.saveResult) + d.addCallback(self.finishCommand) + d.addErrback(self.abortCommand) + else: + return defer.fail(AttributeError(cmd.remoteMethod)) + + def _flushQueue(self): + """Pop next command in queue and run it.""" + + if len(self.queued) > 0: + self.currentCommand = self.queued.pop(0) + self.runCurrentCommand() + + def saveResult(self, result): + """Put the result in the history.""" + self.history[result['number']] = result + return result + + def finishCommand(self, result): + """Finish currrent command.""" + + # The order of these commands is absolutely critical. + self.currentCommand.handleResult(result) + self.currentCommand.finished = True + self._flushQueue() + return result + + def abortCommand(self, reason): + """Abort current command. + + This eats the Failure but first passes it onto the Deferred that the + user has. + + It also clear out the queue so subsequence commands don't run. + """ + + # The order of these 3 commands is absolutely critical. The currentCommand + # must first be marked as finished BEFORE the queue is cleared and before + # the current command is sent the failure. + # Also, the queue must be cleared BEFORE the current command is sent the Failure + # otherwise the errback chain could trigger new commands to be added to the + # queue before we clear it. We should clear ONLY the commands that were in + # the queue when the error occured. + self.currentCommand.finished = True + s = "%r %r %r" % (self.currentCommand.remoteMethod, self.currentCommand.args, self.currentCommand.kwargs) + self.clear_queue(msg=s) + self.currentCommand.handleError(reason) + + return None + + #--------------------------------------------------------------------------- + # IEngineCore methods + #--------------------------------------------------------------------------- + + @queue + def execute(self, lines): + pass + + @queue + def push(self, namespace): + pass + + @queue + def pull(self, keys): + pass + + @queue + def push_function(self, namespace): + pass + + @queue + def pull_function(self, keys): + pass + + def get_result(self, i=None): + if i is None: + i = max(self.history.keys()+[None]) + + cmd = self.history.get(i, None) + # Uncomment this line to disable chaching of results + #cmd = None + if cmd is None: + return self.submitCommand(Command('get_result', i)) + else: + return defer.succeed(cmd) + + def reset(self): + self.clear_queue() + self.history = {} # reset the cache - I am not sure we should do this + return self.submitCommand(Command('reset')) + + def kill(self): + self.clear_queue() + return self.submitCommand(Command('kill')) + + @queue + def keys(self): + pass + + #--------------------------------------------------------------------------- + # IEngineSerialized methods + #--------------------------------------------------------------------------- + + @queue + def push_serialized(self, namespace): + pass + + @queue + def pull_serialized(self, keys): + pass + + #--------------------------------------------------------------------------- + # IEngineProperties methods + #--------------------------------------------------------------------------- + + @queue + def set_properties(self, namespace): + pass + + @queue + def get_properties(self, keys=None): + pass + + @queue + def del_properties(self, keys): + pass + + @queue + def has_properties(self, keys): + pass + + @queue + def clear_properties(self): + pass + + #--------------------------------------------------------------------------- + # IQueuedEngine methods + #--------------------------------------------------------------------------- + + def clear_queue(self, msg=''): + """Clear the queue, but doesn't cancel the currently running commmand.""" + + for cmd in self.queued: + cmd.deferred.errback(failure.Failure(error.QueueCleared(msg))) + self.queued = [] + return defer.succeed(None) + + def queue_status(self): + if self.currentCommand is not None: + if self.currentCommand.finished: + pending = repr(None) + else: + pending = repr(self.currentCommand) + else: + pending = repr(None) + dikt = {'queue':map(repr,self.queued), 'pending':pending} + return defer.succeed(dikt) + + def register_failure_observer(self, obs): + self.failureObservers.append(obs) + + def unregister_failure_observer(self, obs): + self.failureObservers.remove(obs) + + +# Now register QueuedEngine as an adpater class that makes an IEngineBase into a +# IEngineQueued. +components.registerAdapter(QueuedEngine, IEngineBase, IEngineQueued) + + +class Command(object): + """A command object that encapslates queued commands. + + This class basically keeps track of a command that has been queued + in a QueuedEngine. It manages the deferreds and hold the method to be called + and the arguments to that method. + """ + + + def __init__(self, remoteMethod, *args, **kwargs): + """Build a new Command object.""" + + self.remoteMethod = remoteMethod + self.args = args + self.kwargs = kwargs + self.finished = False + + def setDeferred(self, d): + """Sets the deferred attribute of the Command.""" + + self.deferred = d + + def __repr__(self): + if not self.args: + args = '' + else: + args = str(self.args)[1:-2] #cut off (...,) + for k,v in self.kwargs.iteritems(): + if args: + args += ', ' + args += '%s=%r' %(k,v) + return "%s(%s)" %(self.remoteMethod, args) + + def handleResult(self, result): + """When the result is ready, relay it to self.deferred.""" + + self.deferred.callback(result) + + def handleError(self, reason): + """When an error has occured, relay it to self.deferred.""" + + self.deferred.errback(reason) + +class ThreadedEngineService(EngineService): + + zi.implements(IEngineBase) + + def __init__(self, shellClass=Interpreter, mpi=None): + EngineService.__init__(self, shellClass, mpi) + # Only import this if we are going to use this class + from twisted.internet import threads + + def execute(self, lines): + msg = """engine: %r +method: execute(lines) +lines = %s""" % (self.id, lines) + d = threads.deferToThread(self.executeAndRaise, msg, self.shell.execute, lines) + d.addCallback(self.addIDToResult) + return d diff --git a/IPython/kernel/error.py b/IPython/kernel/error.py new file mode 100644 index 0000000..3aaa78c --- /dev/null +++ b/IPython/kernel/error.py @@ -0,0 +1,185 @@ +# encoding: utf-8 + +"""Classes and functions for kernel related errors and exceptions.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from IPython.kernel.core import error +from twisted.python import failure + +#------------------------------------------------------------------------------- +# Error classes +#------------------------------------------------------------------------------- + +class KernelError(error.IPythonError): + pass + +class NotDefined(KernelError): + def __init__(self, name): + self.name = name + self.args = (name,) + + def __repr__(self): + return '' % self.name + + __str__ = __repr__ + +class QueueCleared(KernelError): + pass + +class IdInUse(KernelError): + pass + +class ProtocolError(KernelError): + pass + +class ConnectionError(KernelError): + pass + +class InvalidEngineID(KernelError): + pass + +class NoEnginesRegistered(KernelError): + pass + +class InvalidClientID(KernelError): + pass + +class InvalidDeferredID(KernelError): + pass + +class SerializationError(KernelError): + pass + +class MessageSizeError(KernelError): + pass + +class PBMessageSizeError(MessageSizeError): + pass + +class ResultNotCompleted(KernelError): + pass + +class ResultAlreadyRetrieved(KernelError): + pass + +class ClientError(KernelError): + pass + +class TaskAborted(KernelError): + pass + +class TaskTimeout(KernelError): + pass + +class NotAPendingResult(KernelError): + pass + +class UnpickleableException(KernelError): + pass + +class AbortedPendingDeferredError(KernelError): + pass + +class InvalidProperty(KernelError): + pass + +class MissingBlockArgument(KernelError): + pass + +class StopLocalExecution(KernelError): + pass + +class SecurityError(KernelError): + pass + +class CompositeError(KernelError): + def __init__(self, message, elist): + Exception.__init__(self, *(message, elist)) + self.message = message + self.elist = elist + + def _get_engine_str(self, ev): + try: + ei = ev._ipython_engine_info + except AttributeError: + return '[Engine Exception]' + else: + return '[%i:%s]: ' % (ei['engineid'], ei['method']) + + def _get_traceback(self, ev): + try: + tb = ev._ipython_traceback_text + except AttributeError: + return 'No traceback available' + else: + return tb + + def __str__(self): + s = str(self.message) + for et, ev, etb in self.elist: + engine_str = self._get_engine_str(ev) + s = s + '\n' + engine_str + str(et.__name__) + ': ' + str(ev) + return s + + def print_tracebacks(self, excid=None): + if excid is None: + for (et,ev,etb) in self.elist: + print self._get_engine_str(ev) + print self._get_traceback(ev) + print + else: + try: + et,ev,etb = self.elist[excid] + except: + raise IndexError("an exception with index %i does not exist"%excid) + else: + print self._get_engine_str(ev) + print self._get_traceback(ev) + + def raise_exception(self, excid=0): + try: + et,ev,etb = self.elist[excid] + except: + raise IndexError("an exception with index %i does not exist"%excid) + else: + raise et, ev, etb + +def collect_exceptions(rlist, method): + elist = [] + for r in rlist: + if isinstance(r, failure.Failure): + r.cleanFailure() + et, ev, etb = r.type, r.value, r.tb + # Sometimes we could have CompositeError in our list. Just take + # the errors out of them and put them in our new list. This + # has the effect of flattening lists of CompositeErrors into one + # CompositeError + if et==CompositeError: + for e in ev.elist: + elist.append(e) + else: + elist.append((et, ev, etb)) + if len(elist)==0: + return rlist + else: + msg = "one or more exceptions from call to method: %s" % (method) + # This silliness is needed so the debugger has access to the exception + # instance (e in this case) + try: + raise CompositeError(msg, elist) + except CompositeError, e: + raise e + + diff --git a/IPython/kernel/fcutil.py b/IPython/kernel/fcutil.py new file mode 100644 index 0000000..9f7c730 --- /dev/null +++ b/IPython/kernel/fcutil.py @@ -0,0 +1,69 @@ +# encoding: utf-8 + +"""Foolscap related utilities.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import os + +from foolscap import Tub, UnauthenticatedTub + +def check_furl_file_security(furl_file, secure): + """Remove the old furl_file if changing security modes.""" + + if os.path.isfile(furl_file): + f = open(furl_file, 'r') + oldfurl = f.read().strip() + f.close() + if (oldfurl.startswith('pb://') and not secure) or (oldfurl.startswith('pbu://') and secure): + os.remove(furl_file) + +def is_secure(furl): + if is_valid(furl): + if furl.startswith("pb://"): + return True + elif furl.startswith("pbu://"): + return False + else: + raise ValueError("invalid furl: %s" % furl) + +def is_valid(furl): + if isinstance(furl, str): + if furl.startswith("pb://") or furl.startswith("pbu://"): + return True + else: + return False + +def find_furl(furl_or_file): + if isinstance(furl_or_file, str): + if is_valid(furl_or_file): + return furl_or_file + if os.path.isfile(furl_or_file): + furl = open(furl_or_file, 'r').read().strip() + if is_valid(furl): + return furl + raise ValueError("not a furl or a file containing a furl: %s" % furl_or_file) + +# We do this so if a user doesn't have OpenSSL installed, it will try to use +# an UnauthenticatedTub. But, they will still run into problems if they +# try to use encrypted furls. +try: + import OpenSSL +except: + Tub = UnauthenticatedTub + have_crypto = False +else: + have_crypto = True + + diff --git a/IPython/kernel/magic.py b/IPython/kernel/magic.py new file mode 100644 index 0000000..cefca8b --- /dev/null +++ b/IPython/kernel/magic.py @@ -0,0 +1,171 @@ +# encoding: utf-8 + +"""Magic command interface for interactive parallel work.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import new + +from IPython.iplib import InteractiveShell +from IPython.Shell import MTInteractiveShell + +from twisted.internet.defer import Deferred + + +#------------------------------------------------------------------------------- +# Definitions of magic functions for use with IPython +#------------------------------------------------------------------------------- + +NO_ACTIVE_CONTROLLER = """ +Error: No Controller is activated +Use activate() on a RemoteController object to activate it for magics. +""" + +def magic_result(self,parameter_s=''): + """Print the result of command i on all engines of the active controller. + + To activate a controller in IPython, first create it and then call + the activate() method. + + Then you can do the following: + + >>> result # Print the latest result + Printing result... + [127.0.0.1:0] In [1]: b = 10 + [127.0.0.1:1] In [1]: b = 10 + + >>> result 0 # Print result 0 + In [14]: result 0 + Printing result... + [127.0.0.1:0] In [0]: a = 5 + [127.0.0.1:1] In [0]: a = 5 + """ + try: + activeController = __IPYTHON__.activeController + except AttributeError: + print NO_ACTIVE_CONTROLLER + else: + try: + index = int(parameter_s) + except: + index = None + result = activeController.get_result(index) + return result + +def magic_px(self,parameter_s=''): + """Executes the given python command on the active IPython Controller. + + To activate a Controller in IPython, first create it and then call + the activate() method. + + Then you can do the following: + + >>> %px a = 5 # Runs a = 5 on all nodes + """ + + try: + activeController = __IPYTHON__.activeController + except AttributeError: + print NO_ACTIVE_CONTROLLER + else: + print "Executing command on Controller" + result = activeController.execute(parameter_s) + return result + +def pxrunsource(self, source, filename="", symbol="single"): + + try: + code = self.compile(source, filename, symbol) + except (OverflowError, SyntaxError, ValueError): + # Case 1 + self.showsyntaxerror(filename) + return None + + if code is None: + # Case 2 + return True + + # Case 3 + # Because autopx is enabled, we now call executeAll or disable autopx if + # %autopx or autopx has been called + if '_ip.magic("%autopx' in source or '_ip.magic("autopx' in source: + _disable_autopx(self) + return False + else: + try: + result = self.activeController.execute(source) + except: + self.showtraceback() + else: + print result.__repr__() + return False + +def magic_autopx(self, parameter_s=''): + """Toggles auto parallel mode for the active IPython Controller. + + To activate a Controller in IPython, first create it and then call + the activate() method. + + Then you can do the following: + + >>> %autopx # Now all commands are executed in parallel + Auto Parallel Enabled + Type %autopx to disable + ... + >>> %autopx # Now all commands are locally executed + Auto Parallel Disabled + """ + + if hasattr(self, 'autopx'): + if self.autopx == True: + _disable_autopx(self) + else: + _enable_autopx(self) + else: + _enable_autopx(self) + +def _enable_autopx(self): + """Enable %autopx mode by saving the original runsource and installing + pxrunsource. + """ + try: + activeController = __IPYTHON__.activeController + except AttributeError: + print "No active RemoteController found, use RemoteController.activate()." + else: + self._original_runsource = self.runsource + self.runsource = new.instancemethod(pxrunsource, self, self.__class__) + self.autopx = True + print "Auto Parallel Enabled\nType %autopx to disable" + +def _disable_autopx(self): + """Disable %autopx by restoring the original runsource.""" + if hasattr(self, 'autopx'): + if self.autopx == True: + self.runsource = self._original_runsource + self.autopx = False + print "Auto Parallel Disabled" + +# Add the new magic function to the class dict: + +InteractiveShell.magic_result = magic_result +InteractiveShell.magic_px = magic_px +InteractiveShell.magic_autopx = magic_autopx + +# And remove the global name to keep global namespace clean. Don't worry, the +# copy bound to IPython stays, we're just removing the global name. +del magic_result +del magic_px +del magic_autopx + diff --git a/IPython/kernel/map.py b/IPython/kernel/map.py new file mode 100644 index 0000000..2e0b932 --- /dev/null +++ b/IPython/kernel/map.py @@ -0,0 +1,121 @@ +# encoding: utf-8 + +"""Classes used in scattering and gathering sequences. + +Scattering consists of partitioning a sequence and sending the various +pieces to individual nodes in a cluster. +""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import types + +from IPython.genutils import flatten as genutil_flatten + +#------------------------------------------------------------------------------- +# Figure out which array packages are present and their array types +#------------------------------------------------------------------------------- + +arrayModules = [] +try: + import Numeric +except ImportError: + pass +else: + arrayModules.append({'module':Numeric, 'type':Numeric.arraytype}) +try: + import numpy +except ImportError: + pass +else: + arrayModules.append({'module':numpy, 'type':numpy.ndarray}) +try: + import numarray +except ImportError: + pass +else: + arrayModules.append({'module':numarray, + 'type':numarray.numarraycore.NumArray}) + +class Map: + """A class for partitioning a sequence using a map.""" + + def getPartition(self, seq, p, q): + """Returns the pth partition of q partitions of seq.""" + + # Test for error conditions here + if p<0 or p>=q: + print "No partition exists." + return + + remainder = len(seq)%q + basesize = len(seq)/q + hi = [] + lo = [] + for n in range(q): + if n < remainder: + lo.append(n * (basesize + 1)) + hi.append(lo[-1] + basesize + 1) + else: + lo.append(n*basesize + remainder) + hi.append(lo[-1] + basesize) + + + result = seq[lo[p]:hi[p]] + return result + + def joinPartitions(self, listOfPartitions): + return self.concatenate(listOfPartitions) + + def concatenate(self, listOfPartitions): + testObject = listOfPartitions[0] + # First see if we have a known array type + for m in arrayModules: + #print m + if isinstance(testObject, m['type']): + return m['module'].concatenate(listOfPartitions) + # Next try for Python sequence types + if isinstance(testObject, (types.ListType, types.TupleType)): + return genutil_flatten(listOfPartitions) + # If we have scalars, just return listOfPartitions + return listOfPartitions + +class RoundRobinMap(Map): + """Partitions a sequence in a roun robin fashion. + + This currently does not work! + """ + + def getPartition(self, seq, p, q): + return seq[p:len(seq):q] + #result = [] + #for i in range(p,len(seq),q): + # result.append(seq[i]) + #return result + + def joinPartitions(self, listOfPartitions): + #lengths = [len(x) for x in listOfPartitions] + #maxPartitionLength = len(listOfPartitions[0]) + #numberOfPartitions = len(listOfPartitions) + #concat = self.concatenate(listOfPartitions) + #totalLength = len(concat) + #result = [] + #for i in range(maxPartitionLength): + # result.append(concat[i:totalLength:maxPartitionLength]) + return self.concatenate(listOfPartitions) + +styles = {'basic':Map} + + + diff --git a/IPython/kernel/multiengine.py b/IPython/kernel/multiengine.py new file mode 100644 index 0000000..930f05f --- /dev/null +++ b/IPython/kernel/multiengine.py @@ -0,0 +1,780 @@ +# encoding: utf-8 +# -*- test-case-name: IPython.kernel.test.test_multiengine -*- + +"""Adapt the IPython ControllerServer to IMultiEngine. + +This module provides classes that adapt a ControllerService to the +IMultiEngine interface. This interface is a basic interactive interface +for working with a set of engines where it is desired to have explicit +access to each registered engine. + +The classes here are exposed to the network in files like: + +* multienginevanilla.py +* multienginepb.py +""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from new import instancemethod +from types import FunctionType + +from twisted.application import service +from twisted.internet import defer, reactor +from twisted.python import log, components, failure +from zope.interface import Interface, implements, Attribute + +from IPython.tools import growl +from IPython.kernel.util import printer +from IPython.kernel.twistedutil import gatherBoth +from IPython.kernel import map as Map +from IPython.kernel import error +from IPython.kernel.pendingdeferred import PendingDeferredManager, two_phase +from IPython.kernel.controllerservice import \ + ControllerAdapterBase, \ + ControllerService, \ + IControllerBase + + +#------------------------------------------------------------------------------- +# Interfaces for the MultiEngine representation of a controller +#------------------------------------------------------------------------------- + +class IEngineMultiplexer(Interface): + """Interface to multiple engines implementing IEngineCore/Serialized/Queued. + + This class simply acts as a multiplexer of methods that are in the + various IEngines* interfaces. Thus the methods here are jut like those + in the IEngine* interfaces, but with an extra first argument, targets. + The targets argument can have the following forms: + + * targets = 10 # Engines are indexed by ints + * targets = [0,1,2,3] # A list of ints + * targets = 'all' # A string to indicate all targets + + If targets is bad in any way, an InvalidEngineID will be raised. This + includes engines not being registered. + + All IEngineMultiplexer multiplexer methods must return a Deferred to a list + with length equal to the number of targets. The elements of the list will + correspond to the return of the corresponding IEngine method. + + Failures are aggressive, meaning that if an action fails for any target, + the overall action will fail immediately with that Failure. + + :Parameters: + targets : int, list of ints, or 'all' + Engine ids the action will apply to. + + :Returns: Deferred to a list of results for each engine. + + :Exception: + InvalidEngineID + If the targets argument is bad or engines aren't registered. + NoEnginesRegistered + If there are no engines registered and targets='all' + """ + + #--------------------------------------------------------------------------- + # Mutiplexed methods + #--------------------------------------------------------------------------- + + def execute(lines, targets='all'): + """Execute lines of Python code on targets. + + See the class docstring for information about targets and possible + exceptions this method can raise. + + :Parameters: + lines : str + String of python code to be executed on targets. + """ + + def push(namespace, targets='all'): + """Push dict namespace into the user's namespace on targets. + + See the class docstring for information about targets and possible + exceptions this method can raise. + + :Parameters: + namspace : dict + Dict of key value pairs to be put into the users namspace. + """ + + def pull(keys, targets='all'): + """Pull values out of the user's namespace on targets by keys. + + See the class docstring for information about targets and possible + exceptions this method can raise. + + :Parameters: + keys : tuple of strings + Sequence of keys to be pulled from user's namespace. + """ + + def push_function(namespace, targets='all'): + """""" + + def pull_function(keys, targets='all'): + """""" + + def get_result(i=None, targets='all'): + """Get the result for command i from targets. + + See the class docstring for information about targets and possible + exceptions this method can raise. + + :Parameters: + i : int or None + Command index or None to indicate most recent command. + """ + + def reset(targets='all'): + """Reset targets. + + This clears the users namespace of the Engines, but won't cause + modules to be reloaded. + """ + + def keys(targets='all'): + """Get variable names defined in user's namespace on targets.""" + + def kill(controller=False, targets='all'): + """Kill the targets Engines and possibly the controller. + + :Parameters: + controller : boolean + Should the controller be killed as well. If so all the + engines will be killed first no matter what targets is. + """ + + def push_serialized(namespace, targets='all'): + """Push a namespace of Serialized objects to targets. + + :Parameters: + namespace : dict + A dict whose keys are the variable names and whose values + are serialized version of the objects. + """ + + def pull_serialized(keys, targets='all'): + """Pull Serialized objects by keys from targets. + + :Parameters: + keys : tuple of strings + Sequence of variable names to pull as serialized objects. + """ + + def clear_queue(targets='all'): + """Clear the queue of pending command for targets.""" + + def queue_status(targets='all'): + """Get the status of the queue on the targets.""" + + def set_properties(properties, targets='all'): + """set properties by key and value""" + + def get_properties(keys=None, targets='all'): + """get a list of properties by `keys`, if no keys specified, get all""" + + def del_properties(keys, targets='all'): + """delete properties by `keys`""" + + def has_properties(keys, targets='all'): + """get a list of bool values for whether `properties` has `keys`""" + + def clear_properties(targets='all'): + """clear the properties dict""" + + +class IMultiEngine(IEngineMultiplexer): + """A controller that exposes an explicit interface to all of its engines. + + This is the primary inteface for interactive usage. + """ + + def get_ids(): + """Return list of currently registered ids. + + :Returns: A Deferred to a list of registered engine ids. + """ + + + +#------------------------------------------------------------------------------- +# Implementation of the core MultiEngine classes +#------------------------------------------------------------------------------- + +class MultiEngine(ControllerAdapterBase): + """The representation of a ControllerService as a IMultiEngine. + + Although it is not implemented currently, this class would be where a + client/notification API is implemented. It could inherit from something + like results.NotifierParent and then use the notify method to send + notifications. + """ + + implements(IMultiEngine) + + def __init(self, controller): + ControllerAdapterBase.__init__(self, controller) + + #--------------------------------------------------------------------------- + # Helper methods + #--------------------------------------------------------------------------- + + def engineList(self, targets): + """Parse the targets argument into a list of valid engine objects. + + :Parameters: + targets : int, list of ints or 'all' + The targets argument to be parsed. + + :Returns: List of engine objects. + + :Exception: + InvalidEngineID + If targets is not valid or if an engine is not registered. + """ + if isinstance(targets, int): + if targets not in self.engines.keys(): + log.msg("Engine with id %i is not registered" % targets) + raise error.InvalidEngineID("Engine with id %i is not registered" % targets) + else: + return [self.engines[targets]] + elif isinstance(targets, (list, tuple)): + for id in targets: + if id not in self.engines.keys(): + log.msg("Engine with id %r is not registered" % id) + raise error.InvalidEngineID("Engine with id %r is not registered" % id) + return map(self.engines.get, targets) + elif targets == 'all': + eList = self.engines.values() + if len(eList) == 0: + msg = """There are no engines registered. + Check the logs in ~/.ipython/log if you think there should have been.""" + raise error.NoEnginesRegistered(msg) + else: + return eList + else: + raise error.InvalidEngineID("targets argument is not an int, list of ints or 'all': %r"%targets) + + def _performOnEngines(self, methodName, *args, **kwargs): + """Calls a method on engines and returns deferred to list of results. + + :Parameters: + methodName : str + Name of the method to be called. + targets : int, list of ints, 'all' + The targets argument to be parsed into a list of engine objects. + args + The positional keyword arguments to be passed to the engines. + kwargs + The keyword arguments passed to the method + + :Returns: List of deferreds to the results on each engine + + :Exception: + InvalidEngineID + If the targets argument is bad in any way. + AttributeError + If the method doesn't exist on one of the engines. + """ + targets = kwargs.pop('targets') + log.msg("Performing %s on %r" % (methodName, targets)) + # log.msg("Performing %s(%r, %r) on %r" % (methodName, args, kwargs, targets)) + # This will and should raise if targets is not valid! + engines = self.engineList(targets) + dList = [] + for e in engines: + meth = getattr(e, methodName, None) + if meth is not None: + dList.append(meth(*args, **kwargs)) + else: + raise AttributeError("Engine %i does not have method %s" % (e.id, methodName)) + return dList + + def _performOnEnginesAndGatherBoth(self, methodName, *args, **kwargs): + """Called _performOnEngines and wraps result/exception into deferred.""" + try: + dList = self._performOnEngines(methodName, *args, **kwargs) + except (error.InvalidEngineID, AttributeError, KeyError, error.NoEnginesRegistered): + return defer.fail(failure.Failure()) + else: + # Having fireOnOneErrback is causing problems with the determinacy + # of the system. Basically, once a single engine has errbacked, this + # method returns. In some cases, this will cause client to submit + # another command. Because the previous command is still running + # on some engines, this command will be queued. When those commands + # then errback, the second command will raise QueueCleared. Ahhh! + d = gatherBoth(dList, + fireOnOneErrback=0, + consumeErrors=1, + logErrors=0) + d.addCallback(error.collect_exceptions, methodName) + return d + + #--------------------------------------------------------------------------- + # General IMultiEngine methods + #--------------------------------------------------------------------------- + + def get_ids(self): + return defer.succeed(self.engines.keys()) + + #--------------------------------------------------------------------------- + # IEngineMultiplexer methods + #--------------------------------------------------------------------------- + + def execute(self, lines, targets='all'): + return self._performOnEnginesAndGatherBoth('execute', lines, targets=targets) + + def push(self, ns, targets='all'): + return self._performOnEnginesAndGatherBoth('push', ns, targets=targets) + + def pull(self, keys, targets='all'): + return self._performOnEnginesAndGatherBoth('pull', keys, targets=targets) + + def push_function(self, ns, targets='all'): + return self._performOnEnginesAndGatherBoth('push_function', ns, targets=targets) + + def pull_function(self, keys, targets='all'): + return self._performOnEnginesAndGatherBoth('pull_function', keys, targets=targets) + + def get_result(self, i=None, targets='all'): + return self._performOnEnginesAndGatherBoth('get_result', i, targets=targets) + + def reset(self, targets='all'): + return self._performOnEnginesAndGatherBoth('reset', targets=targets) + + def keys(self, targets='all'): + return self._performOnEnginesAndGatherBoth('keys', targets=targets) + + def kill(self, controller=False, targets='all'): + if controller: + targets = 'all' + d = self._performOnEnginesAndGatherBoth('kill', targets=targets) + if controller: + log.msg("Killing controller") + d.addCallback(lambda _: reactor.callLater(2.0, reactor.stop)) + # Consume any weird stuff coming back + d.addBoth(lambda _: None) + return d + + def push_serialized(self, namespace, targets='all'): + for k, v in namespace.iteritems(): + log.msg("Pushed object %s is %f MB" % (k, v.getDataSize())) + d = self._performOnEnginesAndGatherBoth('push_serialized', namespace, targets=targets) + return d + + def pull_serialized(self, keys, targets='all'): + try: + dList = self._performOnEngines('pull_serialized', keys, targets=targets) + except (error.InvalidEngineID, AttributeError, error.NoEnginesRegistered): + return defer.fail(failure.Failure()) + else: + for d in dList: + d.addCallback(self._logSizes) + d = gatherBoth(dList, + fireOnOneErrback=0, + consumeErrors=1, + logErrors=0) + d.addCallback(error.collect_exceptions, 'pull_serialized') + return d + + def _logSizes(self, listOfSerialized): + if isinstance(listOfSerialized, (list, tuple)): + for s in listOfSerialized: + log.msg("Pulled object is %f MB" % s.getDataSize()) + else: + log.msg("Pulled object is %f MB" % listOfSerialized.getDataSize()) + return listOfSerialized + + def clear_queue(self, targets='all'): + return self._performOnEnginesAndGatherBoth('clear_queue', targets=targets) + + def queue_status(self, targets='all'): + log.msg("Getting queue status on %r" % targets) + try: + engines = self.engineList(targets) + except (error.InvalidEngineID, AttributeError, error.NoEnginesRegistered): + return defer.fail(failure.Failure()) + else: + dList = [] + for e in engines: + dList.append(e.queue_status().addCallback(lambda s:(e.id, s))) + d = gatherBoth(dList, + fireOnOneErrback=0, + consumeErrors=1, + logErrors=0) + d.addCallback(error.collect_exceptions, 'queue_status') + return d + + def get_properties(self, keys=None, targets='all'): + log.msg("Getting properties on %r" % targets) + try: + engines = self.engineList(targets) + except (error.InvalidEngineID, AttributeError, error.NoEnginesRegistered): + return defer.fail(failure.Failure()) + else: + dList = [e.get_properties(keys) for e in engines] + d = gatherBoth(dList, + fireOnOneErrback=0, + consumeErrors=1, + logErrors=0) + d.addCallback(error.collect_exceptions, 'get_properties') + return d + + def set_properties(self, properties, targets='all'): + log.msg("Setting properties on %r" % targets) + try: + engines = self.engineList(targets) + except (error.InvalidEngineID, AttributeError, error.NoEnginesRegistered): + return defer.fail(failure.Failure()) + else: + dList = [e.set_properties(properties) for e in engines] + d = gatherBoth(dList, + fireOnOneErrback=0, + consumeErrors=1, + logErrors=0) + d.addCallback(error.collect_exceptions, 'set_properties') + return d + + def has_properties(self, keys, targets='all'): + log.msg("Checking properties on %r" % targets) + try: + engines = self.engineList(targets) + except (error.InvalidEngineID, AttributeError, error.NoEnginesRegistered): + return defer.fail(failure.Failure()) + else: + dList = [e.has_properties(keys) for e in engines] + d = gatherBoth(dList, + fireOnOneErrback=0, + consumeErrors=1, + logErrors=0) + d.addCallback(error.collect_exceptions, 'has_properties') + return d + + def del_properties(self, keys, targets='all'): + log.msg("Deleting properties on %r" % targets) + try: + engines = self.engineList(targets) + except (error.InvalidEngineID, AttributeError, error.NoEnginesRegistered): + return defer.fail(failure.Failure()) + else: + dList = [e.del_properties(keys) for e in engines] + d = gatherBoth(dList, + fireOnOneErrback=0, + consumeErrors=1, + logErrors=0) + d.addCallback(error.collect_exceptions, 'del_properties') + return d + + def clear_properties(self, targets='all'): + log.msg("Clearing properties on %r" % targets) + try: + engines = self.engineList(targets) + except (error.InvalidEngineID, AttributeError, error.NoEnginesRegistered): + return defer.fail(failure.Failure()) + else: + dList = [e.clear_properties() for e in engines] + d = gatherBoth(dList, + fireOnOneErrback=0, + consumeErrors=1, + logErrors=0) + d.addCallback(error.collect_exceptions, 'clear_properties') + return d + + +components.registerAdapter(MultiEngine, + IControllerBase, + IMultiEngine) + + +#------------------------------------------------------------------------------- +# Interfaces for the Synchronous MultiEngine +#------------------------------------------------------------------------------- + +class ISynchronousEngineMultiplexer(Interface): + pass + + +class ISynchronousMultiEngine(ISynchronousEngineMultiplexer): + """Synchronous, two-phase version of IMultiEngine. + + Methods in this interface are identical to those of IMultiEngine, but they + take one additional argument: + + execute(lines, targets='all') -> execute(lines, targets='all, block=True) + + :Parameters: + block : boolean + Should the method return a deferred to a deferredID or the + actual result. If block=False a deferred to a deferredID is + returned and the user must call `get_pending_deferred` at a later + point. If block=True, a deferred to the actual result comes back. + """ + def get_pending_deferred(deferredID, block=True): + """""" + + def clear_pending_deferreds(): + """""" + + +#------------------------------------------------------------------------------- +# Implementation of the Synchronous MultiEngine +#------------------------------------------------------------------------------- + +class SynchronousMultiEngine(PendingDeferredManager): + """Adapt an `IMultiEngine` -> `ISynchronousMultiEngine` + + Warning, this class uses a decorator that currently uses **kwargs. + Because of this block must be passed as a kwarg, not positionally. + """ + + implements(ISynchronousMultiEngine) + + def __init__(self, multiengine): + self.multiengine = multiengine + PendingDeferredManager.__init__(self) + + #--------------------------------------------------------------------------- + # Decorated pending deferred methods + #--------------------------------------------------------------------------- + + @two_phase + def execute(self, lines, targets='all'): + d = self.multiengine.execute(lines, targets) + return d + + @two_phase + def push(self, namespace, targets='all'): + return self.multiengine.push(namespace, targets) + + @two_phase + def pull(self, keys, targets='all'): + d = self.multiengine.pull(keys, targets) + return d + + @two_phase + def push_function(self, namespace, targets='all'): + return self.multiengine.push_function(namespace, targets) + + @two_phase + def pull_function(self, keys, targets='all'): + d = self.multiengine.pull_function(keys, targets) + return d + + @two_phase + def get_result(self, i=None, targets='all'): + return self.multiengine.get_result(i, targets='all') + + @two_phase + def reset(self, targets='all'): + return self.multiengine.reset(targets) + + @two_phase + def keys(self, targets='all'): + return self.multiengine.keys(targets) + + @two_phase + def kill(self, controller=False, targets='all'): + return self.multiengine.kill(controller, targets) + + @two_phase + def push_serialized(self, namespace, targets='all'): + return self.multiengine.push_serialized(namespace, targets) + + @two_phase + def pull_serialized(self, keys, targets='all'): + return self.multiengine.pull_serialized(keys, targets) + + @two_phase + def clear_queue(self, targets='all'): + return self.multiengine.clear_queue(targets) + + @two_phase + def queue_status(self, targets='all'): + return self.multiengine.queue_status(targets) + + @two_phase + def set_properties(self, properties, targets='all'): + return self.multiengine.set_properties(properties, targets) + + @two_phase + def get_properties(self, keys=None, targets='all'): + return self.multiengine.get_properties(keys, targets) + + @two_phase + def has_properties(self, keys, targets='all'): + return self.multiengine.has_properties(keys, targets) + + @two_phase + def del_properties(self, keys, targets='all'): + return self.multiengine.del_properties(keys, targets) + + @two_phase + def clear_properties(self, targets='all'): + return self.multiengine.clear_properties(targets) + + #--------------------------------------------------------------------------- + # IMultiEngine methods + #--------------------------------------------------------------------------- + + def get_ids(self): + """Return a list of registered engine ids. + + Never use the two phase block/non-block stuff for this. + """ + return self.multiengine.get_ids() + + +components.registerAdapter(SynchronousMultiEngine, IMultiEngine, ISynchronousMultiEngine) + + +#------------------------------------------------------------------------------- +# Various high-level interfaces that can be used as MultiEngine mix-ins +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# IMultiEngineCoordinator +#------------------------------------------------------------------------------- + +class IMultiEngineCoordinator(Interface): + """Methods that work on multiple engines explicitly.""" + + def scatter(key, seq, style='basic', flatten=False, targets='all'): + """Partition and distribute a sequence to targets. + + :Parameters: + key : str + The variable name to call the scattered sequence. + seq : list, tuple, array + The sequence to scatter. The type should be preserved. + style : string + A specification of how the sequence is partitioned. Currently + only 'basic' is implemented. + flatten : boolean + Should single element sequences be converted to scalars. + """ + + def gather(key, style='basic', targets='all'): + """Gather object key from targets. + + :Parameters: + key : string + The name of a sequence on the targets to gather. + style : string + A specification of how the sequence is partitioned. Currently + only 'basic' is implemented. + """ + + def map(func, seq, style='basic', targets='all'): + """A parallelized version of Python's builtin map. + + This function implements the following pattern: + + 1. The sequence seq is scattered to the given targets. + 2. map(functionSource, seq) is called on each engine. + 3. The resulting sequences are gathered back to the local machine. + + :Parameters: + targets : int, list or 'all' + The engine ids the action will apply to. Call `get_ids` to see + a list of currently available engines. + func : str, function + An actual function object or a Python string that names a + callable defined on the engines. + seq : list, tuple or numpy array + The local sequence to be scattered. + style : str + Only 'basic' is supported for now. + + :Returns: A list of len(seq) with functionSource called on each element + of seq. + + Example + ======= + + >>> rc.mapAll('lambda x: x*x', range(10000)) + [0,2,4,9,25,36,...] + """ + + +class ISynchronousMultiEngineCoordinator(IMultiEngineCoordinator): + """Methods that work on multiple engines explicitly.""" + pass + + +#------------------------------------------------------------------------------- +# IMultiEngineExtras +#------------------------------------------------------------------------------- + +class IMultiEngineExtras(Interface): + + def zip_pull(targets, *keys): + """Pull, but return results in a different format from `pull`. + + This method basically returns zip(pull(targets, *keys)), with a few + edge cases handled differently. Users of chainsaw will find this format + familiar. + + :Parameters: + targets : int, list or 'all' + The engine ids the action will apply to. Call `get_ids` to see + a list of currently available engines. + keys: list or tuple of str + A list of variable names as string of the Python objects to be pulled + back to the client. + + :Returns: A list of pulled Python objects for each target. + """ + + def run(targets, fname): + """Run a .py file on targets. + + :Parameters: + targets : int, list or 'all' + The engine ids the action will apply to. Call `get_ids` to see + a list of currently available engines. + fname : str + The filename of a .py file on the local system to be sent to and run + on the engines. + block : boolean + Should I block or not. If block=True, wait for the action to + complete and return the result. If block=False, return a + `PendingResult` object that can be used to later get the + result. If block is not specified, the block attribute + will be used instead. + """ + + +class ISynchronousMultiEngineExtras(IMultiEngineExtras): + pass + + +#------------------------------------------------------------------------------- +# The full MultiEngine interface +#------------------------------------------------------------------------------- + +class IFullMultiEngine(IMultiEngine, + IMultiEngineCoordinator, + IMultiEngineExtras): + pass + + +class IFullSynchronousMultiEngine(ISynchronousMultiEngine, + ISynchronousMultiEngineCoordinator, + ISynchronousMultiEngineExtras): + pass + diff --git a/IPython/kernel/multiengineclient.py b/IPython/kernel/multiengineclient.py new file mode 100644 index 0000000..1200f2a --- /dev/null +++ b/IPython/kernel/multiengineclient.py @@ -0,0 +1,840 @@ +# encoding: utf-8 +# -*- test-case-name: IPython.kernel.test.test_multiengineclient -*- + +"""General Classes for IMultiEngine clients.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import sys +import cPickle as pickle +from types import FunctionType +import linecache + +from twisted.internet import reactor +from twisted.python import components, log +from twisted.python.failure import Failure +from zope.interface import Interface, implements, Attribute + +from IPython.ColorANSI import TermColors + +from IPython.kernel.twistedutil import blockingCallFromThread +from IPython.kernel import error +from IPython.kernel.parallelfunction import ParallelFunction +from IPython.kernel import map as Map +from IPython.kernel import multiengine as me +from IPython.kernel.multiengine import (IFullMultiEngine, + IFullSynchronousMultiEngine) + + +#------------------------------------------------------------------------------- +# Pending Result things +#------------------------------------------------------------------------------- + +class IPendingResult(Interface): + """A representation of a result that is pending. + + This class is similar to Twisted's `Deferred` object, but is designed to be + used in a synchronous context. + """ + + result_id=Attribute("ID of the deferred on the other side") + client=Attribute("A client that I came from") + r=Attribute("An attribute that is a property that calls and returns get_result") + + def get_result(default=None, block=True): + """ + Get a result that is pending. + + :Parameters: + default + The value to return if the result is not ready. + block : boolean + Should I block for the result. + + :Returns: The actual result or the default value. + """ + + def add_callback(f, *args, **kwargs): + """ + Add a callback that is called with the result. + + If the original result is foo, adding a callback will cause + f(foo, *args, **kwargs) to be returned instead. If multiple + callbacks are registered, they are chained together: the result of + one is passed to the next and so on. + + Unlike Twisted's Deferred object, there is no errback chain. Thus + any exception raised will not be caught and handled. User must + catch these by hand when calling `get_result`. + """ + + +class PendingResult(object): + """A representation of a result that is not yet ready. + + A user should not create a `PendingResult` instance by hand. + + Methods + ======= + + * `get_result` + * `add_callback` + + Properties + ========== + * `r` + """ + + def __init__(self, client, result_id): + """Create a PendingResult with a result_id and a client instance. + + The client should implement `_getPendingResult(result_id, block)`. + """ + self.client = client + self.result_id = result_id + self.called = False + self.raised = False + self.callbacks = [] + + def get_result(self, default=None, block=True): + """Get a result that is pending. + + This method will connect to an IMultiEngine adapted controller + and see if the result is ready. If the action triggers an exception + raise it and record it. This method records the result/exception once it is + retrieved. Calling `get_result` again will get this cached result or will + re-raise the exception. The .r attribute is a property that calls + `get_result` with block=True. + + :Parameters: + default + The value to return if the result is not ready. + block : boolean + Should I block for the result. + + :Returns: The actual result or the default value. + """ + + if self.called: + if self.raised: + raise self.result[0], self.result[1], self.result[2] + else: + return self.result + try: + result = self.client.get_pending_deferred(self.result_id, block) + except error.ResultNotCompleted: + return default + except: + # Reraise other error, but first record them so they can be reraised + # later if .r or get_result is called again. + self.result = sys.exc_info() + self.called = True + self.raised = True + raise + else: + for cb in self.callbacks: + result = cb[0](result, *cb[1], **cb[2]) + self.result = result + self.called = True + return result + + def add_callback(self, f, *args, **kwargs): + """Add a callback that is called with the result. + + If the original result is result, adding a callback will cause + f(result, *args, **kwargs) to be returned instead. If multiple + callbacks are registered, they are chained together: the result of + one is passed to the next and so on. + + Unlike Twisted's Deferred object, there is no errback chain. Thus + any exception raised will not be caught and handled. User must + catch these by hand when calling `get_result`. + """ + assert callable(f) + self.callbacks.append((f, args, kwargs)) + + def __cmp__(self, other): + if self.result_id < other.result_id: + return -1 + else: + return 1 + + def _get_r(self): + return self.get_result(block=True) + + r = property(_get_r) + """This property is a shortcut to a `get_result(block=True)`.""" + + +#------------------------------------------------------------------------------- +# Pretty printing wrappers for certain lists +#------------------------------------------------------------------------------- + +class ResultList(list): + """A subclass of list that pretty prints the output of `execute`/`get_result`.""" + + def __repr__(self): + output = [] + blue = TermColors.Blue + normal = TermColors.Normal + red = TermColors.Red + green = TermColors.Green + output.append("\n") + for cmd in self: + if isinstance(cmd, Failure): + output.append(cmd) + else: + target = cmd.get('id',None) + cmd_num = cmd.get('number',None) + cmd_stdin = cmd.get('input',{}).get('translated','No Input') + cmd_stdout = cmd.get('stdout', None) + cmd_stderr = cmd.get('stderr', None) + output.append("%s[%i]%s In [%i]:%s %s\n" % \ + (green, target, + blue, cmd_num, normal, cmd_stdin)) + if cmd_stdout: + output.append("%s[%i]%s Out[%i]:%s %s\n" % \ + (green, target, + red, cmd_num, normal, cmd_stdout)) + if cmd_stderr: + output.append("%s[%i]%s Err[%i]:\n%s %s" % \ + (green, target, + red, cmd_num, normal, cmd_stderr)) + return ''.join(output) + + +def wrapResultList(result): + """A function that wraps the output of `execute`/`get_result` -> `ResultList`.""" + if len(result) == 0: + result = [result] + return ResultList(result) + + +class QueueStatusList(list): + """A subclass of list that pretty prints the output of `queue_status`.""" + + def __repr__(self): + output = [] + output.append("\n") + for e in self: + output.append("Engine: %s\n" % repr(e[0])) + output.append(" Pending: %s\n" % repr(e[1]['pending'])) + for q in e[1]['queue']: + output.append(" Command: %s\n" % repr(q)) + return ''.join(output) + + +#------------------------------------------------------------------------------- +# InteractiveMultiEngineClient +#------------------------------------------------------------------------------- + +class InteractiveMultiEngineClient(object): + """A mixin class that add a few methods to a multiengine client. + + The methods in this mixin class are designed for interactive usage. + """ + + def activate(self): + """Make this `MultiEngineClient` active for parallel magic commands. + + IPython has a magic command syntax to work with `MultiEngineClient` objects. + In a given IPython session there is a single active one. While + there can be many `MultiEngineClient` created and used by the user, + there is only one active one. The active `MultiEngineClient` is used whenever + the magic commands %px and %autopx are used. + + The activate() method is called on a given `MultiEngineClient` to make it + active. Once this has been done, the magic commands can be used. + """ + + try: + __IPYTHON__.activeController = self + except NameError: + print "The IPython Controller magics only work within IPython." + + def __setitem__(self, key, value): + """Add a dictionary interface for pushing/pulling. + + This functions as a shorthand for `push`. + + :Parameters: + key : str + What to call the remote object. + value : object + The local Python object to push. + """ + targets, block = self._findTargetsAndBlock() + return self.push({key:value}, targets=targets, block=block) + + def __getitem__(self, key): + """Add a dictionary interface for pushing/pulling. + + This functions as a shorthand to `pull`. + + :Parameters: + - `key`: A string representing the key. + """ + if isinstance(key, str): + targets, block = self._findTargetsAndBlock() + return self.pull(key, targets=targets, block=block) + else: + raise TypeError("__getitem__ only takes strs") + + def __len__(self): + """Return the number of available engines.""" + return len(self.get_ids()) + + def parallelize(self, func, targets=None, block=None): + """Build a `ParallelFunction` object for functionName on engines. + + The returned object will implement a parallel version of functionName + that takes a local sequence as its only argument and calls (in + parallel) functionName on each element of that sequence. The + `ParallelFunction` object has a `targets` attribute that controls + which engines the function is run on. + + :Parameters: + targets : int, list or 'all' + The engine ids the action will apply to. Call `get_ids` to see + a list of currently available engines. + functionName : str + A Python string that names a callable defined on the engines. + + :Returns: A `ParallelFunction` object. + + Examples + ======== + + >>> psin = rc.parallelize('all','lambda x:sin(x)') + >>> psin(range(10000)) + [0,2,4,9,25,36,...] + """ + targets, block = self._findTargetsAndBlock(targets, block) + return ParallelFunction(func, self, targets, block) + + #--------------------------------------------------------------------------- + # Make this a context manager for with + #--------------------------------------------------------------------------- + + def findsource_file(self,f): + linecache.checkcache() + s = findsource(f.f_code) + lnum = f.f_lineno + wsource = s[0][f.f_lineno:] + return strip_whitespace(wsource) + + def findsource_ipython(self,f): + from IPython import ipapi + self.ip = ipapi.get() + wsource = [l+'\n' for l in + self.ip.IP.input_hist_raw[-1].splitlines()[1:]] + return strip_whitespace(wsource) + + def __enter__(self): + f = sys._getframe(1) + local_ns = f.f_locals + global_ns = f.f_globals + if f.f_code.co_filename == '': + s = self.findsource_ipython(f) + else: + s = self.findsource_file(f) + + self._with_context_result = self.execute(s) + + def __exit__ (self, etype, value, tb): + if issubclass(etype,error.StopLocalExecution): + return True + + +def remote(): + m = 'Special exception to stop local execution of parallel code.' + raise error.StopLocalExecution(m) + +def strip_whitespace(source): + # Expand tabs to avoid any confusion. + wsource = [l.expandtabs(4) for l in source] + # Detect the indentation level + done = False + for line in wsource: + if line.isspace(): + continue + for col,char in enumerate(line): + if char != ' ': + done = True + break + if done: + break + # Now we know how much leading space there is in the code. Next, we + # extract up to the first line that has less indentation. + # WARNINGS: we skip comments that may be misindented, but we do NOT yet + # detect triple quoted strings that may have flush left text. + for lno,line in enumerate(wsource): + lead = line[:col] + if lead.isspace(): + continue + else: + if not lead.lstrip().startswith('#'): + break + # The real 'with' source is up to lno + src_lines = [l[col:] for l in wsource[:lno+1]] + + # Finally, check that the source's first non-comment line begins with the + # special call 'remote()' + for nline,line in enumerate(src_lines): + if line.isspace() or line.startswith('#'): + continue + if 'remote()' in line: + break + else: + raise ValueError('remote() call missing at the start of code') + src = ''.join(src_lines[nline+1:]) + #print 'SRC:\n<<<<<<<>>>>>>>\n%s<<<<<>>>>>>' % src # dbg + return src + + +#------------------------------------------------------------------------------- +# The top-level MultiEngine client adaptor +#------------------------------------------------------------------------------- + + +class IFullBlockingMultiEngineClient(Interface): + pass + + +class FullBlockingMultiEngineClient(InteractiveMultiEngineClient): + """ + A blocking client to the `IMultiEngine` controller interface. + + This class allows users to use a set of engines for a parallel + computation through the `IMultiEngine` interface. In this interface, + each engine has a specific id (an int) that is used to refer to the + engine, run code on it, etc. + """ + + implements(IFullBlockingMultiEngineClient) + + def __init__(self, smultiengine): + self.smultiengine = smultiengine + self.block = True + self.targets = 'all' + + def _findBlock(self, block=None): + if block is None: + return self.block + else: + if block in (True, False): + return block + else: + raise ValueError("block must be True or False") + + def _findTargets(self, targets=None): + if targets is None: + return self.targets + else: + if not isinstance(targets, (str,list,tuple,int)): + raise ValueError("targets must be a str, list, tuple or int") + return targets + + def _findTargetsAndBlock(self, targets=None, block=None): + return self._findTargets(targets), self._findBlock(block) + + def _blockFromThread(self, function, *args, **kwargs): + block = kwargs.get('block', None) + if block is None: + raise error.MissingBlockArgument("'block' keyword argument is missing") + result = blockingCallFromThread(function, *args, **kwargs) + if not block: + result = PendingResult(self, result) + return result + + def get_pending_deferred(self, deferredID, block): + return blockingCallFromThread(self.smultiengine.get_pending_deferred, deferredID, block) + + def barrier(self, pendingResults): + """Synchronize a set of `PendingResults`. + + This method is a synchronization primitive that waits for a set of + `PendingResult` objects to complete. More specifically, barier does + the following. + + * The `PendingResult`s are sorted by result_id. + * The `get_result` method is called for each `PendingResult` sequentially + with block=True. + * If a `PendingResult` gets a result that is an exception, it is + trapped and can be re-raised later by calling `get_result` again. + * The `PendingResult`s are flushed from the controller. + + After barrier has been called on a `PendingResult`, its results can + be retrieved by calling `get_result` again or accesing the `r` attribute + of the instance. + """ + + # Convert to list for sorting and check class type + prList = list(pendingResults) + for pr in prList: + if not isinstance(pr, PendingResult): + raise error.NotAPendingResult("Objects passed to barrier must be PendingResult instances") + + # Sort the PendingResults so they are in order + prList.sort() + # Block on each PendingResult object + for pr in prList: + try: + result = pr.get_result(block=True) + except Exception: + pass + + def flush(self): + """ + Clear all pending deferreds/results from the controller. + + For each `PendingResult` that is created by this client, the controller + holds on to the result for that `PendingResult`. This can be a problem + if there are a large number of `PendingResult` objects that are created. + + Once the result of the `PendingResult` has been retrieved, the result + is removed from the controller, but if a user doesn't get a result ( + they just ignore the `PendingResult`) the result is kept forever on the + controller. This method allows the user to clear out all un-retrieved + results on the controller. + """ + r = blockingCallFromThread(self.smultiengine.clear_pending_deferreds) + return r + + clear_pending_results = flush + + #--------------------------------------------------------------------------- + # IEngineMultiplexer related methods + #--------------------------------------------------------------------------- + + def execute(self, lines, targets=None, block=None): + """ + Execute code on a set of engines. + + :Parameters: + lines : str + The Python code to execute as a string + targets : id or list of ids + The engine to use for the execution + block : boolean + If False, this method will return the actual result. If False, + a `PendingResult` is returned which can be used to get the result + at a later time. + """ + targets, block = self._findTargetsAndBlock(targets, block) + result = blockingCallFromThread(self.smultiengine.execute, lines, + targets=targets, block=block) + if block: + result = ResultList(result) + else: + result = PendingResult(self, result) + result.add_callback(wrapResultList) + return result + + def push(self, namespace, targets=None, block=None): + """ + Push a dictionary of keys and values to engines namespace. + + Each engine has a persistent namespace. This method is used to push + Python objects into that namespace. + + The objects in the namespace must be pickleable. + + :Parameters: + namespace : dict + A dict that contains Python objects to be injected into + the engine persistent namespace. + targets : id or list of ids + The engine to use for the execution + block : boolean + If False, this method will return the actual result. If False, + a `PendingResult` is returned which can be used to get the result + at a later time. + """ + targets, block = self._findTargetsAndBlock(targets, block) + return self._blockFromThread(self.smultiengine.push, namespace, + targets=targets, block=block) + + def pull(self, keys, targets=None, block=None): + """ + Pull Python objects by key out of engines namespaces. + + :Parameters: + keys : str or list of str + The names of the variables to be pulled + targets : id or list of ids + The engine to use for the execution + block : boolean + If False, this method will return the actual result. If False, + a `PendingResult` is returned which can be used to get the result + at a later time. + """ + targets, block = self._findTargetsAndBlock(targets, block) + return self._blockFromThread(self.smultiengine.pull, keys, targets=targets, block=block) + + def push_function(self, namespace, targets=None, block=None): + """ + Push a Python function to an engine. + + This method is used to push a Python function to an engine. This + method can then be used in code on the engines. Closures are not supported. + + :Parameters: + namespace : dict + A dict whose values are the functions to be pushed. The keys give + that names that the function will appear as in the engines + namespace. + targets : id or list of ids + The engine to use for the execution + block : boolean + If False, this method will return the actual result. If False, + a `PendingResult` is returned which can be used to get the result + at a later time. + """ + targets, block = self._findTargetsAndBlock(targets, block) + return self._blockFromThread(self.smultiengine.push_function, namespace, targets=targets, block=block) + + def pull_function(self, keys, targets=None, block=None): + """ + Pull a Python function from an engine. + + This method is used to pull a Python function from an engine. + Closures are not supported. + + :Parameters: + keys : str or list of str + The names of the functions to be pulled + targets : id or list of ids + The engine to use for the execution + block : boolean + If False, this method will return the actual result. If False, + a `PendingResult` is returned which can be used to get the result + at a later time. + """ + targets, block = self._findTargetsAndBlock(targets, block) + return self._blockFromThread(self.smultiengine.pull_function, keys, targets=targets, block=block) + + def push_serialized(self, namespace, targets=None, block=None): + targets, block = self._findTargetsAndBlock(targets, block) + return self._blockFromThread(self.smultiengine.push_serialized, namespace, targets=targets, block=block) + + def pull_serialized(self, keys, targets=None, block=None): + targets, block = self._findTargetsAndBlock(targets, block) + return self._blockFromThread(self.smultiengine.pull_serialized, keys, targets=targets, block=block) + + def get_result(self, i=None, targets=None, block=None): + """ + Get a previous result. + + When code is executed in an engine, a dict is created and returned. This + method retrieves that dict for previous commands. + + :Parameters: + i : int + The number of the result to get + targets : id or list of ids + The engine to use for the execution + block : boolean + If False, this method will return the actual result. If False, + a `PendingResult` is returned which can be used to get the result + at a later time. + """ + targets, block = self._findTargetsAndBlock(targets, block) + result = blockingCallFromThread(self.smultiengine.get_result, i, targets=targets, block=block) + if block: + result = ResultList(result) + else: + result = PendingResult(self, result) + result.add_callback(wrapResultList) + return result + + def reset(self, targets=None, block=None): + """ + Reset an engine. + + This method clears out the namespace of an engine. + + :Parameters: + targets : id or list of ids + The engine to use for the execution + block : boolean + If False, this method will return the actual result. If False, + a `PendingResult` is returned which can be used to get the result + at a later time. + """ + targets, block = self._findTargetsAndBlock(targets, block) + return self._blockFromThread(self.smultiengine.reset, targets=targets, block=block) + + def keys(self, targets=None, block=None): + """ + Get a list of all the variables in an engine's namespace. + + :Parameters: + targets : id or list of ids + The engine to use for the execution + block : boolean + If False, this method will return the actual result. If False, + a `PendingResult` is returned which can be used to get the result + at a later time. + """ + targets, block = self._findTargetsAndBlock(targets, block) + return self._blockFromThread(self.smultiengine.keys, targets=targets, block=block) + + def kill(self, controller=False, targets=None, block=None): + """ + Kill the engines and controller. + + This method is used to stop the engine and controller by calling + `reactor.stop`. + + :Parameters: + controller : boolean + If True, kill the engines and controller. If False, just the + engines + targets : id or list of ids + The engine to use for the execution + block : boolean + If False, this method will return the actual result. If False, + a `PendingResult` is returned which can be used to get the result + at a later time. + """ + targets, block = self._findTargetsAndBlock(targets, block) + return self._blockFromThread(self.smultiengine.kill, controller, targets=targets, block=block) + + def clear_queue(self, targets=None, block=None): + """ + Clear out the controller's queue for an engine. + + The controller maintains a queue for each engine. This clear it out. + + :Parameters: + targets : id or list of ids + The engine to use for the execution + block : boolean + If False, this method will return the actual result. If False, + a `PendingResult` is returned which can be used to get the result + at a later time. + """ + targets, block = self._findTargetsAndBlock(targets, block) + return self._blockFromThread(self.smultiengine.clear_queue, targets=targets, block=block) + + def queue_status(self, targets=None, block=None): + """ + Get the status of an engines queue. + + :Parameters: + targets : id or list of ids + The engine to use for the execution + block : boolean + If False, this method will return the actual result. If False, + a `PendingResult` is returned which can be used to get the result + at a later time. + """ + targets, block = self._findTargetsAndBlock(targets, block) + return self._blockFromThread(self.smultiengine.queue_status, targets=targets, block=block) + + def set_properties(self, properties, targets=None, block=None): + targets, block = self._findTargetsAndBlock(targets, block) + return self._blockFromThread(self.smultiengine.set_properties, properties, targets=targets, block=block) + + def get_properties(self, keys=None, targets=None, block=None): + targets, block = self._findTargetsAndBlock(targets, block) + return self._blockFromThread(self.smultiengine.get_properties, keys, targets=targets, block=block) + + def has_properties(self, keys, targets=None, block=None): + targets, block = self._findTargetsAndBlock(targets, block) + return self._blockFromThread(self.smultiengine.has_properties, keys, targets=targets, block=block) + + def del_properties(self, keys, targets=None, block=None): + targets, block = self._findTargetsAndBlock(targets, block) + return self._blockFromThread(self.smultiengine.del_properties, keys, targets=targets, block=block) + + def clear_properties(self, targets=None, block=None): + targets, block = self._findTargetsAndBlock(targets, block) + return self._blockFromThread(self.smultiengine.clear_properties, targets=targets, block=block) + + #--------------------------------------------------------------------------- + # IMultiEngine related methods + #--------------------------------------------------------------------------- + + def get_ids(self): + """ + Returns the ids of currently registered engines. + """ + result = blockingCallFromThread(self.smultiengine.get_ids) + return result + + #--------------------------------------------------------------------------- + # IMultiEngineCoordinator + #--------------------------------------------------------------------------- + + def scatter(self, key, seq, style='basic', flatten=False, targets=None, block=None): + """ + Partition a Python sequence and send the partitions to a set of engines. + """ + targets, block = self._findTargetsAndBlock(targets, block) + return self._blockFromThread(self.smultiengine.scatter, key, seq, + style, flatten, targets=targets, block=block) + + def gather(self, key, style='basic', targets=None, block=None): + """ + Gather a partitioned sequence on a set of engines as a single local seq. + """ + targets, block = self._findTargetsAndBlock(targets, block) + return self._blockFromThread(self.smultiengine.gather, key, style, + targets=targets, block=block) + + def map(self, func, seq, style='basic', targets=None, block=None): + """ + A parallelized version of Python's builtin map + """ + targets, block = self._findTargetsAndBlock(targets, block) + return self._blockFromThread(self.smultiengine.map, func, seq, + style, targets=targets, block=block) + + #--------------------------------------------------------------------------- + # IMultiEngineExtras + #--------------------------------------------------------------------------- + + def zip_pull(self, keys, targets=None, block=None): + targets, block = self._findTargetsAndBlock(targets, block) + return self._blockFromThread(self.smultiengine.zip_pull, keys, + targets=targets, block=block) + + def run(self, filename, targets=None, block=None): + """ + Run a Python code in a file on the engines. + + :Parameters: + filename : str + The name of the local file to run + targets : id or list of ids + The engine to use for the execution + block : boolean + If False, this method will return the actual result. If False, + a `PendingResult` is returned which can be used to get the result + at a later time. + """ + targets, block = self._findTargetsAndBlock(targets, block) + return self._blockFromThread(self.smultiengine.run, filename, + targets=targets, block=block) + + + +components.registerAdapter(FullBlockingMultiEngineClient, + IFullSynchronousMultiEngine, IFullBlockingMultiEngineClient) + + + + diff --git a/IPython/kernel/multienginefc.py b/IPython/kernel/multienginefc.py new file mode 100644 index 0000000..df53f22 --- /dev/null +++ b/IPython/kernel/multienginefc.py @@ -0,0 +1,668 @@ +# encoding: utf-8 + +""" +Expose the multiengine controller over the Foolscap network protocol. +""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import cPickle as pickle +from types import FunctionType + +from zope.interface import Interface, implements +from twisted.internet import defer +from twisted.python import components, failure, log + +from foolscap import Referenceable + +from IPython.kernel import error +from IPython.kernel.util import printer +from IPython.kernel import map as Map +from IPython.kernel.twistedutil import gatherBoth +from IPython.kernel.multiengine import (MultiEngine, + IMultiEngine, + IFullSynchronousMultiEngine, + ISynchronousMultiEngine) +from IPython.kernel.multiengineclient import wrapResultList +from IPython.kernel.pendingdeferred import PendingDeferredManager +from IPython.kernel.pickleutil import (can, canDict, + canSequence, uncan, uncanDict, uncanSequence) + +from IPython.kernel.clientinterfaces import ( + IFCClientInterfaceProvider, + IBlockingClientAdaptor +) + +# Needed to access the true globals from __main__.__dict__ +import __main__ + +#------------------------------------------------------------------------------- +# The Controller side of things +#------------------------------------------------------------------------------- + +def packageResult(wrappedMethod): + + def wrappedPackageResult(self, *args, **kwargs): + d = wrappedMethod(self, *args, **kwargs) + d.addCallback(self.packageSuccess) + d.addErrback(self.packageFailure) + return d + return wrappedPackageResult + + +class IFCSynchronousMultiEngine(Interface): + """Foolscap interface to `ISynchronousMultiEngine`. + + The methods in this interface are similar to those of + `ISynchronousMultiEngine`, but their arguments and return values are pickled + if they are not already simple Python types that can be send over XML-RPC. + + See the documentation of `ISynchronousMultiEngine` and `IMultiEngine` for + documentation about the methods. + + Most methods in this interface act like the `ISynchronousMultiEngine` + versions and can be called in blocking or non-blocking mode. + """ + pass + + +class FCSynchronousMultiEngineFromMultiEngine(Referenceable): + """Adapt `IMultiEngine` -> `ISynchronousMultiEngine` -> `IFCSynchronousMultiEngine`. + """ + + implements(IFCSynchronousMultiEngine, IFCClientInterfaceProvider) + + addSlash = True + + def __init__(self, multiengine): + # Adapt the raw multiengine to `ISynchronousMultiEngine` before saving + # it. This allow this class to do two adaptation steps. + self.smultiengine = ISynchronousMultiEngine(multiengine) + self._deferredIDCallbacks = {} + + #--------------------------------------------------------------------------- + # Non interface methods + #--------------------------------------------------------------------------- + + def packageFailure(self, f): + f.cleanFailure() + return self.packageSuccess(f) + + def packageSuccess(self, obj): + serial = pickle.dumps(obj, 2) + return serial + + #--------------------------------------------------------------------------- + # Things related to PendingDeferredManager + #--------------------------------------------------------------------------- + + @packageResult + def remote_get_pending_deferred(self, deferredID, block): + d = self.smultiengine.get_pending_deferred(deferredID, block) + try: + callback = self._deferredIDCallbacks.pop(deferredID) + except KeyError: + callback = None + if callback is not None: + d.addCallback(callback[0], *callback[1], **callback[2]) + return d + + @packageResult + def remote_clear_pending_deferreds(self): + return defer.maybeDeferred(self.smultiengine.clear_pending_deferreds) + + def _addDeferredIDCallback(self, did, callback, *args, **kwargs): + self._deferredIDCallbacks[did] = (callback, args, kwargs) + return did + + #--------------------------------------------------------------------------- + # IEngineMultiplexer related methods + #--------------------------------------------------------------------------- + + @packageResult + def remote_execute(self, lines, targets, block): + return self.smultiengine.execute(lines, targets=targets, block=block) + + @packageResult + def remote_push(self, binaryNS, targets, block): + try: + namespace = pickle.loads(binaryNS) + except: + d = defer.fail(failure.Failure()) + else: + d = self.smultiengine.push(namespace, targets=targets, block=block) + return d + + @packageResult + def remote_pull(self, keys, targets, block): + d = self.smultiengine.pull(keys, targets=targets, block=block) + return d + + @packageResult + def remote_push_function(self, binaryNS, targets, block): + try: + namespace = pickle.loads(binaryNS) + except: + d = defer.fail(failure.Failure()) + else: + namespace = uncanDict(namespace) + d = self.smultiengine.push_function(namespace, targets=targets, block=block) + return d + + def _canMultipleKeys(self, result): + return [canSequence(r) for r in result] + + @packageResult + def remote_pull_function(self, keys, targets, block): + def can_functions(r, keys): + if len(keys)==1 or isinstance(keys, str): + result = canSequence(r) + elif len(keys)>1: + result = [canSequence(s) for s in r] + return result + d = self.smultiengine.pull_function(keys, targets=targets, block=block) + if block: + d.addCallback(can_functions, keys) + else: + d.addCallback(lambda did: self._addDeferredIDCallback(did, can_functions, keys)) + return d + + @packageResult + def remote_push_serialized(self, binaryNS, targets, block): + try: + namespace = pickle.loads(binaryNS) + except: + d = defer.fail(failure.Failure()) + else: + d = self.smultiengine.push_serialized(namespace, targets=targets, block=block) + return d + + @packageResult + def remote_pull_serialized(self, keys, targets, block): + d = self.smultiengine.pull_serialized(keys, targets=targets, block=block) + return d + + @packageResult + def remote_get_result(self, i, targets, block): + if i == 'None': + i = None + return self.smultiengine.get_result(i, targets=targets, block=block) + + @packageResult + def remote_reset(self, targets, block): + return self.smultiengine.reset(targets=targets, block=block) + + @packageResult + def remote_keys(self, targets, block): + return self.smultiengine.keys(targets=targets, block=block) + + @packageResult + def remote_kill(self, controller, targets, block): + return self.smultiengine.kill(controller, targets=targets, block=block) + + @packageResult + def remote_clear_queue(self, targets, block): + return self.smultiengine.clear_queue(targets=targets, block=block) + + @packageResult + def remote_queue_status(self, targets, block): + return self.smultiengine.queue_status(targets=targets, block=block) + + @packageResult + def remote_set_properties(self, binaryNS, targets, block): + try: + ns = pickle.loads(binaryNS) + except: + d = defer.fail(failure.Failure()) + else: + d = self.smultiengine.set_properties(ns, targets=targets, block=block) + return d + + @packageResult + def remote_get_properties(self, keys, targets, block): + if keys=='None': + keys=None + return self.smultiengine.get_properties(keys, targets=targets, block=block) + + @packageResult + def remote_has_properties(self, keys, targets, block): + return self.smultiengine.has_properties(keys, targets=targets, block=block) + + @packageResult + def remote_del_properties(self, keys, targets, block): + return self.smultiengine.del_properties(keys, targets=targets, block=block) + + @packageResult + def remote_clear_properties(self, targets, block): + return self.smultiengine.clear_properties(targets=targets, block=block) + + #--------------------------------------------------------------------------- + # IMultiEngine related methods + #--------------------------------------------------------------------------- + + def remote_get_ids(self): + """Get the ids of the registered engines. + + This method always blocks. + """ + return self.smultiengine.get_ids() + + #--------------------------------------------------------------------------- + # IFCClientInterfaceProvider related methods + #--------------------------------------------------------------------------- + + def remote_get_client_name(self): + return 'IPython.kernel.multienginefc.FCFullSynchronousMultiEngineClient' + + +# The __init__ method of `FCMultiEngineFromMultiEngine` first adapts the +# `IMultiEngine` to `ISynchronousMultiEngine` so this is actually doing a +# two phase adaptation. +components.registerAdapter(FCSynchronousMultiEngineFromMultiEngine, + IMultiEngine, IFCSynchronousMultiEngine) + + +#------------------------------------------------------------------------------- +# The Client side of things +#------------------------------------------------------------------------------- + + +class FCFullSynchronousMultiEngineClient(object): + + implements(IFullSynchronousMultiEngine, IBlockingClientAdaptor) + + def __init__(self, remote_reference): + self.remote_reference = remote_reference + self._deferredIDCallbacks = {} + # This class manages some pending deferreds through this instance. This + # is required for methods like gather/scatter as it enables us to + # create our own pending deferreds for composite operations. + self.pdm = PendingDeferredManager() + + #--------------------------------------------------------------------------- + # Non interface methods + #--------------------------------------------------------------------------- + + def unpackage(self, r): + return pickle.loads(r) + + #--------------------------------------------------------------------------- + # Things related to PendingDeferredManager + #--------------------------------------------------------------------------- + + def get_pending_deferred(self, deferredID, block=True): + + # Because we are managing some pending deferreds locally (through + # self.pdm) and some remotely (on the controller), we first try the + # local one and then the remote one. + if self.pdm.quick_has_id(deferredID): + d = self.pdm.get_pending_deferred(deferredID, block) + return d + else: + d = self.remote_reference.callRemote('get_pending_deferred', deferredID, block) + d.addCallback(self.unpackage) + try: + callback = self._deferredIDCallbacks.pop(deferredID) + except KeyError: + callback = None + if callback is not None: + d.addCallback(callback[0], *callback[1], **callback[2]) + return d + + def clear_pending_deferreds(self): + + # This clear both the local (self.pdm) and remote pending deferreds + self.pdm.clear_pending_deferreds() + d2 = self.remote_reference.callRemote('clear_pending_deferreds') + d2.addCallback(self.unpackage) + return d2 + + def _addDeferredIDCallback(self, did, callback, *args, **kwargs): + self._deferredIDCallbacks[did] = (callback, args, kwargs) + return did + + #--------------------------------------------------------------------------- + # IEngineMultiplexer related methods + #--------------------------------------------------------------------------- + + def execute(self, lines, targets='all', block=True): + d = self.remote_reference.callRemote('execute', lines, targets, block) + d.addCallback(self.unpackage) + return d + + def push(self, namespace, targets='all', block=True): + serial = pickle.dumps(namespace, 2) + d = self.remote_reference.callRemote('push', serial, targets, block) + d.addCallback(self.unpackage) + return d + + def pull(self, keys, targets='all', block=True): + d = self.remote_reference.callRemote('pull', keys, targets, block) + d.addCallback(self.unpackage) + return d + + def push_function(self, namespace, targets='all', block=True): + cannedNamespace = canDict(namespace) + serial = pickle.dumps(cannedNamespace, 2) + d = self.remote_reference.callRemote('push_function', serial, targets, block) + d.addCallback(self.unpackage) + return d + + def pull_function(self, keys, targets='all', block=True): + def uncan_functions(r, keys): + if len(keys)==1 or isinstance(keys, str): + return uncanSequence(r) + elif len(keys)>1: + return [uncanSequence(s) for s in r] + d = self.remote_reference.callRemote('pull_function', keys, targets, block) + if block: + d.addCallback(self.unpackage) + d.addCallback(uncan_functions, keys) + else: + d.addCallback(self.unpackage) + d.addCallback(lambda did: self._addDeferredIDCallback(did, uncan_functions, keys)) + return d + + def push_serialized(self, namespace, targets='all', block=True): + cannedNamespace = canDict(namespace) + serial = pickle.dumps(cannedNamespace, 2) + d = self.remote_reference.callRemote('push_serialized', serial, targets, block) + d.addCallback(self.unpackage) + return d + + def pull_serialized(self, keys, targets='all', block=True): + d = self.remote_reference.callRemote('pull_serialized', keys, targets, block) + d.addCallback(self.unpackage) + return d + + def get_result(self, i=None, targets='all', block=True): + if i is None: # This is because None cannot be marshalled by xml-rpc + i = 'None' + d = self.remote_reference.callRemote('get_result', i, targets, block) + d.addCallback(self.unpackage) + return d + + def reset(self, targets='all', block=True): + d = self.remote_reference.callRemote('reset', targets, block) + d.addCallback(self.unpackage) + return d + + def keys(self, targets='all', block=True): + d = self.remote_reference.callRemote('keys', targets, block) + d.addCallback(self.unpackage) + return d + + def kill(self, controller=False, targets='all', block=True): + d = self.remote_reference.callRemote('kill', controller, targets, block) + d.addCallback(self.unpackage) + return d + + def clear_queue(self, targets='all', block=True): + d = self.remote_reference.callRemote('clear_queue', targets, block) + d.addCallback(self.unpackage) + return d + + def queue_status(self, targets='all', block=True): + d = self.remote_reference.callRemote('queue_status', targets, block) + d.addCallback(self.unpackage) + return d + + def set_properties(self, properties, targets='all', block=True): + serial = pickle.dumps(properties, 2) + d = self.remote_reference.callRemote('set_properties', serial, targets, block) + d.addCallback(self.unpackage) + return d + + def get_properties(self, keys=None, targets='all', block=True): + if keys==None: + keys='None' + d = self.remote_reference.callRemote('get_properties', keys, targets, block) + d.addCallback(self.unpackage) + return d + + def has_properties(self, keys, targets='all', block=True): + d = self.remote_reference.callRemote('has_properties', keys, targets, block) + d.addCallback(self.unpackage) + return d + + def del_properties(self, keys, targets='all', block=True): + d = self.remote_reference.callRemote('del_properties', keys, targets, block) + d.addCallback(self.unpackage) + return d + + def clear_properties(self, targets='all', block=True): + d = self.remote_reference.callRemote('clear_properties', targets, block) + d.addCallback(self.unpackage) + return d + + #--------------------------------------------------------------------------- + # IMultiEngine related methods + #--------------------------------------------------------------------------- + + def get_ids(self): + d = self.remote_reference.callRemote('get_ids') + return d + + #--------------------------------------------------------------------------- + # ISynchronousMultiEngineCoordinator related methods + #--------------------------------------------------------------------------- + + def _process_targets(self, targets): + def create_targets(ids): + if isinstance(targets, int): + engines = [targets] + elif targets=='all': + engines = ids + elif isinstance(targets, (list, tuple)): + engines = targets + for t in engines: + if not t in ids: + raise error.InvalidEngineID("engine with id %r does not exist"%t) + return engines + + d = self.get_ids() + d.addCallback(create_targets) + return d + + def scatter(self, key, seq, style='basic', flatten=False, targets='all', block=True): + + # Note: scatter and gather handle pending deferreds locally through self.pdm. + # This enables us to collect a bunch fo deferred ids and make a secondary + # deferred id that corresponds to the entire group. This logic is extremely + # difficult to get right though. + def do_scatter(engines): + nEngines = len(engines) + mapClass = Map.styles[style] + mapObject = mapClass() + d_list = [] + # Loop through and push to each engine in non-blocking mode. + # This returns a set of deferreds to deferred_ids + for index, engineid in enumerate(engines): + partition = mapObject.getPartition(seq, index, nEngines) + if flatten and len(partition) == 1: + d = self.push({key: partition[0]}, targets=engineid, block=False) + else: + d = self.push({key: partition}, targets=engineid, block=False) + d_list.append(d) + # Collect the deferred to deferred_ids + d = gatherBoth(d_list, + fireOnOneErrback=0, + consumeErrors=1, + logErrors=0) + # Now d has a list of deferred_ids or Failures coming + d.addCallback(error.collect_exceptions, 'scatter') + def process_did_list(did_list): + """Turn a list of deferred_ids into a final result or failure.""" + new_d_list = [self.get_pending_deferred(did, True) for did in did_list] + final_d = gatherBoth(new_d_list, + fireOnOneErrback=0, + consumeErrors=1, + logErrors=0) + final_d.addCallback(error.collect_exceptions, 'scatter') + final_d.addCallback(lambda lop: [i[0] for i in lop]) + return final_d + # Now, depending on block, we need to handle the list deferred_ids + # coming down the pipe diferently. + if block: + # If we are blocking register a callback that will transform the + # list of deferred_ids into the final result. + d.addCallback(process_did_list) + return d + else: + # Here we are going to use a _local_ PendingDeferredManager. + deferred_id = self.pdm.get_deferred_id() + # This is the deferred we will return to the user that will fire + # with the local deferred_id AFTER we have received the list of + # primary deferred_ids + d_to_return = defer.Deferred() + def do_it(did_list): + """Produce a deferred to the final result, but first fire the + deferred we will return to the user that has the local + deferred id.""" + d_to_return.callback(deferred_id) + return process_did_list(did_list) + d.addCallback(do_it) + # Now save the deferred to the final result + self.pdm.save_pending_deferred(d, deferred_id) + return d_to_return + + d = self._process_targets(targets) + d.addCallback(do_scatter) + return d + + def gather(self, key, style='basic', targets='all', block=True): + + # Note: scatter and gather handle pending deferreds locally through self.pdm. + # This enables us to collect a bunch fo deferred ids and make a secondary + # deferred id that corresponds to the entire group. This logic is extremely + # difficult to get right though. + def do_gather(engines): + nEngines = len(engines) + mapClass = Map.styles[style] + mapObject = mapClass() + d_list = [] + # Loop through and push to each engine in non-blocking mode. + # This returns a set of deferreds to deferred_ids + for index, engineid in enumerate(engines): + d = self.pull(key, targets=engineid, block=False) + d_list.append(d) + # Collect the deferred to deferred_ids + d = gatherBoth(d_list, + fireOnOneErrback=0, + consumeErrors=1, + logErrors=0) + # Now d has a list of deferred_ids or Failures coming + d.addCallback(error.collect_exceptions, 'scatter') + def process_did_list(did_list): + """Turn a list of deferred_ids into a final result or failure.""" + new_d_list = [self.get_pending_deferred(did, True) for did in did_list] + final_d = gatherBoth(new_d_list, + fireOnOneErrback=0, + consumeErrors=1, + logErrors=0) + final_d.addCallback(error.collect_exceptions, 'gather') + final_d.addCallback(lambda lop: [i[0] for i in lop]) + final_d.addCallback(mapObject.joinPartitions) + return final_d + # Now, depending on block, we need to handle the list deferred_ids + # coming down the pipe diferently. + if block: + # If we are blocking register a callback that will transform the + # list of deferred_ids into the final result. + d.addCallback(process_did_list) + return d + else: + # Here we are going to use a _local_ PendingDeferredManager. + deferred_id = self.pdm.get_deferred_id() + # This is the deferred we will return to the user that will fire + # with the local deferred_id AFTER we have received the list of + # primary deferred_ids + d_to_return = defer.Deferred() + def do_it(did_list): + """Produce a deferred to the final result, but first fire the + deferred we will return to the user that has the local + deferred id.""" + d_to_return.callback(deferred_id) + return process_did_list(did_list) + d.addCallback(do_it) + # Now save the deferred to the final result + self.pdm.save_pending_deferred(d, deferred_id) + return d_to_return + + d = self._process_targets(targets) + d.addCallback(do_gather) + return d + + def map(self, func, seq, style='basic', targets='all', block=True): + d_list = [] + if isinstance(func, FunctionType): + d = self.push_function(dict(_ipython_map_func=func), targets=targets, block=False) + d.addCallback(lambda did: self.get_pending_deferred(did, True)) + sourceToRun = '_ipython_map_seq_result = map(_ipython_map_func, _ipython_map_seq)' + elif isinstance(func, str): + d = defer.succeed(None) + sourceToRun = \ + '_ipython_map_seq_result = map(%s, _ipython_map_seq)' % func + else: + raise TypeError("func must be a function or str") + + d.addCallback(lambda _: self.scatter('_ipython_map_seq', seq, style, targets=targets)) + d.addCallback(lambda _: self.execute(sourceToRun, targets=targets, block=False)) + d.addCallback(lambda did: self.get_pending_deferred(did, True)) + d.addCallback(lambda _: self.gather('_ipython_map_seq_result', style, targets=targets, block=block)) + return d + + #--------------------------------------------------------------------------- + # ISynchronousMultiEngineExtras related methods + #--------------------------------------------------------------------------- + + def _transformPullResult(self, pushResult, multitargets, lenKeys): + if not multitargets: + result = pushResult[0] + elif lenKeys > 1: + result = zip(*pushResult) + elif lenKeys is 1: + result = list(pushResult) + return result + + def zip_pull(self, keys, targets='all', block=True): + multitargets = not isinstance(targets, int) and len(targets) > 1 + lenKeys = len(keys) + d = self.pull(keys, targets=targets, block=block) + if block: + d.addCallback(self._transformPullResult, multitargets, lenKeys) + else: + d.addCallback(lambda did: self._addDeferredIDCallback(did, self._transformPullResult, multitargets, lenKeys)) + return d + + def run(self, fname, targets='all', block=True): + fileobj = open(fname,'r') + source = fileobj.read() + fileobj.close() + # if the compilation blows, we get a local error right away + try: + code = compile(source,fname,'exec') + except: + return defer.fail(failure.Failure()) + # Now run the code + d = self.execute(source, targets=targets, block=block) + return d + + #--------------------------------------------------------------------------- + # IBlockingClientAdaptor related methods + #--------------------------------------------------------------------------- + + def adapt_to_blocking_client(self): + from IPython.kernel.multiengineclient import IFullBlockingMultiEngineClient + return IFullBlockingMultiEngineClient(self) diff --git a/IPython/kernel/newserialized.py b/IPython/kernel/newserialized.py new file mode 100644 index 0000000..38ba3c7 --- /dev/null +++ b/IPython/kernel/newserialized.py @@ -0,0 +1,163 @@ +# encoding: utf-8 +# -*- test-case-name: IPython.kernel.test.test_newserialized -*- + +"""Refactored serialization classes and interfaces.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import cPickle as pickle + +from zope.interface import Interface, implements +from twisted.python import components + +try: + import numpy +except ImportError: + pass + +from IPython.kernel.error import SerializationError + +class ISerialized(Interface): + + def getData(): + """""" + + def getDataSize(units=10.0**6): + """""" + + def getTypeDescriptor(): + """""" + + def getMetadata(): + """""" + + +class IUnSerialized(Interface): + + def getObject(): + """""" + +class Serialized(object): + + implements(ISerialized) + + def __init__(self, data, typeDescriptor, metadata={}): + self.data = data + self.typeDescriptor = typeDescriptor + self.metadata = metadata + + def getData(self): + return self.data + + def getDataSize(self, units=10.0**6): + return len(self.data)/units + + def getTypeDescriptor(self): + return self.typeDescriptor + + def getMetadata(self): + return self.metadata + + +class UnSerialized(object): + + implements(IUnSerialized) + + def __init__(self, obj): + self.obj = obj + + def getObject(self): + return self.obj + + +class SerializeIt(object): + + implements(ISerialized) + + def __init__(self, unSerialized): + self.data = None + self.obj = unSerialized.getObject() + if globals().has_key('numpy'): + if isinstance(self.obj, numpy.ndarray): + if len(self.obj) == 0: # length 0 arrays can't be reconstructed + raise SerializationError("You cannot send a length 0 array") + self.obj = numpy.ascontiguousarray(self.obj, dtype=None) + self.typeDescriptor = 'ndarray' + self.metadata = {'shape':self.obj.shape, + 'dtype':self.obj.dtype.str} + else: + self.typeDescriptor = 'pickle' + self.metadata = {} + else: + self.typeDescriptor = 'pickle' + self.metadata = {} + self._generateData() + + def _generateData(self): + if self.typeDescriptor == 'ndarray': + self.data = numpy.getbuffer(self.obj) + elif self.typeDescriptor == 'pickle': + self.data = pickle.dumps(self.obj, 2) + else: + raise SerializationError("Really wierd serialization error.") + del self.obj + + def getData(self): + return self.data + + def getDataSize(self, units=10.0**6): + return len(self.data)/units + + def getTypeDescriptor(self): + return self.typeDescriptor + + def getMetadata(self): + return self.metadata + + +class UnSerializeIt(UnSerialized): + + implements(IUnSerialized) + + def __init__(self, serialized): + self.serialized = serialized + + def getObject(self): + typeDescriptor = self.serialized.getTypeDescriptor() + if globals().has_key('numpy'): + if typeDescriptor == 'ndarray': + result = numpy.frombuffer(self.serialized.getData(), dtype = self.serialized.metadata['dtype']) + result.shape = self.serialized.metadata['shape'] + # This is a hack to make the array writable. We are working with + # the numpy folks to address this issue. + result = result.copy() + elif typeDescriptor == 'pickle': + result = pickle.loads(self.serialized.getData()) + else: + raise SerializationError("Really wierd serialization error.") + elif typeDescriptor == 'pickle': + result = pickle.loads(self.serialized.getData()) + else: + raise SerializationError("Really wierd serialization error.") + return result + +components.registerAdapter(UnSerializeIt, ISerialized, IUnSerialized) + +components.registerAdapter(SerializeIt, IUnSerialized, ISerialized) + +def serialize(obj): + return ISerialized(UnSerialized(obj)) + +def unserialize(serialized): + return IUnSerialized(serialized).getObject() diff --git a/IPython/kernel/parallelfunction.py b/IPython/kernel/parallelfunction.py new file mode 100644 index 0000000..129b369 --- /dev/null +++ b/IPython/kernel/parallelfunction.py @@ -0,0 +1,32 @@ +# encoding: utf-8 + +"""A parallelized function that does scatter/execute/gather.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from types import FunctionType + +class ParallelFunction: + """A function that operates in parallel on sequences.""" + def __init__(self, func, multiengine, targets, block): + """Create a `ParallelFunction`. + """ + assert isinstance(func, (str, FunctionType)), "func must be a fuction or str" + self.func = func + self.multiengine = multiengine + self.targets = targets + self.block = block + + def __call__(self, sequence): + return self.multiengine.map(self.func, sequence, targets=self.targets, block=self.block) \ No newline at end of file diff --git a/IPython/kernel/pbconfig.py b/IPython/kernel/pbconfig.py new file mode 100644 index 0000000..5fb5e4d --- /dev/null +++ b/IPython/kernel/pbconfig.py @@ -0,0 +1,34 @@ +# encoding: utf-8 + +"""Low level configuration for Twisted's Perspective Broker protocol.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from twisted.spread import banana + + +#------------------------------------------------------------------------------- +# This is where you configure the size limit of the banana protocol that +# PB uses. WARNING, this only works if you are NOT using cBanana, which is +# faster than banana.py. +#------------------------------------------------------------------------------- + + + +#banana.SIZE_LIMIT = 640*1024 # The default of 640 kB +banana.SIZE_LIMIT = 10*1024*1024 # 10 MB +#banana.SIZE_LIMIT = 50*1024*1024 # 50 MB + +# This sets the size of chunks used when paging is used. +CHUNK_SIZE = 64*1024 diff --git a/IPython/kernel/pbutil.py b/IPython/kernel/pbutil.py new file mode 100644 index 0000000..6ce8050 --- /dev/null +++ b/IPython/kernel/pbutil.py @@ -0,0 +1,93 @@ +# encoding: utf-8 + +"""Utilities for PB using modules.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import cPickle as pickle + +from twisted.python.failure import Failure +from twisted.python import failure +import threading, sys + +from IPython.kernel import pbconfig +from IPython.kernel.error import PBMessageSizeError, UnpickleableException + + +#------------------------------------------------------------------------------- +# The actual utilities +#------------------------------------------------------------------------------- + +def packageFailure(f): + """Clean and pickle a failure preappending the string FAILURE:""" + + f.cleanFailure() + # This is sometimes helpful in debugging + #f.raiseException() + try: + pString = pickle.dumps(f, 2) + except pickle.PicklingError: + # Certain types of exceptions are not pickleable, for instance ones + # from Boost.Python. We try to wrap them in something that is + f.type = UnpickleableException + f.value = UnpickleableException(str(f.type) + ": " + str(f.value)) + pString = pickle.dumps(f, 2) + return 'FAILURE:' + pString + +def unpackageFailure(r): + """ + See if a returned value is a pickled Failure object. + + To distinguish between general pickled objects and pickled Failures, the + other side should prepend the string FAILURE: to any pickled Failure. + """ + if isinstance(r, str): + if r.startswith('FAILURE:'): + try: + result = pickle.loads(r[8:]) + except pickle.PickleError: + return failure.Failure( \ + FailureUnpickleable("Could not unpickle failure.")) + else: + return result + return r + +def checkMessageSize(m, info): + """Check string m to see if it violates banana.SIZE_LIMIT. + + This should be used on the client side of things for push, scatter + and push_serialized and on the other end for pull, gather and pull_serialized. + + :Parameters: + `m` : string + Message whose size will be checked. + `info` : string + String describing what object the message refers to. + + :Exceptions: + - `PBMessageSizeError`: Raised in the message is > banana.SIZE_LIMIT + + :returns: The original message or a Failure wrapping a PBMessageSizeError + """ + + if len(m) > pbconfig.banana.SIZE_LIMIT: + s = """Objects too big to transfer: +Names: %s +Actual Size (kB): %d +SIZE_LIMIT (kB): %d +* SIZE_LIMIT can be set in kernel.pbconfig""" \ + % (info, len(m)/1024, pbconfig.banana.SIZE_LIMIT/1024) + return Failure(PBMessageSizeError(s)) + else: + return m \ No newline at end of file diff --git a/IPython/kernel/pendingdeferred.py b/IPython/kernel/pendingdeferred.py new file mode 100644 index 0000000..40c890b --- /dev/null +++ b/IPython/kernel/pendingdeferred.py @@ -0,0 +1,178 @@ +# encoding: utf-8 +# -*- test-case-name: IPython.kernel.test.test_pendingdeferred -*- + +"""Classes to manage pending Deferreds. + +A pending deferred is a deferred that may or may not have fired. This module +is useful for taking a class whose methods return deferreds and wrapping it to +provide API that keeps track of those deferreds for later retrieval. See the +tests for examples of its usage. +""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from twisted.application import service +from twisted.internet import defer, reactor +from twisted.python import log, components, failure +from zope.interface import Interface, implements, Attribute + +from IPython.kernel.twistedutil import gatherBoth +from IPython.kernel import error +from IPython.external import guid +from IPython.tools import growl + +class PendingDeferredManager(object): + """A class to track pending deferreds. + + To track a pending deferred, the user of this class must first + get a deferredID by calling `get_next_deferred_id`. Then the user + calls `save_pending_deferred` passing that id and the deferred to + be tracked. To later retrieve it, the user calls + `get_pending_deferred` passing the id. + """ + + def __init__(self): + """Manage pending deferreds.""" + + self.results = {} # Populated when results are ready + self.deferred_ids = [] # List of deferred ids I am managing + self.deferreds_to_callback = {} # dict of lists of deferreds to callback + + def get_deferred_id(self): + return guid.generate() + + def quick_has_id(self, deferred_id): + return deferred_id in self.deferred_ids + + def _save_result(self, result, deferred_id): + if self.quick_has_id(deferred_id): + self.results[deferred_id] = result + self._trigger_callbacks(deferred_id) + + def _trigger_callbacks(self, deferred_id): + # Go through and call the waiting callbacks + result = self.results.get(deferred_id) + if result is not None: # Only trigger if there is a result + try: + d = self.deferreds_to_callback.pop(deferred_id) + except KeyError: + d = None + if d is not None: + if isinstance(result, failure.Failure): + d.errback(result) + else: + d.callback(result) + self.delete_pending_deferred(deferred_id) + + def save_pending_deferred(self, d, deferred_id=None): + """Save the result of a deferred for later retrieval. + + This works even if the deferred has not fired. + + Only callbacks and errbacks applied to d before this method + is called will be called no the final result. + """ + if deferred_id is None: + deferred_id = self.get_deferred_id() + self.deferred_ids.append(deferred_id) + d.addBoth(self._save_result, deferred_id) + return deferred_id + + def _protected_del(self, key, container): + try: + del container[key] + except Exception: + pass + + def delete_pending_deferred(self, deferred_id): + """Remove a deferred I am tracking and add a null Errback. + + :Parameters: + deferredID : str + The id of a deferred that I am tracking. + """ + if self.quick_has_id(deferred_id): + # First go through a errback any deferreds that are still waiting + d = self.deferreds_to_callback.get(deferred_id) + if d is not None: + d.errback(failure.Failure(error.AbortedPendingDeferredError("pending deferred has been deleted: %r"%deferred_id))) + # Now delete all references to this deferred_id + ind = self.deferred_ids.index(deferred_id) + self._protected_del(ind, self.deferred_ids) + self._protected_del(deferred_id, self.deferreds_to_callback) + self._protected_del(deferred_id, self.results) + else: + raise error.InvalidDeferredID('invalid deferred_id: %r' % deferred_id) + + def clear_pending_deferreds(self): + """Remove all the deferreds I am tracking.""" + for did in self.deferred_ids: + self.delete_pending_deferred(did) + + def _delete_and_pass_through(self, r, deferred_id): + self.delete_pending_deferred(deferred_id) + return r + + def get_pending_deferred(self, deferred_id, block): + if not self.quick_has_id(deferred_id) or self.deferreds_to_callback.get(deferred_id) is not None: + return defer.fail(failure.Failure(error.InvalidDeferredID('invalid deferred_id: %r' + deferred_id))) + result = self.results.get(deferred_id) + if result is not None: + self.delete_pending_deferred(deferred_id) + if isinstance(result, failure.Failure): + return defer.fail(result) + else: + return defer.succeed(result) + else: # Result is not ready + if block: + d = defer.Deferred() + self.deferreds_to_callback[deferred_id] = d + return d + else: + return defer.fail(failure.Failure(error.ResultNotCompleted("result not completed: %r" % deferred_id))) + +def two_phase(wrapped_method): + """Wrap methods that return a deferred into a two phase process. + + This transforms:: + + foo(arg1, arg2, ...) -> foo(arg1, arg2,...,block=True). + + The wrapped method will then return a deferred to a deferred id. This will + only work on method of classes that inherit from `PendingDeferredManager`, + as that class provides an API for + + block is a boolean to determine if we should use the two phase process or + just simply call the wrapped method. At this point block does not have a + default and it probably won't. + """ + + def wrapper_two_phase(pdm, *args, **kwargs): + try: + block = kwargs.pop('block') + except KeyError: + block = True # The default if not specified + if block: + return wrapped_method(pdm, *args, **kwargs) + else: + d = wrapped_method(pdm, *args, **kwargs) + deferred_id=pdm.save_pending_deferred(d) + return defer.succeed(deferred_id) + + return wrapper_two_phase + + + + + diff --git a/IPython/kernel/pickleutil.py b/IPython/kernel/pickleutil.py new file mode 100644 index 0000000..087a61c --- /dev/null +++ b/IPython/kernel/pickleutil.py @@ -0,0 +1,83 @@ +# encoding: utf-8 + +"""Pickle related utilities.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from types import FunctionType +from twisted.python import log + +class CannedObject(object): + pass + +class CannedFunction(CannedObject): + + def __init__(self, f): + self._checkType(f) + self.code = f.func_code + + def _checkType(self, obj): + assert isinstance(obj, FunctionType), "Not a function type" + + def getFunction(self, g=None): + if g is None: + g = globals() + newFunc = FunctionType(self.code, g) + return newFunc + +def can(obj): + if isinstance(obj, FunctionType): + return CannedFunction(obj) + else: + return obj + +def canDict(obj): + if isinstance(obj, dict): + for k, v in obj.iteritems(): + obj[k] = can(v) + return obj + else: + return obj + +def canSequence(obj): + if isinstance(obj, (list, tuple)): + t = type(obj) + return t([can(i) for i in obj]) + else: + return obj + +def uncan(obj, g=None): + if isinstance(obj, CannedFunction): + return obj.getFunction(g) + else: + return obj + +def uncanDict(obj, g=None): + if isinstance(obj, dict): + for k, v in obj.iteritems(): + obj[k] = uncan(v,g) + return obj + else: + return obj + +def uncanSequence(obj, g=None): + if isinstance(obj, (list, tuple)): + t = type(obj) + return t([uncan(i,g) for i in obj]) + else: + return obj + + +def rebindFunctionGlobals(f, glbls): + return FunctionType(f.func_code, glbls) diff --git a/IPython/kernel/scripts/__init__.py b/IPython/kernel/scripts/__init__.py new file mode 100644 index 0000000..4e77672 --- /dev/null +++ b/IPython/kernel/scripts/__init__.py @@ -0,0 +1,16 @@ +# encoding: utf-8 + +"""""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- \ No newline at end of file diff --git a/IPython/kernel/scripts/ipcluster b/IPython/kernel/scripts/ipcluster new file mode 100644 index 0000000..362a725 --- /dev/null +++ b/IPython/kernel/scripts/ipcluster @@ -0,0 +1,22 @@ +#!/usr/bin/env python +# encoding: utf-8 + +"""ipcluster script""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +if __name__ == '__main__': + from IPython.kernel.scripts import ipcluster + ipcluster.main() + diff --git a/IPython/kernel/scripts/ipcluster.py b/IPython/kernel/scripts/ipcluster.py new file mode 100755 index 0000000..5f5795d --- /dev/null +++ b/IPython/kernel/scripts/ipcluster.py @@ -0,0 +1,323 @@ +#!/usr/bin/env python +# encoding: utf-8 + +"""Start an IPython cluster conveniently, either locally or remotely. + +Basic usage +----------- + +For local operation, the simplest mode of usage is: + + %prog -n N + +where N is the number of engines you want started. + +For remote operation, you must call it with a cluster description file: + + %prog -f clusterfile.py + +The cluster file is a normal Python script which gets run via execfile(). You +can have arbitrary logic in it, but all that matters is that at the end of the +execution, it declares the variables 'controller', 'engines', and optionally +'sshx'. See the accompanying examples for details on what these variables must +contain. + + +Notes +----- + +WARNING: this code is still UNFINISHED and EXPERIMENTAL! It is incomplete, +some listed options are not really implemented, and all of its interfaces are +subject to change. + +When operating over SSH for a remote cluster, this program relies on the +existence of a particular script called 'sshx'. This script must live in the +target systems where you'll be running your controller and engines, and is +needed to configure your PATH and PYTHONPATH variables for further execution of +python code at the other end of an SSH connection. The script can be as simple +as: + +#!/bin/sh +. $HOME/.bashrc +"$@" + +which is the default one provided by IPython. You can modify this or provide +your own. Since it's quite likely that for different clusters you may need +this script to configure things differently or that it may live in different +locations, its full path can be set in the same file where you define the +cluster setup. IPython's order of evaluation for this variable is the +following: + + a) Internal default: 'sshx'. This only works if it is in the default system + path which SSH sets up in non-interactive mode. + + b) Environment variable: if $IPYTHON_SSHX is defined, this overrides the + internal default. + + c) Variable 'sshx' in the cluster configuration file: finally, this will + override the previous two values. + +This code is Unix-only, with precious little hope of any of this ever working +under Windows, since we need SSH from the ground up, we background processes, +etc. Ports of this functionality to Windows are welcome. + + +Call summary +------------ + + %prog [options] +""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Stdlib imports +#------------------------------------------------------------------------------- + +import os +import signal +import sys +import time + +from optparse import OptionParser +from subprocess import Popen,call + +#--------------------------------------------------------------------------- +# IPython imports +#--------------------------------------------------------------------------- +from IPython.tools import utils +from IPython.config import cutils + +#--------------------------------------------------------------------------- +# Normal code begins +#--------------------------------------------------------------------------- + +def parse_args(): + """Parse command line and return opts,args.""" + + parser = OptionParser(usage=__doc__) + newopt = parser.add_option # shorthand + + newopt("--controller-port", type="int", dest="controllerport", + help="the TCP port the controller is listening on") + + newopt("--controller-ip", type="string", dest="controllerip", + help="the TCP ip address of the controller") + + newopt("-n", "--num", type="int", dest="n",default=2, + help="the number of engines to start") + + newopt("--engine-port", type="int", dest="engineport", + help="the TCP port the controller will listen on for engine " + "connections") + + newopt("--engine-ip", type="string", dest="engineip", + help="the TCP ip address the controller will listen on " + "for engine connections") + + newopt("--mpi", type="string", dest="mpi", + help="use mpi with package: for instance --mpi=mpi4py") + + newopt("-l", "--logfile", type="string", dest="logfile", + help="log file name") + + newopt('-f','--cluster-file',dest='clusterfile', + help='file describing a remote cluster') + + return parser.parse_args() + +def numAlive(controller,engines): + """Return the number of processes still alive.""" + retcodes = [controller.poll()] + \ + [e.poll() for e in engines] + return retcodes.count(None) + +stop = lambda pid: os.kill(pid,signal.SIGINT) +kill = lambda pid: os.kill(pid,signal.SIGTERM) + +def cleanup(clean,controller,engines): + """Stop the controller and engines with the given cleanup method.""" + + for e in engines: + if e.poll() is None: + print 'Stopping engine, pid',e.pid + clean(e.pid) + if controller.poll() is None: + print 'Stopping controller, pid',controller.pid + clean(controller.pid) + + +def ensureDir(path): + """Ensure a directory exists or raise an exception.""" + if not os.path.isdir(path): + os.makedirs(path) + + +def startMsg(control_host,control_port=10105): + """Print a startup message""" + print + print 'Your cluster is up and running.' + print + print 'For interactive use, you can make a MultiEngineClient with:' + print + print 'from IPython.kernel import client' + print "mec = client.MultiEngineClient((%r,%s))" % \ + (control_host,control_port) + print + print 'You can then cleanly stop the cluster from IPython using:' + print + print 'mec.kill(controller=True)' + print + + +def clusterLocal(opt,arg): + """Start a cluster on the local machine.""" + + # Store all logs inside the ipython directory + ipdir = cutils.get_ipython_dir() + pjoin = os.path.join + + logfile = opt.logfile + if logfile is None: + logdir_base = pjoin(ipdir,'log') + ensureDir(logdir_base) + logfile = pjoin(logdir_base,'ipcluster-') + + print 'Starting controller:', + controller = Popen(['ipcontroller','--logfile',logfile]) + print 'Controller PID:',controller.pid + + print 'Starting engines: ', + time.sleep(3) + + englogfile = '%s%s-' % (logfile,controller.pid) + mpi = opt.mpi + if mpi: # start with mpi - killing the engines with sigterm will not work if you do this + engines = [Popen(['mpirun', '-np', str(opt.n), 'ipengine', '--mpi', mpi, '--logfile',englogfile])] + else: # do what we would normally do + engines = [ Popen(['ipengine','--logfile',englogfile]) + for i in range(opt.n) ] + eids = [e.pid for e in engines] + print 'Engines PIDs: ',eids + print 'Log files: %s*' % englogfile + + proc_ids = eids + [controller.pid] + procs = engines + [controller] + + grpid = os.getpgrp() + try: + startMsg('127.0.0.1') + print 'You can also hit Ctrl-C to stop it, or use from the cmd line:' + print + print 'kill -INT',grpid + print + try: + while True: + time.sleep(5) + except: + pass + finally: + print 'Stopping cluster. Cleaning up...' + cleanup(stop,controller,engines) + for i in range(4): + time.sleep(i+2) + nZombies = numAlive(controller,engines) + if nZombies== 0: + print 'OK: All processes cleaned up.' + break + print 'Trying again, %d processes did not stop...' % nZombies + cleanup(kill,controller,engines) + if numAlive(controller,engines) == 0: + print 'OK: All processes cleaned up.' + break + else: + print '*'*75 + print 'ERROR: could not kill some processes, try to do it', + print 'manually.' + zombies = [] + if controller.returncode is None: + print 'Controller is alive: pid =',controller.pid + zombies.append(controller.pid) + liveEngines = [ e for e in engines if e.returncode is None ] + for e in liveEngines: + print 'Engine is alive: pid =',e.pid + zombies.append(e.pid) + print + print 'Zombie summary:',' '.join(map(str,zombies)) + +def clusterRemote(opt,arg): + """Start a remote cluster over SSH""" + + # Load the remote cluster configuration + clConfig = {} + execfile(opt.clusterfile,clConfig) + contConfig = clConfig['controller'] + engConfig = clConfig['engines'] + # Determine where to find sshx: + sshx = clConfig.get('sshx',os.environ.get('IPYTHON_SSHX','sshx')) + + # Store all logs inside the ipython directory + ipdir = cutils.get_ipython_dir() + pjoin = os.path.join + + logfile = opt.logfile + if logfile is None: + logdir_base = pjoin(ipdir,'log') + ensureDir(logdir_base) + logfile = pjoin(logdir_base,'ipcluster') + + # Append this script's PID to the logfile name always + logfile = '%s-%s' % (logfile,os.getpid()) + + print 'Starting controller:' + # Controller data: + xsys = os.system + + contHost = contConfig['host'] + contLog = '%s-con-%s-' % (logfile,contHost) + cmd = "ssh %s '%s' 'ipcontroller --logfile %s' &" % \ + (contHost,sshx,contLog) + #print 'cmd:<%s>' % cmd # dbg + xsys(cmd) + time.sleep(2) + + print 'Starting engines: ' + for engineHost,engineData in engConfig.iteritems(): + if isinstance(engineData,int): + numEngines = engineData + else: + raise NotImplementedError('port configuration not finished for engines') + + print 'Sarting %d engines on %s' % (numEngines,engineHost) + engLog = '%s-eng-%s-' % (logfile,engineHost) + for i in range(numEngines): + cmd = "ssh %s '%s' 'ipengine --controller-ip %s --logfile %s' &" % \ + (engineHost,sshx,contHost,engLog) + #print 'cmd:<%s>' % cmd # dbg + xsys(cmd) + # Wait after each host a little bit + time.sleep(1) + + startMsg(contConfig['host']) + +def main(): + """Main driver for the two big options: local or remote cluster.""" + + opt,arg = parse_args() + + clusterfile = opt.clusterfile + if clusterfile: + clusterRemote(opt,arg) + else: + clusterLocal(opt,arg) + + +if __name__=='__main__': + main() diff --git a/IPython/kernel/scripts/ipcontroller b/IPython/kernel/scripts/ipcontroller new file mode 100644 index 0000000..b4c86a4 --- /dev/null +++ b/IPython/kernel/scripts/ipcontroller @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# encoding: utf-8 + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +if __name__ == '__main__': + from IPython.kernel.scripts import ipcontroller + ipcontroller.main() + diff --git a/IPython/kernel/scripts/ipcontroller.py b/IPython/kernel/scripts/ipcontroller.py new file mode 100755 index 0000000..26736c4 --- /dev/null +++ b/IPython/kernel/scripts/ipcontroller.py @@ -0,0 +1,366 @@ +#!/usr/bin/env python +# encoding: utf-8 + +"""The IPython controller.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +# Python looks for an empty string at the beginning of sys.path to enable +# importing from the cwd. +import sys +sys.path.insert(0, '') + +import sys, time, os +from optparse import OptionParser + +from twisted.application import internet, service +from twisted.internet import reactor, error, defer +from twisted.python import log + +from IPython.kernel.fcutil import Tub, UnauthenticatedTub, have_crypto + +# from IPython.tools import growl +# growl.start("IPython1 Controller") + +from IPython.kernel.error import SecurityError +from IPython.kernel import controllerservice +from IPython.kernel.fcutil import check_furl_file_security + +from IPython.kernel.config import config_manager as kernel_config_manager +from IPython.config.cutils import import_item + + +#------------------------------------------------------------------------------- +# Code +#------------------------------------------------------------------------------- + +def make_tub(ip, port, secure, cert_file): + """ + Create a listening tub given an ip, port, and cert_file location. + + :Parameters: + ip : str + The ip address that the tub should listen on. Empty means all + port : int + The port that the tub should listen on. A value of 0 means + pick a random port + secure: boolean + Will the connection be secure (in the foolscap sense) + cert_file: + A filename of a file to be used for theSSL certificate + """ + if secure: + if have_crypto: + tub = Tub(certFile=cert_file) + else: + raise SecurityError("OpenSSL is not available, so we can't run in secure mode, aborting") + else: + tub = UnauthenticatedTub() + + # Set the strport based on the ip and port and start listening + if ip == '': + strport = "tcp:%i" % port + else: + strport = "tcp:%i:interface=%s" % (port, ip) + listener = tub.listenOn(strport) + + return tub, listener + +def make_client_service(controller_service, config): + """ + Create a service that will listen for clients. + + This service is simply a `foolscap.Tub` instance that has a set of Referenceables + registered with it. + """ + + # Now create the foolscap tub + ip = config['controller']['client_tub']['ip'] + port = config['controller']['client_tub'].as_int('port') + location = config['controller']['client_tub']['location'] + secure = config['controller']['client_tub']['secure'] + cert_file = config['controller']['client_tub']['cert_file'] + client_tub, client_listener = make_tub(ip, port, secure, cert_file) + + # Set the location in the trivial case of localhost + if ip == 'localhost' or ip == '127.0.0.1': + location = "127.0.0.1" + + if not secure: + log.msg("WARNING: you are running the controller with no client security") + + def set_location_and_register(): + """Set the location for the tub and return a deferred.""" + + def register(empty, ref, furl_file): + client_tub.registerReference(ref, furlFile=furl_file) + + if location == '': + d = client_tub.setLocationAutomatically() + else: + d = defer.maybeDeferred(client_tub.setLocation, "%s:%i" % (location, client_listener.getPortnum())) + + for ciname, ci in config['controller']['controller_interfaces'].iteritems(): + log.msg("Adapting Controller to interface: %s" % ciname) + furl_file = ci['furl_file'] + log.msg("Saving furl for interface [%s] to file: %s" % (ciname, furl_file)) + check_furl_file_security(furl_file, secure) + adapted_controller = import_item(ci['controller_interface'])(controller_service) + d.addCallback(register, import_item(ci['fc_interface'])(adapted_controller), + furl_file=ci['furl_file']) + + reactor.callWhenRunning(set_location_and_register) + return client_tub + + +def make_engine_service(controller_service, config): + """ + Create a service that will listen for engines. + + This service is simply a `foolscap.Tub` instance that has a set of Referenceables + registered with it. + """ + + # Now create the foolscap tub + ip = config['controller']['engine_tub']['ip'] + port = config['controller']['engine_tub'].as_int('port') + location = config['controller']['engine_tub']['location'] + secure = config['controller']['engine_tub']['secure'] + cert_file = config['controller']['engine_tub']['cert_file'] + engine_tub, engine_listener = make_tub(ip, port, secure, cert_file) + + # Set the location in the trivial case of localhost + if ip == 'localhost' or ip == '127.0.0.1': + location = "127.0.0.1" + + if not secure: + log.msg("WARNING: you are running the controller with no engine security") + + def set_location_and_register(): + """Set the location for the tub and return a deferred.""" + + def register(empty, ref, furl_file): + engine_tub.registerReference(ref, furlFile=furl_file) + + if location == '': + d = engine_tub.setLocationAutomatically() + else: + d = defer.maybeDeferred(engine_tub.setLocation, "%s:%i" % (location, engine_listener.getPortnum())) + + furl_file = config['controller']['engine_furl_file'] + engine_fc_interface = import_item(config['controller']['engine_fc_interface']) + log.msg("Saving furl for the engine to file: %s" % furl_file) + check_furl_file_security(furl_file, secure) + fc_controller = engine_fc_interface(controller_service) + d.addCallback(register, fc_controller, furl_file=furl_file) + + reactor.callWhenRunning(set_location_and_register) + return engine_tub + +def start_controller(): + """ + Start the controller by creating the service hierarchy and starting the reactor. + + This method does the following: + + * It starts the controller logging + * In execute an import statement for the controller + * It creates 2 `foolscap.Tub` instances for the client and the engines + and registers `foolscap.Referenceables` with the tubs to expose the + controller to engines and clients. + """ + config = kernel_config_manager.get_config_obj() + + # Start logging + logfile = config['controller']['logfile'] + if logfile: + logfile = logfile + str(os.getpid()) + '.log' + try: + openLogFile = open(logfile, 'w') + except: + openLogFile = sys.stdout + else: + openLogFile = sys.stdout + log.startLogging(openLogFile) + + # Execute any user defined import statements + cis = config['controller']['import_statement'] + if cis: + try: + exec cis in globals(), locals() + except: + log.msg("Error running import_statement: %s" % cis) + + # Create the service hierarchy + main_service = service.MultiService() + # The controller service + controller_service = controllerservice.ControllerService() + controller_service.setServiceParent(main_service) + # The client tub and all its refereceables + client_service = make_client_service(controller_service, config) + client_service.setServiceParent(main_service) + # The engine tub + engine_service = make_engine_service(controller_service, config) + engine_service.setServiceParent(main_service) + # Start the controller service and set things running + main_service.startService() + reactor.run() + +def init_config(): + """ + Initialize the configuration using default and command line options. + """ + + parser = OptionParser() + + # Client related options + parser.add_option( + "--client-ip", + type="string", + dest="client_ip", + help="the IP address or hostname the controller will listen on for client connections" + ) + parser.add_option( + "--client-port", + type="int", + dest="client_port", + help="the port the controller will listen on for client connections" + ) + parser.add_option( + '--client-location', + type="string", + dest="client_location", + help="hostname or ip for clients to connect to" + ) + parser.add_option( + "-x", + action="store_false", + dest="client_secure", + help="turn off all client security" + ) + parser.add_option( + '--client-cert-file', + type="string", + dest="client_cert_file", + help="file to store the client SSL certificate" + ) + parser.add_option( + '--task-furl-file', + type="string", + dest="task_furl_file", + help="file to store the FURL for task clients to connect with" + ) + parser.add_option( + '--multiengine-furl-file', + type="string", + dest="multiengine_furl_file", + help="file to store the FURL for multiengine clients to connect with" + ) + # Engine related options + parser.add_option( + "--engine-ip", + type="string", + dest="engine_ip", + help="the IP address or hostname the controller will listen on for engine connections" + ) + parser.add_option( + "--engine-port", + type="int", + dest="engine_port", + help="the port the controller will listen on for engine connections" + ) + parser.add_option( + '--engine-location', + type="string", + dest="engine_location", + help="hostname or ip for engines to connect to" + ) + parser.add_option( + "-y", + action="store_false", + dest="engine_secure", + help="turn off all engine security" + ) + parser.add_option( + '--engine-cert-file', + type="string", + dest="engine_cert_file", + help="file to store the engine SSL certificate" + ) + parser.add_option( + '--engine-furl-file', + type="string", + dest="engine_furl_file", + help="file to store the FURL for engines to connect with" + ) + parser.add_option( + "-l", "--logfile", + type="string", + dest="logfile", + help="log file name (default is stdout)" + ) + parser.add_option( + "--ipythondir", + type="string", + dest="ipythondir", + help="look for config files and profiles in this directory" + ) + + (options, args) = parser.parse_args() + + kernel_config_manager.update_config_obj_from_default_file(options.ipythondir) + config = kernel_config_manager.get_config_obj() + + # Update with command line options + if options.client_ip is not None: + config['controller']['client_tub']['ip'] = options.client_ip + if options.client_port is not None: + config['controller']['client_tub']['port'] = options.client_port + if options.client_location is not None: + config['controller']['client_tub']['location'] = options.client_location + if options.client_secure is not None: + config['controller']['client_tub']['secure'] = options.client_secure + if options.client_cert_file is not None: + config['controller']['client_tub']['cert_file'] = options.client_cert_file + if options.task_furl_file is not None: + config['controller']['controller_interfaces']['task']['furl_file'] = options.task_furl_file + if options.multiengine_furl_file is not None: + config['controller']['controller_interfaces']['multiengine']['furl_file'] = options.multiengine_furl_file + if options.engine_ip is not None: + config['controller']['engine_tub']['ip'] = options.engine_ip + if options.engine_port is not None: + config['controller']['engine_tub']['port'] = options.engine_port + if options.engine_location is not None: + config['controller']['engine_tub']['location'] = options.engine_location + if options.engine_secure is not None: + config['controller']['engine_tub']['secure'] = options.engine_secure + if options.engine_cert_file is not None: + config['controller']['engine_tub']['cert_file'] = options.engine_cert_file + if options.engine_furl_file is not None: + config['controller']['engine_furl_file'] = options.engine_furl_file + + if options.logfile is not None: + config['controller']['logfile'] = options.logfile + + kernel_config_manager.update_config_obj(config) + +def main(): + """ + After creating the configuration information, start the controller. + """ + init_config() + start_controller() + +if __name__ == "__main__": + main() diff --git a/IPython/kernel/scripts/ipengine b/IPython/kernel/scripts/ipengine new file mode 100644 index 0000000..92eab1c --- /dev/null +++ b/IPython/kernel/scripts/ipengine @@ -0,0 +1,20 @@ +#!/usr/bin/env python +# encoding: utf-8 + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +if __name__ == '__main__': + from IPython.kernel.scripts import ipengine + ipengine.main() + diff --git a/IPython/kernel/scripts/ipengine.py b/IPython/kernel/scripts/ipengine.py new file mode 100755 index 0000000..9f02ea1 --- /dev/null +++ b/IPython/kernel/scripts/ipengine.py @@ -0,0 +1,169 @@ +#!/usr/bin/env python +# encoding: utf-8 + +"""Start the IPython Engine.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +# Python looks for an empty string at the beginning of sys.path to enable +# importing from the cwd. +import sys +sys.path.insert(0, '') + +import sys, os +from optparse import OptionParser + +from twisted.application import service +from twisted.internet import reactor +from twisted.python import log + +from IPython.kernel.fcutil import Tub, UnauthenticatedTub + +from IPython.kernel.core.config import config_manager as core_config_manager +from IPython.config.cutils import import_item +from IPython.kernel.engineservice import EngineService +from IPython.kernel.config import config_manager as kernel_config_manager +from IPython.kernel.engineconnector import EngineConnector + + +#------------------------------------------------------------------------------- +# Code +#------------------------------------------------------------------------------- + +def start_engine(): + """ + Start the engine, by creating it and starting the Twisted reactor. + + This method does: + + * If it exists, runs the `mpi_import_statement` to call `MPI_Init` + * Starts the engine logging + * Creates an IPython shell and wraps it in an `EngineService` + * Creates a `foolscap.Tub` to use in connecting to a controller. + * Uses the tub and the `EngineService` along with a Foolscap URL + (or FURL) to connect to the controller and register the engine + with the controller + """ + kernel_config = kernel_config_manager.get_config_obj() + core_config = core_config_manager.get_config_obj() + + # Execute the mpi import statement that needs to call MPI_Init + mpikey = kernel_config['mpi']['default'] + mpi_import_statement = kernel_config['mpi'].get(mpikey, None) + if mpi_import_statement is not None: + try: + exec mpi_import_statement in locals(), globals() + except: + mpi = None + else: + mpi = None + + # Start logging + logfile = kernel_config['engine']['logfile'] + if logfile: + logfile = logfile + str(os.getpid()) + '.log' + try: + openLogFile = open(logfile, 'w') + except: + openLogFile = sys.stdout + else: + openLogFile = sys.stdout + log.startLogging(openLogFile) + + # Create the underlying shell class and EngineService + shell_class = import_item(core_config['shell']['shell_class']) + engine_service = EngineService(shell_class, mpi=mpi) + shell_import_statement = core_config['shell']['import_statement'] + if shell_import_statement: + try: + engine_service.execute(shell_import_statement) + except: + log.msg("Error running import_statement: %s" % sis) + + # Create the service hierarchy + main_service = service.MultiService() + engine_service.setServiceParent(main_service) + tub_service = Tub() + tub_service.setServiceParent(main_service) + # This needs to be called before the connection is initiated + main_service.startService() + + # This initiates the connection to the controller and calls + # register_engine to tell the controller we are ready to do work + engine_connector = EngineConnector(tub_service) + furl_file = kernel_config['engine']['furl_file'] + d = engine_connector.connect_to_controller(engine_service, furl_file) + d.addErrback(lambda _: reactor.stop()) + + reactor.run() + + +def init_config(): + """ + Initialize the configuration using default and command line options. + """ + + parser = OptionParser() + + parser.add_option( + "--furl-file", + type="string", + dest="furl_file", + help="The filename containing the FURL of the controller" + ) + parser.add_option( + "--mpi", + type="string", + dest="mpi", + help="How to enable MPI (mpi4py, pytrilinos, or empty string to disable)" + ) + parser.add_option( + "-l", + "--logfile", + type="string", + dest="logfile", + help="log file name (default is stdout)" + ) + parser.add_option( + "--ipythondir", + type="string", + dest="ipythondir", + help="look for config files and profiles in this directory" + ) + + (options, args) = parser.parse_args() + + kernel_config_manager.update_config_obj_from_default_file(options.ipythondir) + core_config_manager.update_config_obj_from_default_file(options.ipythondir) + + kernel_config = kernel_config_manager.get_config_obj() + # Now override with command line options + if options.furl_file is not None: + kernel_config['engine']['furl_file'] = options.furl_file + if options.logfile is not None: + kernel_config['engine']['logfile'] = options.logfile + if options.mpi is not None: + kernel_config['mpi']['default'] = options.mpi + + +def main(): + """ + After creating the configuration information, start the engine. + """ + init_config() + start_engine() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/IPython/kernel/task.py b/IPython/kernel/task.py new file mode 100644 index 0000000..18ecc6d --- /dev/null +++ b/IPython/kernel/task.py @@ -0,0 +1,799 @@ +# encoding: utf-8 +# -*- test-case-name: IPython.kernel.tests.test_task -*- + +"""Task farming representation of the ControllerService.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import copy, time +from types import FunctionType as function + +import zope.interface as zi, string +from twisted.internet import defer, reactor +from twisted.python import components, log, failure + +# from IPython.genutils import time + +from IPython.kernel import engineservice as es, error +from IPython.kernel import controllerservice as cs +from IPython.kernel.twistedutil import gatherBoth, DeferredList + +from IPython.kernel.pickleutil import can,uncan, CannedFunction + +def canTask(task): + t = copy.copy(task) + t.depend = can(t.depend) + if t.recovery_task: + t.recovery_task = canTask(t.recovery_task) + return t + +def uncanTask(task): + t = copy.copy(task) + t.depend = uncan(t.depend) + if t.recovery_task and t.recovery_task is not task: + t.recovery_task = uncanTask(t.recovery_task) + return t + +time_format = '%Y/%m/%d %H:%M:%S' + +class Task(object): + """Our representation of a task for the `TaskController` interface. + + The user should create instances of this class to represent a task that + needs to be done. + + :Parameters: + expression : str + A str that is valid python code that is the task. + pull : str or list of str + The names of objects to be pulled as results. If not specified, + will return {'result', None} + push : dict + A dict of objects to be pushed into the engines namespace before + execution of the expression. + clear_before : boolean + Should the engine's namespace be cleared before the task is run. + Default=False. + clear_after : boolean + Should the engine's namespace be cleared after the task is run. + Default=False. + retries : int + The number of times to resumbit the task if it fails. Default=0. + recovery_task : Task + This is the Task to be run when the task has exhausted its retries + Default=None. + depend : bool function(properties) + This is the dependency function for the Task, which determines + whether a task can be run on a Worker. `depend` is called with + one argument, the worker's properties dict, and should return + True if the worker meets the dependencies or False if it does + not. + Default=None - run on any worker + options : dict + Any other keyword options for more elaborate uses of tasks + + Examples + -------- + + >>> t = Task('dostuff(args)') + >>> t = Task('a=5', pull='a') + >>> t = Task('a=5\nb=4', pull=['a','b']) + >>> t = Task('os.kill(os.getpid(),9)', retries=100) # this is a bad idea + # A dependency case: + >>> def hasMPI(props): + ... return props.get('mpi') is not None + >>> t = Task('mpi.send(blah,blah)', depend = hasMPI) + """ + + def __init__(self, expression, pull=None, push=None, + clear_before=False, clear_after=False, retries=0, + recovery_task=None, depend=None, **options): + self.expression = expression + if isinstance(pull, str): + self.pull = [pull] + else: + self.pull = pull + self.push = push + self.clear_before = clear_before + self.clear_after = clear_after + self.retries=retries + self.recovery_task = recovery_task + self.depend = depend + self.options = options + self.taskid = None + +class ResultNS: + """The result namespace object for use in TaskResult objects as tr.ns. + It builds an object from a dictionary, such that it has attributes + according to the key,value pairs of the dictionary. + + This works by calling setattr on ALL key,value pairs in the dict. If a user + chooses to overwrite the `__repr__` or `__getattr__` attributes, they can. + This can be a bad idea, as it may corrupt standard behavior of the + ns object. + + Example + -------- + + >>> ns = ResultNS({'a':17,'foo':range(3)}) + >>> print ns + NS{'a':17,'foo':range(3)} + >>> ns.a + 17 + >>> ns['foo'] + [0,1,2] + """ + def __init__(self, dikt): + for k,v in dikt.iteritems(): + setattr(self,k,v) + + def __repr__(self): + l = dir(self) + d = {} + for k in l: + # do not print private objects + if k[:2] != '__' and k[-2:] != '__': + d[k] = getattr(self, k) + return "NS"+repr(d) + + def __getitem__(self, key): + return getattr(self, key) + +class TaskResult(object): + """ + An object for returning task results. + + This object encapsulates the results of a task. On task + success it will have a keys attribute that will have a list + of the variables that have been pulled back. These variables + are accessible as attributes of this class as well. On + success the failure attribute will be None. + + In task failure, keys will be empty, but failure will contain + the failure object that encapsulates the remote exception. + One can also simply call the raiseException() method of + this class to re-raise any remote exception in the local + session. + + The TaskResult has a .ns member, which is a property for access + to the results. If the Task had pull=['a', 'b'], then the + Task Result will have attributes tr.ns.a, tr.ns.b for those values. + Accessing tr.ns will raise the remote failure if the task failed. + + The engineid attribute should have the engineid of the engine + that ran the task. But, because engines can come and go in + the ipython task system, the engineid may not continue to be + valid or accurate. + + The taskid attribute simply gives the taskid that the task + is tracked under. + """ + taskid = None + + def _getNS(self): + if isinstance(self.failure, failure.Failure): + return self.failure.raiseException() + else: + return self._ns + + def _setNS(self, v): + raise Exception("I am protected!") + + ns = property(_getNS, _setNS) + + def __init__(self, results, engineid): + self.engineid = engineid + if isinstance(results, failure.Failure): + self.failure = results + self.results = {} + else: + self.results = results + self.failure = None + + self._ns = ResultNS(self.results) + + self.keys = self.results.keys() + + def __repr__(self): + if self.failure is not None: + contents = self.failure + else: + contents = self.results + return "TaskResult[ID:%r]:%r"%(self.taskid, contents) + + def __getitem__(self, key): + if self.failure is not None: + self.raiseException() + return self.results[key] + + def raiseException(self): + """Re-raise any remote exceptions in the local python session.""" + if self.failure is not None: + self.failure.raiseException() + + +class IWorker(zi.Interface): + """The Basic Worker Interface. + + A worked is a representation of an Engine that is ready to run tasks. + """ + + zi.Attribute("workerid", "the id of the worker") + + def run(task): + """Run task in worker's namespace. + + :Parameters: + task : a `Task` object + + :Returns: `Deferred` to a `TaskResult` object. + """ + + +class WorkerFromQueuedEngine(object): + """Adapt an `IQueuedEngine` to an `IWorker` object""" + zi.implements(IWorker) + + def __init__(self, qe): + self.queuedEngine = qe + self.workerid = None + + def _get_properties(self): + return self.queuedEngine.properties + + properties = property(_get_properties, lambda self, _:None) + + def run(self, task): + """Run task in worker's namespace. + + :Parameters: + task : a `Task` object + + :Returns: `Deferred` to a `TaskResult` object. + """ + if task.clear_before: + d = self.queuedEngine.reset() + else: + d = defer.succeed(None) + + if task.push is not None: + d.addCallback(lambda r: self.queuedEngine.push(task.push)) + + d.addCallback(lambda r: self.queuedEngine.execute(task.expression)) + + if task.pull is not None: + d.addCallback(lambda r: self.queuedEngine.pull(task.pull)) + else: + d.addCallback(lambda r: None) + + def reseter(result): + self.queuedEngine.reset() + return result + + if task.clear_after: + d.addBoth(reseter) + + return d.addBoth(self._zipResults, task.pull, time.time(), time.localtime()) + + def _zipResults(self, result, names, start, start_struct): + """Callback for construting the TaskResult object.""" + if isinstance(result, failure.Failure): + tr = TaskResult(result, self.queuedEngine.id) + else: + if names is None: + resultDict = {} + elif len(names) == 1: + resultDict = {names[0]:result} + else: + resultDict = dict(zip(names, result)) + tr = TaskResult(resultDict, self.queuedEngine.id) + # the time info + tr.submitted = time.strftime(time_format, start_struct) + tr.completed = time.strftime(time_format) + tr.duration = time.time()-start + return tr + + +components.registerAdapter(WorkerFromQueuedEngine, es.IEngineQueued, IWorker) + +class IScheduler(zi.Interface): + """The interface for a Scheduler. + """ + zi.Attribute("nworkers", "the number of unassigned workers") + zi.Attribute("ntasks", "the number of unscheduled tasks") + zi.Attribute("workerids", "a list of the worker ids") + zi.Attribute("taskids", "a list of the task ids") + + def add_task(task, **flags): + """Add a task to the queue of the Scheduler. + + :Parameters: + task : a `Task` object + The task to be queued. + flags : dict + General keywords for more sophisticated scheduling + """ + + def pop_task(id=None): + """Pops a Task object. + + This gets the next task to be run. If no `id` is requested, the highest priority + task is returned. + + :Parameters: + id + The id of the task to be popped. The default (None) is to return + the highest priority task. + + :Returns: a `Task` object + + :Exceptions: + IndexError : raised if no taskid in queue + """ + + def add_worker(worker, **flags): + """Add a worker to the worker queue. + + :Parameters: + worker : an IWorker implementing object + flags : General keywords for more sophisticated scheduling + """ + + def pop_worker(id=None): + """Pops an IWorker object that is ready to do work. + + This gets the next IWorker that is ready to do work. + + :Parameters: + id : if specified, will pop worker with workerid=id, else pops + highest priority worker. Defaults to None. + + :Returns: + an IWorker object + + :Exceptions: + IndexError : raised if no workerid in queue + """ + + def ready(): + """Returns True if there is something to do, False otherwise""" + + def schedule(): + """Returns a tuple of the worker and task pair for the next + task to be run. + """ + + +class FIFOScheduler(object): + """A basic First-In-First-Out (Queue) Scheduler. + This is the default Scheduler for the TaskController. + See the docstrings for IScheduler for interface details. + """ + + zi.implements(IScheduler) + + def __init__(self): + self.tasks = [] + self.workers = [] + + def _ntasks(self): + return len(self.tasks) + + def _nworkers(self): + return len(self.workers) + + ntasks = property(_ntasks, lambda self, _:None) + nworkers = property(_nworkers, lambda self, _:None) + + def _taskids(self): + return [t.taskid for t in self.tasks] + + def _workerids(self): + return [w.workerid for w in self.workers] + + taskids = property(_taskids, lambda self,_:None) + workerids = property(_workerids, lambda self,_:None) + + def add_task(self, task, **flags): + self.tasks.append(task) + + def pop_task(self, id=None): + if id is None: + return self.tasks.pop(0) + else: + for i in range(len(self.tasks)): + taskid = self.tasks[i].taskid + if id == taskid: + return self.tasks.pop(i) + raise IndexError("No task #%i"%id) + + def add_worker(self, worker, **flags): + self.workers.append(worker) + + def pop_worker(self, id=None): + if id is None: + return self.workers.pop(0) + else: + for i in range(len(self.workers)): + workerid = self.workers[i].workerid + if id == workerid: + return self.workers.pop(i) + raise IndexError("No worker #%i"%id) + + def schedule(self): + for t in self.tasks: + for w in self.workers: + try:# do not allow exceptions to break this + cando = t.depend is None or t.depend(w.properties) + except: + cando = False + if cando: + return self.pop_worker(w.workerid), self.pop_task(t.taskid) + return None, None + + + +class LIFOScheduler(FIFOScheduler): + """A Last-In-First-Out (Stack) Scheduler. This scheduler should naively + reward fast engines by giving them more jobs. This risks starvation, but + only in cases with low load, where starvation does not really matter. + """ + + def add_task(self, task, **flags): + # self.tasks.reverse() + self.tasks.insert(0, task) + # self.tasks.reverse() + + def add_worker(self, worker, **flags): + # self.workers.reverse() + self.workers.insert(0, worker) + # self.workers.reverse() + + +class ITaskController(cs.IControllerBase): + """The Task based interface to a `ControllerService` object + + This adapts a `ControllerService` to the ITaskController interface. + """ + + def run(task): + """Run a task. + + :Parameters: + task : an IPython `Task` object + + :Returns: the integer ID of the task + """ + + def get_task_result(taskid, block=False): + """Get the result of a task by its ID. + + :Parameters: + taskid : int + the id of the task whose result is requested + + :Returns: `Deferred` to (taskid, actualResult) if the task is done, and None + if not. + + :Exceptions: + actualResult will be an `IndexError` if no such task has been submitted + """ + + def abort(taskid): + """Remove task from queue if task is has not been submitted. + + If the task has already been submitted, wait for it to finish and discard + results and prevent resubmission. + + :Parameters: + taskid : the id of the task to be aborted + + :Returns: + `Deferred` to abort attempt completion. Will be None on success. + + :Exceptions: + deferred will fail with `IndexError` if no such task has been submitted + or the task has already completed. + """ + + def barrier(taskids): + """Block until the list of taskids are completed. + + Returns None on success. + """ + + def spin(): + """touch the scheduler, to resume scheduling without submitting + a task. + """ + + def queue_status(self, verbose=False): + """Get a dictionary with the current state of the task queue. + + If verbose is True, then return lists of taskids, otherwise, + return the number of tasks with each status. + """ + + +class TaskController(cs.ControllerAdapterBase): + """The Task based interface to a Controller object. + + If you want to use a different scheduler, just subclass this and set + the `SchedulerClass` member to the *class* of your chosen scheduler. + """ + + zi.implements(ITaskController) + SchedulerClass = FIFOScheduler + + timeout = 30 + + def __init__(self, controller): + self.controller = controller + self.controller.on_register_engine_do(self.registerWorker, True) + self.controller.on_unregister_engine_do(self.unregisterWorker, True) + self.taskid = 0 + self.failurePenalty = 1 # the time in seconds to penalize + # a worker for failing a task + self.pendingTasks = {} # dict of {workerid:(taskid, task)} + self.deferredResults = {} # dict of {taskid:deferred} + self.finishedResults = {} # dict of {taskid:actualResult} + self.workers = {} # dict of {workerid:worker} + self.abortPending = [] # dict of {taskid:abortDeferred} + self.idleLater = None # delayed call object for timeout + self.scheduler = self.SchedulerClass() + + for id in self.controller.engines.keys(): + self.workers[id] = IWorker(self.controller.engines[id]) + self.workers[id].workerid = id + self.schedule.add_worker(self.workers[id]) + + def registerWorker(self, id): + """Called by controller.register_engine.""" + if self.workers.get(id): + raise "We already have one! This should not happen." + self.workers[id] = IWorker(self.controller.engines[id]) + self.workers[id].workerid = id + if not self.pendingTasks.has_key(id):# if not working + self.scheduler.add_worker(self.workers[id]) + self.distributeTasks() + + def unregisterWorker(self, id): + """Called by controller.unregister_engine""" + + if self.workers.has_key(id): + try: + self.scheduler.pop_worker(id) + except IndexError: + pass + self.workers.pop(id) + + def _pendingTaskIDs(self): + return [t.taskid for t in self.pendingTasks.values()] + + #--------------------------------------------------------------------------- + # Interface methods + #--------------------------------------------------------------------------- + + def run(self, task): + """Run a task and return `Deferred` to its taskid.""" + task.taskid = self.taskid + task.start = time.localtime() + self.taskid += 1 + d = defer.Deferred() + self.scheduler.add_task(task) + # log.msg('Queuing task: %i' % task.taskid) + + self.deferredResults[task.taskid] = [] + self.distributeTasks() + return defer.succeed(task.taskid) + + def get_task_result(self, taskid, block=False): + """Returns a `Deferred` to a TaskResult tuple or None.""" + # log.msg("Getting task result: %i" % taskid) + if self.finishedResults.has_key(taskid): + tr = self.finishedResults[taskid] + return defer.succeed(tr) + elif self.deferredResults.has_key(taskid): + if block: + d = defer.Deferred() + self.deferredResults[taskid].append(d) + return d + else: + return defer.succeed(None) + else: + return defer.fail(IndexError("task ID not registered: %r" % taskid)) + + def abort(self, taskid): + """Remove a task from the queue if it has not been run already.""" + if not isinstance(taskid, int): + return defer.fail(failure.Failure(TypeError("an integer task id expected: %r" % taskid))) + try: + self.scheduler.pop_task(taskid) + except IndexError, e: + if taskid in self.finishedResults.keys(): + d = defer.fail(IndexError("Task Already Completed")) + elif taskid in self.abortPending: + d = defer.fail(IndexError("Task Already Aborted")) + elif taskid in self._pendingTaskIDs():# task is pending + self.abortPending.append(taskid) + d = defer.succeed(None) + else: + d = defer.fail(e) + else: + d = defer.execute(self._doAbort, taskid) + + return d + + def barrier(self, taskids): + dList = [] + if isinstance(taskids, int): + taskids = [taskids] + for id in taskids: + d = self.get_task_result(id, block=True) + dList.append(d) + d = DeferredList(dList, consumeErrors=1) + d.addCallbacks(lambda r: None) + return d + + def spin(self): + return defer.succeed(self.distributeTasks()) + + def queue_status(self, verbose=False): + pending = self._pendingTaskIDs() + failed = [] + succeeded = [] + for k,v in self.finishedResults.iteritems(): + if not isinstance(v, failure.Failure): + if hasattr(v,'failure'): + if v.failure is None: + succeeded.append(k) + else: + failed.append(k) + scheduled = self.scheduler.taskids + if verbose: + result = dict(pending=pending, failed=failed, + succeeded=succeeded, scheduled=scheduled) + else: + result = dict(pending=len(pending),failed=len(failed), + succeeded=len(succeeded),scheduled=len(scheduled)) + return defer.succeed(result) + + #--------------------------------------------------------------------------- + # Queue methods + #--------------------------------------------------------------------------- + + def _doAbort(self, taskid): + """Helper function for aborting a pending task.""" + # log.msg("Task aborted: %i" % taskid) + result = failure.Failure(error.TaskAborted()) + self._finishTask(taskid, result) + if taskid in self.abortPending: + self.abortPending.remove(taskid) + + def _finishTask(self, taskid, result): + dlist = self.deferredResults.pop(taskid) + result.taskid = taskid # The TaskResult should save the taskid + self.finishedResults[taskid] = result + for d in dlist: + d.callback(result) + + def distributeTasks(self): + """Distribute tasks while self.scheduler has things to do.""" + # log.msg("distributing Tasks") + worker, task = self.scheduler.schedule() + if not worker and not task: + if self.idleLater and self.idleLater.called:# we are inside failIdle + self.idleLater = None + else: + self.checkIdle() + return False + # else something to do: + while worker and task: + # get worker and task + # add to pending + self.pendingTasks[worker.workerid] = task + # run/link callbacks + d = worker.run(task) + # log.msg("Running task %i on worker %i" %(task.taskid, worker.workerid)) + d.addBoth(self.taskCompleted, task.taskid, worker.workerid) + worker, task = self.scheduler.schedule() + # check for idle timeout: + self.checkIdle() + return True + + def checkIdle(self): + if self.idleLater and not self.idleLater.called: + self.idleLater.cancel() + if self.scheduler.ntasks and self.workers and \ + self.scheduler.nworkers == len(self.workers): + self.idleLater = reactor.callLater(self.timeout, self.failIdle) + else: + self.idleLater = None + + def failIdle(self): + if not self.distributeTasks(): + while self.scheduler.ntasks: + t = self.scheduler.pop_task() + msg = "task %i failed to execute due to unmet dependencies"%t.taskid + msg += " for %i seconds"%self.timeout + # log.msg("Task aborted by timeout: %i" % t.taskid) + f = failure.Failure(error.TaskTimeout(msg)) + self._finishTask(t.taskid, f) + self.idleLater = None + + + def taskCompleted(self, result, taskid, workerid): + """This is the err/callback for a completed task.""" + try: + task = self.pendingTasks.pop(workerid) + except: + # this should not happen + log.msg("Tried to pop bad pending task %i from worker %i"%(taskid, workerid)) + log.msg("Result: %r"%result) + log.msg("Pending tasks: %s"%self.pendingTasks) + return + + # Check if aborted while pending + aborted = False + if taskid in self.abortPending: + self._doAbort(taskid) + aborted = True + + if not aborted: + if result.failure is not None and isinstance(result.failure, failure.Failure): # we failed + log.msg("Task %i failed on worker %i"% (taskid, workerid)) + if task.retries > 0: # resubmit + task.retries -= 1 + self.scheduler.add_task(task) + s = "Resubmitting task %i, %i retries remaining" %(taskid, task.retries) + log.msg(s) + self.distributeTasks() + elif isinstance(task.recovery_task, Task) and \ + task.recovery_task.retries > -1: + # retries = -1 is to prevent infinite recovery_task loop + task.retries = -1 + task.recovery_task.taskid = taskid + task = task.recovery_task + self.scheduler.add_task(task) + s = "Recovering task %i, %i retries remaining" %(taskid, task.retries) + log.msg(s) + self.distributeTasks() + else: # done trying + self._finishTask(taskid, result) + # wait a second before readmitting a worker that failed + # it may have died, and not yet been unregistered + reactor.callLater(self.failurePenalty, self.readmitWorker, workerid) + else: # we succeeded + # log.msg("Task completed: %i"% taskid) + self._finishTask(taskid, result) + self.readmitWorker(workerid) + else:# we aborted the task + if result.failure is not None and isinstance(result.failure, failure.Failure): # it failed, penalize worker + reactor.callLater(self.failurePenalty, self.readmitWorker, workerid) + else: + self.readmitWorker(workerid) + + def readmitWorker(self, workerid): + """Readmit a worker to the scheduler. + + This is outside `taskCompleted` because of the `failurePenalty` being + implemented through `reactor.callLater`. + """ + + if workerid in self.workers.keys() and workerid not in self.pendingTasks.keys(): + self.scheduler.add_worker(self.workers[workerid]) + self.distributeTasks() + + +components.registerAdapter(TaskController, cs.IControllerBase, ITaskController) diff --git a/IPython/kernel/taskclient.py b/IPython/kernel/taskclient.py new file mode 100644 index 0000000..405407a --- /dev/null +++ b/IPython/kernel/taskclient.py @@ -0,0 +1,161 @@ +# encoding: utf-8 +# -*- test-case-name: IPython.kernel.tests.test_taskcontrollerxmlrpc -*- + +"""The Generic Task Client object. + +This must be subclassed based on your connection method. +""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from zope.interface import Interface, implements +from twisted.python import components, log + +from IPython.kernel.twistedutil import blockingCallFromThread +from IPython.kernel import task, error + +#------------------------------------------------------------------------------- +# Connecting Task Client +#------------------------------------------------------------------------------- + +class InteractiveTaskClient(object): + + def irun(self, *args, **kwargs): + """Run a task on the `TaskController`. + + This method is a shorthand for run(task) and its arguments are simply + passed onto a `Task` object: + + irun(*args, **kwargs) -> run(Task(*args, **kwargs)) + + :Parameters: + expression : str + A str that is valid python code that is the task. + pull : str or list of str + The names of objects to be pulled as results. + push : dict + A dict of objects to be pushed into the engines namespace before + execution of the expression. + clear_before : boolean + Should the engine's namespace be cleared before the task is run. + Default=False. + clear_after : boolean + Should the engine's namespace be cleared after the task is run. + Default=False. + retries : int + The number of times to resumbit the task if it fails. Default=0. + options : dict + Any other keyword options for more elaborate uses of tasks + + :Returns: A `TaskResult` object. + """ + block = kwargs.pop('block', False) + if len(args) == 1 and isinstance(args[0], task.Task): + t = args[0] + else: + t = task.Task(*args, **kwargs) + taskid = self.run(t) + print "TaskID = %i"%taskid + if block: + return self.get_task_result(taskid, block) + else: + return taskid + +class IBlockingTaskClient(Interface): + """ + An interface for blocking task clients. + """ + pass + + +class BlockingTaskClient(InteractiveTaskClient): + """ + This class provides a blocking task client. + """ + + implements(IBlockingTaskClient) + + def __init__(self, task_controller): + self.task_controller = task_controller + self.block = True + + def run(self, task): + """ + Run a task and return a task id that can be used to get the task result. + + :Parameters: + task : `Task` + The `Task` object to run + """ + return blockingCallFromThread(self.task_controller.run, task) + + def get_task_result(self, taskid, block=False): + """ + Get or poll for a task result. + + :Parameters: + taskid : int + The id of the task whose result to get + block : boolean + If True, wait until the task is done and then result the + `TaskResult` object. If False, just poll for the result and + return None if the task is not done. + """ + return blockingCallFromThread(self.task_controller.get_task_result, + taskid, block) + + def abort(self, taskid): + """ + Abort a task by task id if it has not been started. + """ + return blockingCallFromThread(self.task_controller.abort, taskid) + + def barrier(self, taskids): + """ + Wait for a set of tasks to finish. + + :Parameters: + taskids : list of ints + A list of task ids to wait for. + """ + return blockingCallFromThread(self.task_controller.barrier, taskids) + + def spin(self): + """ + Cause the scheduler to schedule tasks. + + This method only needs to be called in unusual situations where the + scheduler is idle for some reason. + """ + return blockingCallFromThread(self.task_controller.spin) + + def queue_status(self, verbose=False): + """ + Get a dictionary with the current state of the task queue. + + :Parameters: + verbose : boolean + If True, return a list of taskids. If False, simply give + the number of tasks with each status. + + :Returns: + A dict with the queue status. + """ + return blockingCallFromThread(self.task_controller.queue_status, verbose) + + +components.registerAdapter(BlockingTaskClient, + task.ITaskController, IBlockingTaskClient) + + diff --git a/IPython/kernel/taskfc.py b/IPython/kernel/taskfc.py new file mode 100644 index 0000000..b4096e7 --- /dev/null +++ b/IPython/kernel/taskfc.py @@ -0,0 +1,267 @@ +# encoding: utf-8 +# -*- test-case-name: IPython.kernel.tests.test_taskxmlrpc -*- +"""A Foolscap interface to a TaskController. + +This class lets Foolscap clients talk to a TaskController. +""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import cPickle as pickle +import xmlrpclib, copy + +from zope.interface import Interface, implements +from twisted.internet import defer +from twisted.python import components, failure + +from foolscap import Referenceable + +from IPython.kernel.twistedutil import blockingCallFromThread +from IPython.kernel import error, task as taskmodule, taskclient +from IPython.kernel.pickleutil import can, uncan +from IPython.kernel.clientinterfaces import ( + IFCClientInterfaceProvider, + IBlockingClientAdaptor +) + +#------------------------------------------------------------------------------- +# The Controller side of things +#------------------------------------------------------------------------------- + + +class IFCTaskController(Interface): + """Foolscap interface to task controller. + + See the documentation of ITaskController for documentation about the methods. + """ + def remote_run(request, binTask): + """""" + + def remote_abort(request, taskid): + """""" + + def remote_get_task_result(request, taskid, block=False): + """""" + + def remote_barrier(request, taskids): + """""" + + def remote_spin(request): + """""" + + def remote_queue_status(request, verbose): + """""" + + +class FCTaskControllerFromTaskController(Referenceable): + """XML-RPC attachmeot for controller. + + See IXMLRPCTaskController and ITaskController (and its children) for documentation. + """ + implements(IFCTaskController, IFCClientInterfaceProvider) + + def __init__(self, taskController): + self.taskController = taskController + + #--------------------------------------------------------------------------- + # Non interface methods + #--------------------------------------------------------------------------- + + def packageFailure(self, f): + f.cleanFailure() + return self.packageSuccess(f) + + def packageSuccess(self, obj): + serial = pickle.dumps(obj, 2) + return serial + + #--------------------------------------------------------------------------- + # ITaskController related methods + #--------------------------------------------------------------------------- + + def remote_run(self, ptask): + try: + ctask = pickle.loads(ptask) + task = taskmodule.uncanTask(ctask) + except: + d = defer.fail(pickle.UnpickleableError("Could not unmarshal task")) + else: + d = self.taskController.run(task) + d.addCallback(self.packageSuccess) + d.addErrback(self.packageFailure) + return d + + def remote_abort(self, taskid): + d = self.taskController.abort(taskid) + d.addCallback(self.packageSuccess) + d.addErrback(self.packageFailure) + return d + + def remote_get_task_result(self, taskid, block=False): + d = self.taskController.get_task_result(taskid, block) + d.addCallback(self.packageSuccess) + d.addErrback(self.packageFailure) + return d + + def remote_barrier(self, taskids): + d = self.taskController.barrier(taskids) + d.addCallback(self.packageSuccess) + d.addErrback(self.packageFailure) + return d + + def remote_spin(self): + d = self.taskController.spin() + d.addCallback(self.packageSuccess) + d.addErrback(self.packageFailure) + return d + + def remote_queue_status(self, verbose): + d = self.taskController.queue_status(verbose) + d.addCallback(self.packageSuccess) + d.addErrback(self.packageFailure) + return d + + def remote_get_client_name(self): + return 'IPython.kernel.taskfc.FCTaskClient' + +components.registerAdapter(FCTaskControllerFromTaskController, + taskmodule.ITaskController, IFCTaskController) + + +#------------------------------------------------------------------------------- +# The Client side of things +#------------------------------------------------------------------------------- + +class FCTaskClient(object): + """XML-RPC based TaskController client that implements ITaskController. + + :Parameters: + addr : (ip, port) + The ip (str) and port (int) tuple of the `TaskController`. + """ + implements(taskmodule.ITaskController, IBlockingClientAdaptor) + + def __init__(self, remote_reference): + self.remote_reference = remote_reference + + #--------------------------------------------------------------------------- + # Non interface methods + #--------------------------------------------------------------------------- + + def unpackage(self, r): + return pickle.loads(r) + + #--------------------------------------------------------------------------- + # ITaskController related methods + #--------------------------------------------------------------------------- + def run(self, task): + """Run a task on the `TaskController`. + + :Parameters: + task : a `Task` object + + The Task object is created using the following signature: + + Task(expression, pull=None, push={}, clear_before=False, + clear_after=False, retries=0, **options):) + + The meaning of the arguments is as follows: + + :Task Parameters: + expression : str + A str that is valid python code that is the task. + pull : str or list of str + The names of objects to be pulled as results. + push : dict + A dict of objects to be pushed into the engines namespace before + execution of the expression. + clear_before : boolean + Should the engine's namespace be cleared before the task is run. + Default=False. + clear_after : boolean + Should the engine's namespace be cleared after the task is run. + Default=False. + retries : int + The number of times to resumbit the task if it fails. Default=0. + options : dict + Any other keyword options for more elaborate uses of tasks + + :Returns: The int taskid of the submitted task. Pass this to + `get_task_result` to get the `TaskResult` object. + """ + assert isinstance(task, taskmodule.Task), "task must be a Task object!" + ctask = taskmodule.canTask(task) # handles arbitrary function in .depend + # as well as arbitrary recovery_task chains + ptask = pickle.dumps(ctask, 2) + d = self.remote_reference.callRemote('run', ptask) + d.addCallback(self.unpackage) + return d + + def get_task_result(self, taskid, block=False): + """The task result by taskid. + + :Parameters: + taskid : int + The taskid of the task to be retrieved. + block : boolean + Should I block until the task is done? + + :Returns: A `TaskResult` object that encapsulates the task result. + """ + d = self.remote_reference.callRemote('get_task_result', taskid, block) + d.addCallback(self.unpackage) + return d + + def abort(self, taskid): + """Abort a task by taskid. + + :Parameters: + taskid : int + The taskid of the task to be aborted. + block : boolean + Should I block until the task is aborted. + """ + d = self.remote_reference.callRemote('abort', taskid) + d.addCallback(self.unpackage) + return d + + def barrier(self, taskids): + """Block until all tasks are completed. + + :Parameters: + taskids : list, tuple + A sequence of taskids to block on. + """ + d = self.remote_reference.callRemote('barrier', taskids) + d.addCallback(self.unpackage) + return d + + def spin(self): + """touch the scheduler, to resume scheduling without submitting + a task. + """ + d = self.remote_reference.callRemote('spin') + d.addCallback(self.unpackage) + return d + + def queue_status(self, verbose=False): + """Return a dict with the status of the task queue.""" + d = self.remote_reference.callRemote('queue_status', verbose) + d.addCallback(self.unpackage) + return d + + def adapt_to_blocking_client(self): + from IPython.kernel.taskclient import IBlockingTaskClient + return IBlockingTaskClient(self) + diff --git a/IPython/kernel/tests/__init__.py b/IPython/kernel/tests/__init__.py new file mode 100644 index 0000000..bef7dcc --- /dev/null +++ b/IPython/kernel/tests/__init__.py @@ -0,0 +1,10 @@ +# encoding: utf-8 + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- diff --git a/IPython/kernel/tests/controllertest.py b/IPython/kernel/tests/controllertest.py new file mode 100644 index 0000000..5ebfaab --- /dev/null +++ b/IPython/kernel/tests/controllertest.py @@ -0,0 +1,102 @@ +# encoding: utf-8 + +"""This file contains unittests for the kernel.engineservice.py module. + +Things that should be tested: + + - Should the EngineService return Deferred objects? + - Run the same tests that are run in shell.py. + - Make sure that the Interface is really implemented. + - The startService and stopService methods. +""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from twisted.internet import defer +import zope.interface as zi + +from IPython.kernel import engineservice as es +from IPython.kernel import error +from IPython.testing.util import DeferredTestCase +from IPython.kernel.controllerservice import \ + IControllerCore + + +class IControllerCoreTestCase(object): + """Tests for objects that implement IControllerCore. + + This test assumes that self.controller is defined and implements + IControllerCore. + """ + + def testIControllerCoreInterface(self): + """Does self.engine claim to implement IEngineCore?""" + self.assert_(IControllerCore.providedBy(self.controller)) + + def testIControllerCoreInterfaceMethods(self): + """Does self.engine have the methods and attributes in IEngireCore.""" + for m in list(IControllerCore): + self.assert_(hasattr(self.controller, m)) + + def testRegisterUnregisterEngine(self): + engine = es.EngineService() + qengine = es.QueuedEngine(engine) + regDict = self.controller.register_engine(qengine, 0) + self.assert_(isinstance(regDict, dict)) + self.assert_(regDict.has_key('id')) + self.assert_(regDict['id']==0) + self.controller.unregister_engine(0) + self.assert_(self.controller.engines.get(0, None) == None) + + def testRegisterUnregisterMultipleEngines(self): + e1 = es.EngineService() + qe1 = es.QueuedEngine(e1) + e2 = es.EngineService() + qe2 = es.QueuedEngine(e2) + rd1 = self.controller.register_engine(qe1, 0) + self.assertEquals(rd1['id'], 0) + rd2 = self.controller.register_engine(qe2, 1) + self.assertEquals(rd2['id'], 1) + self.controller.unregister_engine(0) + rd1 = self.controller.register_engine(qe1, 0) + self.assertEquals(rd1['id'], 0) + self.controller.unregister_engine(1) + rd2 = self.controller.register_engine(qe2, 0) + self.assertEquals(rd2['id'], 1) + self.controller.unregister_engine(0) + self.controller.unregister_engine(1) + self.assertEquals(self.controller.engines,{}) + + def testRegisterCallables(self): + e1 = es.EngineService() + qe1 = es.QueuedEngine(e1) + self.registerCallableCalled = ';lkj' + self.unregisterCallableCalled = ';lkj' + self.controller.on_register_engine_do(self._registerCallable, False) + self.controller.on_unregister_engine_do(self._unregisterCallable, False) + self.controller.register_engine(qe1, 0) + self.assertEquals(self.registerCallableCalled, 'asdf') + self.controller.unregister_engine(0) + self.assertEquals(self.unregisterCallableCalled, 'asdf') + self.controller.on_register_engine_do_not(self._registerCallable) + self.controller.on_unregister_engine_do_not(self._unregisterCallable) + + def _registerCallable(self): + self.registerCallableCalled = 'asdf' + + def _unregisterCallable(self): + self.unregisterCallableCalled = 'asdf' + + def testBadUnregister(self): + self.assertRaises(AssertionError, self.controller.unregister_engine, 'foo') \ No newline at end of file diff --git a/IPython/kernel/tests/engineservicetest.py b/IPython/kernel/tests/engineservicetest.py new file mode 100644 index 0000000..b4a6d21 --- /dev/null +++ b/IPython/kernel/tests/engineservicetest.py @@ -0,0 +1,373 @@ +# encoding: utf-8 + +"""Test template for complete engine object""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import cPickle as pickle + +from twisted.internet import defer, reactor +from twisted.python import failure +from twisted.application import service +import zope.interface as zi + +from IPython.kernel import newserialized +from IPython.kernel import error +from IPython.kernel.pickleutil import can, uncan +import IPython.kernel.engineservice as es +from IPython.kernel.core.interpreter import Interpreter +from IPython.testing.parametric import Parametric, parametric + +#------------------------------------------------------------------------------- +# Tests +#------------------------------------------------------------------------------- + + +# A sequence of valid commands run through execute +validCommands = ['a=5', + 'b=10', + 'a=5; b=10; c=a+b', + 'import math; 2.0*math.pi', + """def f(): + result = 0.0 + for i in range(10): + result += i +""", + 'if 1<2: a=5', + """import time +time.sleep(0.1)""", + """from math import cos; +x = 1.0*cos(0.5)""", # Semicolons lead to Discard ast nodes that should be discarded + """from sets import Set +s = Set() + """, # Trailing whitespace should be allowed. + """import math +math.cos(1.0)""", # Test a method call with a discarded return value + """x=1.0234 +a=5; b=10""", # Test an embedded semicolon + """x=1.0234 +a=5; b=10;""" # Test both an embedded and trailing semicolon + ] + +# A sequence of commands that raise various exceptions +invalidCommands = [('a=1/0',ZeroDivisionError), + ('print v',NameError), + ('l=[];l[0]',IndexError), + ("d={};d['a']",KeyError), + ("assert 1==0",AssertionError), + ("import abababsdbfsbaljasdlja",ImportError), + ("raise Exception()",Exception)] + +def testf(x): + return 2.0*x + +globala = 99 + +def testg(x): + return globala*x + +class IEngineCoreTestCase(object): + """Test an IEngineCore implementer.""" + + def createShell(self): + return Interpreter() + + def catchQueueCleared(self, f): + try: + f.raiseException() + except error.QueueCleared: + pass + + def testIEngineCoreInterface(self): + """Does self.engine claim to implement IEngineCore?""" + self.assert_(es.IEngineCore.providedBy(self.engine)) + + def testIEngineCoreInterfaceMethods(self): + """Does self.engine have the methods and attributes in IEngineCore.""" + for m in list(es.IEngineCore): + self.assert_(hasattr(self.engine, m)) + + def testIEngineCoreDeferreds(self): + d = self.engine.execute('a=5') + d.addCallback(lambda _: self.engine.pull('a')) + d.addCallback(lambda _: self.engine.get_result()) + d.addCallback(lambda _: self.engine.keys()) + d.addCallback(lambda _: self.engine.push(dict(a=10))) + return d + + def runTestExecute(self, cmd): + self.shell = Interpreter() + actual = self.shell.execute(cmd) + def compare(computed): + actual['id'] = computed['id'] + self.assertEquals(actual, computed) + d = self.engine.execute(cmd) + d.addCallback(compare) + return d + + @parametric + def testExecute(cls): + return [(cls.runTestExecute, cmd) for cmd in validCommands] + + def runTestExecuteFailures(self, cmd, exc): + def compare(f): + self.assertRaises(exc, f.raiseException) + d = self.engine.execute(cmd) + d.addErrback(compare) + return d + + @parametric + def testExecuteFailures(cls): + return [(cls.runTestExecuteFailures, cmd, exc) for cmd, exc in invalidCommands] + + def runTestPushPull(self, o): + d = self.engine.push(dict(a=o)) + d.addCallback(lambda r: self.engine.pull('a')) + d.addCallback(lambda r: self.assertEquals(o,r)) + return d + + @parametric + def testPushPull(cls): + objs = [10,"hi there",1.2342354,{"p":(1,2)},None] + return [(cls.runTestPushPull, o) for o in objs] + + def testPullNameError(self): + d = self.engine.push(dict(a=5)) + d.addCallback(lambda _:self.engine.reset()) + d.addCallback(lambda _: self.engine.pull("a")) + d.addErrback(lambda f: self.assertRaises(NameError, f.raiseException)) + return d + + def testPushPullFailures(self): + d = self.engine.pull('a') + d.addErrback(lambda f: self.assertRaises(NameError, f.raiseException)) + d.addCallback(lambda _: self.engine.execute('l = lambda x: x')) + d.addCallback(lambda _: self.engine.pull('l')) + d.addErrback(lambda f: self.assertRaises(pickle.PicklingError, f.raiseException)) + d.addCallback(lambda _: self.engine.push(dict(l=lambda x: x))) + d.addErrback(lambda f: self.assertRaises(pickle.PicklingError, f.raiseException)) + return d + + def testPushPullArray(self): + try: + import numpy + except: + print 'no numpy, ', + return + a = numpy.random.random(1000) + d = self.engine.push(dict(a=a)) + d.addCallback(lambda _: self.engine.pull('a')) + d.addCallback(lambda b: b==a) + d.addCallback(lambda c: c.all()) + return self.assertDeferredEquals(d, True) + + def testPushFunction(self): + + d = self.engine.push_function(dict(f=testf)) + d.addCallback(lambda _: self.engine.execute('result = f(10)')) + d.addCallback(lambda _: self.engine.pull('result')) + d.addCallback(lambda r: self.assertEquals(r, testf(10))) + return d + + def testPullFunction(self): + d = self.engine.push_function(dict(f=testf, g=testg)) + d.addCallback(lambda _: self.engine.pull_function(('f','g'))) + d.addCallback(lambda r: self.assertEquals(r[0](10), testf(10))) + return d + + def testPushFunctionGlobal(self): + """Make sure that pushed functions pick up the user's namespace for globals.""" + d = self.engine.push(dict(globala=globala)) + d.addCallback(lambda _: self.engine.push_function(dict(g=testg))) + d.addCallback(lambda _: self.engine.execute('result = g(10)')) + d.addCallback(lambda _: self.engine.pull('result')) + d.addCallback(lambda r: self.assertEquals(r, testg(10))) + return d + + def testGetResultFailure(self): + d = self.engine.get_result(None) + d.addErrback(lambda f: self.assertRaises(IndexError, f.raiseException)) + d.addCallback(lambda _: self.engine.get_result(10)) + d.addErrback(lambda f: self.assertRaises(IndexError, f.raiseException)) + return d + + def runTestGetResult(self, cmd): + self.shell = Interpreter() + actual = self.shell.execute(cmd) + def compare(computed): + actual['id'] = computed['id'] + self.assertEquals(actual, computed) + d = self.engine.execute(cmd) + d.addCallback(lambda r: self.engine.get_result(r['number'])) + d.addCallback(compare) + return d + + @parametric + def testGetResult(cls): + return [(cls.runTestGetResult, cmd) for cmd in validCommands] + + def testGetResultDefault(self): + cmd = 'a=5' + shell = self.createShell() + shellResult = shell.execute(cmd) + def popit(dikt, key): + dikt.pop(key) + return dikt + d = self.engine.execute(cmd) + d.addCallback(lambda _: self.engine.get_result()) + d.addCallback(lambda r: self.assertEquals(shellResult, popit(r,'id'))) + return d + + def testKeys(self): + d = self.engine.keys() + d.addCallback(lambda s: isinstance(s, list)) + d.addCallback(lambda r: self.assertEquals(r, True)) + return d + +Parametric(IEngineCoreTestCase) + +class IEngineSerializedTestCase(object): + """Test an IEngineCore implementer.""" + + def testIEngineSerializedInterface(self): + """Does self.engine claim to implement IEngineCore?""" + self.assert_(es.IEngineSerialized.providedBy(self.engine)) + + def testIEngineSerializedInterfaceMethods(self): + """Does self.engine have the methods and attributes in IEngireCore.""" + for m in list(es.IEngineSerialized): + self.assert_(hasattr(self.engine, m)) + + def testIEngineSerializedDeferreds(self): + dList = [] + d = self.engine.push_serialized(dict(key=newserialized.serialize(12345))) + self.assert_(isinstance(d, defer.Deferred)) + dList.append(d) + d = self.engine.pull_serialized('key') + self.assert_(isinstance(d, defer.Deferred)) + dList.append(d) + D = defer.DeferredList(dList) + return D + + def testPushPullSerialized(self): + objs = [10,"hi there",1.2342354,{"p":(1,2)}] + d = defer.succeed(None) + for o in objs: + self.engine.push_serialized(dict(key=newserialized.serialize(o))) + value = self.engine.pull_serialized('key') + value.addCallback(lambda serial: newserialized.IUnSerialized(serial).getObject()) + d = self.assertDeferredEquals(value,o,d) + return d + + def testPullSerializedFailures(self): + d = self.engine.pull_serialized('a') + d.addErrback(lambda f: self.assertRaises(NameError, f.raiseException)) + d.addCallback(lambda _: self.engine.execute('l = lambda x: x')) + d.addCallback(lambda _: self.engine.pull_serialized('l')) + d.addErrback(lambda f: self.assertRaises(pickle.PicklingError, f.raiseException)) + return d + +Parametric(IEngineSerializedTestCase) + +class IEngineQueuedTestCase(object): + """Test an IEngineQueued implementer.""" + + def testIEngineQueuedInterface(self): + """Does self.engine claim to implement IEngineQueued?""" + self.assert_(es.IEngineQueued.providedBy(self.engine)) + + def testIEngineQueuedInterfaceMethods(self): + """Does self.engine have the methods and attributes in IEngireQueued.""" + for m in list(es.IEngineQueued): + self.assert_(hasattr(self.engine, m)) + + def testIEngineQueuedDeferreds(self): + dList = [] + d = self.engine.clear_queue() + self.assert_(isinstance(d, defer.Deferred)) + dList.append(d) + d = self.engine.queue_status() + self.assert_(isinstance(d, defer.Deferred)) + dList.append(d) + D = defer.DeferredList(dList) + return D + + def testClearQueue(self): + result = self.engine.clear_queue() + d1 = self.assertDeferredEquals(result, None) + d1.addCallback(lambda _: self.engine.queue_status()) + d2 = self.assertDeferredEquals(d1, {'queue':[], 'pending':'None'}) + return d2 + + def testQueueStatus(self): + result = self.engine.queue_status() + result.addCallback(lambda r: 'queue' in r and 'pending' in r) + d = self.assertDeferredEquals(result, True) + return d + +Parametric(IEngineQueuedTestCase) + +class IEnginePropertiesTestCase(object): + """Test an IEngineProperties implementor.""" + + def testIEnginePropertiesInterface(self): + """Does self.engine claim to implement IEngineProperties?""" + self.assert_(es.IEngineProperties.providedBy(self.engine)) + + def testIEnginePropertiesInterfaceMethods(self): + """Does self.engine have the methods and attributes in IEngireProperties.""" + for m in list(es.IEngineProperties): + self.assert_(hasattr(self.engine, m)) + + def testGetSetProperties(self): + dikt = dict(a=5, b='asdf', c=True, d=None, e=range(5)) + d = self.engine.set_properties(dikt) + d.addCallback(lambda r: self.engine.get_properties()) + d = self.assertDeferredEquals(d, dikt) + d.addCallback(lambda r: self.engine.get_properties(('c',))) + d = self.assertDeferredEquals(d, {'c': dikt['c']}) + d.addCallback(lambda r: self.engine.set_properties(dict(c=False))) + d.addCallback(lambda r: self.engine.get_properties(('c', 'd'))) + d = self.assertDeferredEquals(d, dict(c=False, d=None)) + return d + + def testClearProperties(self): + dikt = dict(a=5, b='asdf', c=True, d=None, e=range(5)) + d = self.engine.set_properties(dikt) + d.addCallback(lambda r: self.engine.clear_properties()) + d.addCallback(lambda r: self.engine.get_properties()) + d = self.assertDeferredEquals(d, {}) + return d + + def testDelHasProperties(self): + dikt = dict(a=5, b='asdf', c=True, d=None, e=range(5)) + d = self.engine.set_properties(dikt) + d.addCallback(lambda r: self.engine.del_properties(('b','e'))) + d.addCallback(lambda r: self.engine.has_properties(('a','b','c','d','e'))) + d = self.assertDeferredEquals(d, [True, False, True, True, False]) + return d + + def testStrictDict(self): + s = """from IPython.kernel.engineservice import get_engine +p = get_engine(%s).properties"""%self.engine.id + d = self.engine.execute(s) + d.addCallback(lambda r: self.engine.execute("p['a'] = lambda _:None")) + d = self.assertDeferredRaises(d, error.InvalidProperty) + d.addCallback(lambda r: self.engine.execute("p['a'] = range(5)")) + d.addCallback(lambda r: self.engine.execute("p['a'].append(5)")) + d.addCallback(lambda r: self.engine.get_properties('a')) + d = self.assertDeferredEquals(d, dict(a=range(5))) + return d + +Parametric(IEnginePropertiesTestCase) diff --git a/IPython/kernel/tests/multienginetest.py b/IPython/kernel/tests/multienginetest.py new file mode 100644 index 0000000..30a2df7 --- /dev/null +++ b/IPython/kernel/tests/multienginetest.py @@ -0,0 +1,838 @@ +# encoding: utf-8 + +"""""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from twisted.internet import defer + +from IPython.kernel import engineservice as es +from IPython.kernel import multiengine as me +from IPython.kernel import newserialized +from IPython.kernel.error import NotDefined +from IPython.testing import util +from IPython.testing.parametric import parametric, Parametric +from IPython.kernel import newserialized +from IPython.kernel.util import printer +from IPython.kernel.error import (InvalidEngineID, + NoEnginesRegistered, + CompositeError, + InvalidDeferredID) +from IPython.kernel.tests.engineservicetest import validCommands, invalidCommands +from IPython.kernel.core.interpreter import Interpreter + + +#------------------------------------------------------------------------------- +# Base classes and utilities +#------------------------------------------------------------------------------- + +class IMultiEngineBaseTestCase(object): + """Basic utilities for working with multiengine tests. + + Some subclass should define: + + * self.multiengine + * self.engines to keep track of engines for clean up""" + + def createShell(self): + return Interpreter() + + def addEngine(self, n=1): + for i in range(n): + e = es.EngineService() + e.startService() + regDict = self.controller.register_engine(es.QueuedEngine(e), None) + e.id = regDict['id'] + self.engines.append(e) + + +def testf(x): + return 2.0*x + + +globala = 99 + + +def testg(x): + return globala*x + + +def isdid(did): + if not isinstance(did, str): + return False + if not len(did)==40: + return False + return True + + +def _raise_it(f): + try: + f.raiseException() + except CompositeError, e: + e.raise_exception() + +#------------------------------------------------------------------------------- +# IMultiEngineTestCase +#------------------------------------------------------------------------------- + +class IMultiEngineTestCase(IMultiEngineBaseTestCase): + """A test for any object that implements IEngineMultiplexer. + + self.multiengine must be defined and implement IEngineMultiplexer. + """ + + def testIMultiEngineInterface(self): + """Does self.engine claim to implement IEngineCore?""" + self.assert_(me.IEngineMultiplexer.providedBy(self.multiengine)) + self.assert_(me.IMultiEngine.providedBy(self.multiengine)) + + def testIEngineMultiplexerInterfaceMethods(self): + """Does self.engine have the methods and attributes in IEngineCore.""" + for m in list(me.IEngineMultiplexer): + self.assert_(hasattr(self.multiengine, m)) + + def testIEngineMultiplexerDeferreds(self): + self.addEngine(1) + d= self.multiengine.execute('a=5', targets=0) + d.addCallback(lambda _: self.multiengine.push(dict(a=5),targets=0)) + d.addCallback(lambda _: self.multiengine.push(dict(a=5, b='asdf', c=[1,2,3]),targets=0)) + d.addCallback(lambda _: self.multiengine.pull(('a','b','c'),targets=0)) + d.addCallback(lambda _: self.multiengine.get_result(targets=0)) + d.addCallback(lambda _: self.multiengine.reset(targets=0)) + d.addCallback(lambda _: self.multiengine.keys(targets=0)) + d.addCallback(lambda _: self.multiengine.push_serialized(dict(a=newserialized.serialize(10)),targets=0)) + d.addCallback(lambda _: self.multiengine.pull_serialized('a',targets=0)) + d.addCallback(lambda _: self.multiengine.clear_queue(targets=0)) + d.addCallback(lambda _: self.multiengine.queue_status(targets=0)) + return d + + def testInvalidEngineID(self): + self.addEngine(1) + badID = 100 + d = self.multiengine.execute('a=5', targets=badID) + d.addErrback(lambda f: self.assertRaises(InvalidEngineID, f.raiseException)) + d.addCallback(lambda _: self.multiengine.push(dict(a=5), targets=badID)) + d.addErrback(lambda f: self.assertRaises(InvalidEngineID, f.raiseException)) + d.addCallback(lambda _: self.multiengine.pull('a', targets=badID)) + d.addErrback(lambda f: self.assertRaises(InvalidEngineID, f.raiseException)) + d.addCallback(lambda _: self.multiengine.reset(targets=badID)) + d.addErrback(lambda f: self.assertRaises(InvalidEngineID, f.raiseException)) + d.addCallback(lambda _: self.multiengine.keys(targets=badID)) + d.addErrback(lambda f: self.assertRaises(InvalidEngineID, f.raiseException)) + d.addCallback(lambda _: self.multiengine.push_serialized(dict(a=newserialized.serialize(10)), targets=badID)) + d.addErrback(lambda f: self.assertRaises(InvalidEngineID, f.raiseException)) + d.addCallback(lambda _: self.multiengine.pull_serialized('a', targets=badID)) + d.addErrback(lambda f: self.assertRaises(InvalidEngineID, f.raiseException)) + d.addCallback(lambda _: self.multiengine.queue_status(targets=badID)) + d.addErrback(lambda f: self.assertRaises(InvalidEngineID, f.raiseException)) + return d + + def testNoEnginesRegistered(self): + badID = 'all' + d= self.multiengine.execute('a=5', targets=badID) + d.addErrback(lambda f: self.assertRaises(NoEnginesRegistered, f.raiseException)) + d.addCallback(lambda _: self.multiengine.push(dict(a=5), targets=badID)) + d.addErrback(lambda f: self.assertRaises(NoEnginesRegistered, f.raiseException)) + d.addCallback(lambda _: self.multiengine.pull('a', targets=badID)) + d.addErrback(lambda f: self.assertRaises(NoEnginesRegistered, f.raiseException)) + d.addCallback(lambda _: self.multiengine.get_result(targets=badID)) + d.addErrback(lambda f: self.assertRaises(NoEnginesRegistered, f.raiseException)) + d.addCallback(lambda _: self.multiengine.reset(targets=badID)) + d.addErrback(lambda f: self.assertRaises(NoEnginesRegistered, f.raiseException)) + d.addCallback(lambda _: self.multiengine.keys(targets=badID)) + d.addErrback(lambda f: self.assertRaises(NoEnginesRegistered, f.raiseException)) + d.addCallback(lambda _: self.multiengine.push_serialized(dict(a=newserialized.serialize(10)), targets=badID)) + d.addErrback(lambda f: self.assertRaises(NoEnginesRegistered, f.raiseException)) + d.addCallback(lambda _: self.multiengine.pull_serialized('a', targets=badID)) + d.addErrback(lambda f: self.assertRaises(NoEnginesRegistered, f.raiseException)) + d.addCallback(lambda _: self.multiengine.queue_status(targets=badID)) + d.addErrback(lambda f: self.assertRaises(NoEnginesRegistered, f.raiseException)) + return d + + def runExecuteAll(self, d, cmd, shell): + actual = shell.execute(cmd) + d.addCallback(lambda _: self.multiengine.execute(cmd)) + def compare(result): + for r in result: + actual['id'] = r['id'] + self.assertEquals(r, actual) + d.addCallback(compare) + + def testExecuteAll(self): + self.addEngine(4) + d= defer.Deferred() + shell = Interpreter() + for cmd in validCommands: + self.runExecuteAll(d, cmd, shell) + d.callback(None) + return d + + # The following two methods show how to do parametrized + # tests. This is really slick! Same is used above. + def runExecuteFailures(self, cmd, exc): + self.addEngine(4) + d= self.multiengine.execute(cmd) + d.addErrback(lambda f: self.assertRaises(exc, _raise_it, f)) + return d + + @parametric + def testExecuteFailures(cls): + return [(cls.runExecuteFailures,cmd,exc) for + cmd,exc in invalidCommands] + + def testPushPull(self): + self.addEngine(1) + objs = [10,"hi there",1.2342354,{"p":(1,2)}] + d= self.multiengine.push(dict(key=objs[0]), targets=0) + d.addCallback(lambda _: self.multiengine.pull('key', targets=0)) + d.addCallback(lambda r: self.assertEquals(r, [objs[0]])) + d.addCallback(lambda _: self.multiengine.push(dict(key=objs[1]), targets=0)) + d.addCallback(lambda _: self.multiengine.pull('key', targets=0)) + d.addCallback(lambda r: self.assertEquals(r, [objs[1]])) + d.addCallback(lambda _: self.multiengine.push(dict(key=objs[2]), targets=0)) + d.addCallback(lambda _: self.multiengine.pull('key', targets=0)) + d.addCallback(lambda r: self.assertEquals(r, [objs[2]])) + d.addCallback(lambda _: self.multiengine.push(dict(key=objs[3]), targets=0)) + d.addCallback(lambda _: self.multiengine.pull('key', targets=0)) + d.addCallback(lambda r: self.assertEquals(r, [objs[3]])) + d.addCallback(lambda _: self.multiengine.reset(targets=0)) + d.addCallback(lambda _: self.multiengine.pull('a', targets=0)) + d.addErrback(lambda f: self.assertRaises(NameError, _raise_it, f)) + d.addCallback(lambda _: self.multiengine.push(dict(a=10,b=20))) + d.addCallback(lambda _: self.multiengine.pull(('a','b'))) + d.addCallback(lambda r: self.assertEquals(r, [[10,20]])) + return d + + def testPushPullAll(self): + self.addEngine(4) + d= self.multiengine.push(dict(a=10)) + d.addCallback(lambda _: self.multiengine.pull('a')) + d.addCallback(lambda r: self.assert_(r==[10,10,10,10])) + d.addCallback(lambda _: self.multiengine.push(dict(a=10, b=20))) + d.addCallback(lambda _: self.multiengine.pull(('a','b'))) + d.addCallback(lambda r: self.assert_(r==4*[[10,20]])) + d.addCallback(lambda _: self.multiengine.push(dict(a=10, b=20), targets=0)) + d.addCallback(lambda _: self.multiengine.pull(('a','b'), targets=0)) + d.addCallback(lambda r: self.assert_(r==[[10,20]])) + d.addCallback(lambda _: self.multiengine.push(dict(a=None, b=None), targets=0)) + d.addCallback(lambda _: self.multiengine.pull(('a','b'), targets=0)) + d.addCallback(lambda r: self.assert_(r==[[None,None]])) + return d + + def testPushPullSerialized(self): + self.addEngine(1) + objs = [10,"hi there",1.2342354,{"p":(1,2)}] + d= self.multiengine.push_serialized(dict(key=newserialized.serialize(objs[0])), targets=0) + d.addCallback(lambda _: self.multiengine.pull_serialized('key', targets=0)) + d.addCallback(lambda serial: newserialized.IUnSerialized(serial[0]).getObject()) + d.addCallback(lambda r: self.assertEquals(r, objs[0])) + d.addCallback(lambda _: self.multiengine.push_serialized(dict(key=newserialized.serialize(objs[1])), targets=0)) + d.addCallback(lambda _: self.multiengine.pull_serialized('key', targets=0)) + d.addCallback(lambda serial: newserialized.IUnSerialized(serial[0]).getObject()) + d.addCallback(lambda r: self.assertEquals(r, objs[1])) + d.addCallback(lambda _: self.multiengine.push_serialized(dict(key=newserialized.serialize(objs[2])), targets=0)) + d.addCallback(lambda _: self.multiengine.pull_serialized('key', targets=0)) + d.addCallback(lambda serial: newserialized.IUnSerialized(serial[0]).getObject()) + d.addCallback(lambda r: self.assertEquals(r, objs[2])) + d.addCallback(lambda _: self.multiengine.push_serialized(dict(key=newserialized.serialize(objs[3])), targets=0)) + d.addCallback(lambda _: self.multiengine.pull_serialized('key', targets=0)) + d.addCallback(lambda serial: newserialized.IUnSerialized(serial[0]).getObject()) + d.addCallback(lambda r: self.assertEquals(r, objs[3])) + d.addCallback(lambda _: self.multiengine.push(dict(a=10,b=range(5)), targets=0)) + d.addCallback(lambda _: self.multiengine.pull_serialized(('a','b'), targets=0)) + d.addCallback(lambda serial: [newserialized.IUnSerialized(s).getObject() for s in serial[0]]) + d.addCallback(lambda r: self.assertEquals(r, [10, range(5)])) + d.addCallback(lambda _: self.multiengine.reset(targets=0)) + d.addCallback(lambda _: self.multiengine.pull_serialized('a', targets=0)) + d.addErrback(lambda f: self.assertRaises(NameError, _raise_it, f)) + return d + + objs = [10,"hi there",1.2342354,{"p":(1,2)}] + d= defer.succeed(None) + for o in objs: + self.multiengine.push_serialized(0, key=newserialized.serialize(o)) + value = self.multiengine.pull_serialized(0, 'key') + value.addCallback(lambda serial: newserialized.IUnSerialized(serial[0]).getObject()) + d = self.assertDeferredEquals(value,o,d) + return d + + def runGetResultAll(self, d, cmd, shell): + actual = shell.execute(cmd) + d.addCallback(lambda _: self.multiengine.execute(cmd)) + d.addCallback(lambda _: self.multiengine.get_result()) + def compare(result): + for r in result: + actual['id'] = r['id'] + self.assertEquals(r, actual) + d.addCallback(compare) + + def testGetResultAll(self): + self.addEngine(4) + d= defer.Deferred() + shell = Interpreter() + for cmd in validCommands: + self.runGetResultAll(d, cmd, shell) + d.callback(None) + return d + + def testGetResultDefault(self): + self.addEngine(1) + target = 0 + cmd = 'a=5' + shell = self.createShell() + shellResult = shell.execute(cmd) + def popit(dikt, key): + dikt.pop(key) + return dikt + d= self.multiengine.execute(cmd, targets=target) + d.addCallback(lambda _: self.multiengine.get_result(targets=target)) + d.addCallback(lambda r: self.assertEquals(shellResult, popit(r[0],'id'))) + return d + + def testGetResultFailure(self): + self.addEngine(1) + d= self.multiengine.get_result(None, targets=0) + d.addErrback(lambda f: self.assertRaises(IndexError, _raise_it, f)) + d.addCallback(lambda _: self.multiengine.get_result(10, targets=0)) + d.addErrback(lambda f: self.assertRaises(IndexError, _raise_it, f)) + return d + + def testPushFunction(self): + self.addEngine(1) + d= self.multiengine.push_function(dict(f=testf), targets=0) + d.addCallback(lambda _: self.multiengine.execute('result = f(10)', targets=0)) + d.addCallback(lambda _: self.multiengine.pull('result', targets=0)) + d.addCallback(lambda r: self.assertEquals(r[0], testf(10))) + d.addCallback(lambda _: self.multiengine.push(dict(globala=globala), targets=0)) + d.addCallback(lambda _: self.multiengine.push_function(dict(g=testg), targets=0)) + d.addCallback(lambda _: self.multiengine.execute('result = g(10)', targets=0)) + d.addCallback(lambda _: self.multiengine.pull('result', targets=0)) + d.addCallback(lambda r: self.assertEquals(r[0], testg(10))) + return d + + def testPullFunction(self): + self.addEngine(1) + d= self.multiengine.push(dict(a=globala), targets=0) + d.addCallback(lambda _: self.multiengine.push_function(dict(f=testf), targets=0)) + d.addCallback(lambda _: self.multiengine.pull_function('f', targets=0)) + d.addCallback(lambda r: self.assertEquals(r[0](10), testf(10))) + d.addCallback(lambda _: self.multiengine.execute("def g(x): return x*x", targets=0)) + d.addCallback(lambda _: self.multiengine.pull_function(('f','g'),targets=0)) + d.addCallback(lambda r: self.assertEquals((r[0][0](10),r[0][1](10)), (testf(10), 100))) + return d + + def testPushFunctionAll(self): + self.addEngine(4) + d= self.multiengine.push_function(dict(f=testf)) + d.addCallback(lambda _: self.multiengine.execute('result = f(10)')) + d.addCallback(lambda _: self.multiengine.pull('result')) + d.addCallback(lambda r: self.assertEquals(r, 4*[testf(10)])) + d.addCallback(lambda _: self.multiengine.push(dict(globala=globala))) + d.addCallback(lambda _: self.multiengine.push_function(dict(testg=testg))) + d.addCallback(lambda _: self.multiengine.execute('result = testg(10)')) + d.addCallback(lambda _: self.multiengine.pull('result')) + d.addCallback(lambda r: self.assertEquals(r, 4*[testg(10)])) + return d + + def testPullFunctionAll(self): + self.addEngine(4) + d= self.multiengine.push_function(dict(f=testf)) + d.addCallback(lambda _: self.multiengine.pull_function('f')) + d.addCallback(lambda r: self.assertEquals([func(10) for func in r], 4*[testf(10)])) + return d + + def testGetIDs(self): + self.addEngine(1) + d= self.multiengine.get_ids() + d.addCallback(lambda r: self.assertEquals(r, [0])) + d.addCallback(lambda _: self.addEngine(3)) + d.addCallback(lambda _: self.multiengine.get_ids()) + d.addCallback(lambda r: self.assertEquals(r, [0,1,2,3])) + return d + + def testClearQueue(self): + self.addEngine(4) + d= self.multiengine.clear_queue() + d.addCallback(lambda r: self.assertEquals(r,4*[None])) + return d + + def testQueueStatus(self): + self.addEngine(4) + d= self.multiengine.queue_status(targets=0) + d.addCallback(lambda r: self.assert_(isinstance(r[0],tuple))) + return d + + def testGetSetProperties(self): + self.addEngine(4) + dikt = dict(a=5, b='asdf', c=True, d=None, e=range(5)) + d= self.multiengine.set_properties(dikt) + d.addCallback(lambda r: self.multiengine.get_properties()) + d.addCallback(lambda r: self.assertEquals(r, 4*[dikt])) + d.addCallback(lambda r: self.multiengine.get_properties(('c',))) + d.addCallback(lambda r: self.assertEquals(r, 4*[{'c': dikt['c']}])) + d.addCallback(lambda r: self.multiengine.set_properties(dict(c=False))) + d.addCallback(lambda r: self.multiengine.get_properties(('c', 'd'))) + d.addCallback(lambda r: self.assertEquals(r, 4*[dict(c=False, d=None)])) + return d + + def testClearProperties(self): + self.addEngine(4) + dikt = dict(a=5, b='asdf', c=True, d=None, e=range(5)) + d= self.multiengine.set_properties(dikt) + d.addCallback(lambda r: self.multiengine.clear_properties()) + d.addCallback(lambda r: self.multiengine.get_properties()) + d.addCallback(lambda r: self.assertEquals(r, 4*[{}])) + return d + + def testDelHasProperties(self): + self.addEngine(4) + dikt = dict(a=5, b='asdf', c=True, d=None, e=range(5)) + d= self.multiengine.set_properties(dikt) + d.addCallback(lambda r: self.multiengine.del_properties(('b','e'))) + d.addCallback(lambda r: self.multiengine.has_properties(('a','b','c','d','e'))) + d.addCallback(lambda r: self.assertEquals(r, 4*[[True, False, True, True, False]])) + return d + +Parametric(IMultiEngineTestCase) + +#------------------------------------------------------------------------------- +# ISynchronousMultiEngineTestCase +#------------------------------------------------------------------------------- + +class ISynchronousMultiEngineTestCase(IMultiEngineBaseTestCase): + + def testISynchronousMultiEngineInterface(self): + """Does self.engine claim to implement IEngineCore?""" + self.assert_(me.ISynchronousEngineMultiplexer.providedBy(self.multiengine)) + self.assert_(me.ISynchronousMultiEngine.providedBy(self.multiengine)) + + def testExecute(self): + self.addEngine(4) + execute = self.multiengine.execute + d= execute('a=5', targets=0, block=True) + d.addCallback(lambda r: self.assert_(len(r)==1)) + d.addCallback(lambda _: execute('b=10')) + d.addCallback(lambda r: self.assert_(len(r)==4)) + d.addCallback(lambda _: execute('c=30', block=False)) + d.addCallback(lambda did: self.assert_(isdid(did))) + d.addCallback(lambda _: execute('d=[0,1,2]', block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.assert_(len(r)==4)) + return d + + def testPushPull(self): + data = dict(a=10, b=1.05, c=range(10), d={'e':(1,2),'f':'hi'}) + self.addEngine(4) + push = self.multiengine.push + pull = self.multiengine.pull + d= push({'data':data}, targets=0) + d.addCallback(lambda r: pull('data', targets=0)) + d.addCallback(lambda r: self.assertEqual(r,[data])) + d.addCallback(lambda _: push({'data':data})) + d.addCallback(lambda r: pull('data')) + d.addCallback(lambda r: self.assertEqual(r,4*[data])) + d.addCallback(lambda _: push({'data':data}, block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda _: pull('data', block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.assertEqual(r,4*[data])) + d.addCallback(lambda _: push(dict(a=10,b=20))) + d.addCallback(lambda _: pull(('a','b'))) + d.addCallback(lambda r: self.assertEquals(r, 4*[[10,20]])) + return d + + def testPushPullFunction(self): + self.addEngine(4) + pushf = self.multiengine.push_function + pullf = self.multiengine.pull_function + push = self.multiengine.push + pull = self.multiengine.pull + execute = self.multiengine.execute + d= pushf({'testf':testf}, targets=0) + d.addCallback(lambda r: pullf('testf', targets=0)) + d.addCallback(lambda r: self.assertEqual(r[0](1.0), testf(1.0))) + d.addCallback(lambda _: execute('r = testf(10)', targets=0)) + d.addCallback(lambda _: pull('r', targets=0)) + d.addCallback(lambda r: self.assertEquals(r[0], testf(10))) + d.addCallback(lambda _: pushf({'testf':testf}, block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda _: pullf('testf', block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.assertEqual(r[0](1.0), testf(1.0))) + d.addCallback(lambda _: execute("def g(x): return x*x", targets=0)) + d.addCallback(lambda _: pullf(('testf','g'),targets=0)) + d.addCallback(lambda r: self.assertEquals((r[0][0](10),r[0][1](10)), (testf(10), 100))) + return d + + def testGetResult(self): + shell = Interpreter() + result1 = shell.execute('a=10') + result1['id'] = 0 + result2 = shell.execute('b=20') + result2['id'] = 0 + execute= self.multiengine.execute + get_result = self.multiengine.get_result + self.addEngine(1) + d= execute('a=10') + d.addCallback(lambda _: get_result()) + d.addCallback(lambda r: self.assertEquals(r[0], result1)) + d.addCallback(lambda _: execute('b=20')) + d.addCallback(lambda _: get_result(1)) + d.addCallback(lambda r: self.assertEquals(r[0], result1)) + d.addCallback(lambda _: get_result(2, block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.assertEquals(r[0], result2)) + return d + + def testResetAndKeys(self): + self.addEngine(1) + + #Blocking mode + d= self.multiengine.push(dict(a=10, b=20, c=range(10)), targets=0) + d.addCallback(lambda _: self.multiengine.keys(targets=0)) + def keys_found(keys): + self.assert_('a' in keys[0]) + self.assert_('b' in keys[0]) + self.assert_('b' in keys[0]) + d.addCallback(keys_found) + d.addCallback(lambda _: self.multiengine.reset(targets=0)) + d.addCallback(lambda _: self.multiengine.keys(targets=0)) + def keys_not_found(keys): + self.assert_('a' not in keys[0]) + self.assert_('b' not in keys[0]) + self.assert_('b' not in keys[0]) + d.addCallback(keys_not_found) + + #Non-blocking mode + d.addCallback(lambda _: self.multiengine.push(dict(a=10, b=20, c=range(10)), targets=0)) + d.addCallback(lambda _: self.multiengine.keys(targets=0, block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + def keys_found(keys): + self.assert_('a' in keys[0]) + self.assert_('b' in keys[0]) + self.assert_('b' in keys[0]) + d.addCallback(keys_found) + d.addCallback(lambda _: self.multiengine.reset(targets=0, block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda _: self.multiengine.keys(targets=0, block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + def keys_not_found(keys): + self.assert_('a' not in keys[0]) + self.assert_('b' not in keys[0]) + self.assert_('b' not in keys[0]) + d.addCallback(keys_not_found) + + return d + + def testPushPullSerialized(self): + self.addEngine(1) + dikt = dict(a=10,b='hi there',c=1.2345,d={'p':(1,2)}) + sdikt = {} + for k,v in dikt.iteritems(): + sdikt[k] = newserialized.serialize(v) + d= self.multiengine.push_serialized(dict(a=sdikt['a']), targets=0) + d.addCallback(lambda _: self.multiengine.pull('a',targets=0)) + d.addCallback(lambda r: self.assertEquals(r[0], dikt['a'])) + d.addCallback(lambda _: self.multiengine.pull_serialized('a', targets=0)) + d.addCallback(lambda serial: newserialized.IUnSerialized(serial[0]).getObject()) + d.addCallback(lambda r: self.assertEquals(r, dikt['a'])) + d.addCallback(lambda _: self.multiengine.push_serialized(sdikt, targets=0)) + d.addCallback(lambda _: self.multiengine.pull_serialized(sdikt.keys(), targets=0)) + d.addCallback(lambda serial: [newserialized.IUnSerialized(s).getObject() for s in serial[0]]) + d.addCallback(lambda r: self.assertEquals(r, dikt.values())) + d.addCallback(lambda _: self.multiengine.reset(targets=0)) + d.addCallback(lambda _: self.multiengine.pull_serialized('a', targets=0)) + d.addErrback(lambda f: self.assertRaises(NameError, _raise_it, f)) + + #Non-blocking mode + d.addCallback(lambda r: self.multiengine.push_serialized(dict(a=sdikt['a']), targets=0, block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda _: self.multiengine.pull('a',targets=0)) + d.addCallback(lambda r: self.assertEquals(r[0], dikt['a'])) + d.addCallback(lambda _: self.multiengine.pull_serialized('a', targets=0, block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda serial: newserialized.IUnSerialized(serial[0]).getObject()) + d.addCallback(lambda r: self.assertEquals(r, dikt['a'])) + d.addCallback(lambda _: self.multiengine.push_serialized(sdikt, targets=0, block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda _: self.multiengine.pull_serialized(sdikt.keys(), targets=0, block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda serial: [newserialized.IUnSerialized(s).getObject() for s in serial[0]]) + d.addCallback(lambda r: self.assertEquals(r, dikt.values())) + d.addCallback(lambda _: self.multiengine.reset(targets=0)) + d.addCallback(lambda _: self.multiengine.pull_serialized('a', targets=0, block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addErrback(lambda f: self.assertRaises(NameError, _raise_it, f)) + return d + + def testClearQueue(self): + self.addEngine(4) + d= self.multiengine.clear_queue() + d.addCallback(lambda r: self.multiengine.clear_queue(block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.assertEquals(r,4*[None])) + return d + + def testQueueStatus(self): + self.addEngine(4) + d= self.multiengine.queue_status(targets=0) + d.addCallback(lambda r: self.assert_(isinstance(r[0],tuple))) + d.addCallback(lambda r: self.multiengine.queue_status(targets=0, block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.assert_(isinstance(r[0],tuple))) + return d + + def testGetIDs(self): + self.addEngine(1) + d= self.multiengine.get_ids() + d.addCallback(lambda r: self.assertEquals(r, [0])) + d.addCallback(lambda _: self.addEngine(3)) + d.addCallback(lambda _: self.multiengine.get_ids()) + d.addCallback(lambda r: self.assertEquals(r, [0,1,2,3])) + return d + + def testGetSetProperties(self): + self.addEngine(4) + dikt = dict(a=5, b='asdf', c=True, d=None, e=range(5)) + d= self.multiengine.set_properties(dikt) + d.addCallback(lambda r: self.multiengine.get_properties()) + d.addCallback(lambda r: self.assertEquals(r, 4*[dikt])) + d.addCallback(lambda r: self.multiengine.get_properties(('c',))) + d.addCallback(lambda r: self.assertEquals(r, 4*[{'c': dikt['c']}])) + d.addCallback(lambda r: self.multiengine.set_properties(dict(c=False))) + d.addCallback(lambda r: self.multiengine.get_properties(('c', 'd'))) + d.addCallback(lambda r: self.assertEquals(r, 4*[dict(c=False, d=None)])) + + #Non-blocking + d.addCallback(lambda r: self.multiengine.set_properties(dikt, block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.multiengine.get_properties(block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.assertEquals(r, 4*[dikt])) + d.addCallback(lambda r: self.multiengine.get_properties(('c',), block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.assertEquals(r, 4*[{'c': dikt['c']}])) + d.addCallback(lambda r: self.multiengine.set_properties(dict(c=False), block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.multiengine.get_properties(('c', 'd'), block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.assertEquals(r, 4*[dict(c=False, d=None)])) + return d + + def testClearProperties(self): + self.addEngine(4) + dikt = dict(a=5, b='asdf', c=True, d=None, e=range(5)) + d= self.multiengine.set_properties(dikt) + d.addCallback(lambda r: self.multiengine.clear_properties()) + d.addCallback(lambda r: self.multiengine.get_properties()) + d.addCallback(lambda r: self.assertEquals(r, 4*[{}])) + + #Non-blocking + d.addCallback(lambda r: self.multiengine.set_properties(dikt, block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.multiengine.clear_properties(block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.multiengine.get_properties(block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.assertEquals(r, 4*[{}])) + return d + + def testDelHasProperties(self): + self.addEngine(4) + dikt = dict(a=5, b='asdf', c=True, d=None, e=range(5)) + d= self.multiengine.set_properties(dikt) + d.addCallback(lambda r: self.multiengine.del_properties(('b','e'))) + d.addCallback(lambda r: self.multiengine.has_properties(('a','b','c','d','e'))) + d.addCallback(lambda r: self.assertEquals(r, 4*[[True, False, True, True, False]])) + + #Non-blocking + d.addCallback(lambda r: self.multiengine.set_properties(dikt, block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.multiengine.del_properties(('b','e'), block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.multiengine.has_properties(('a','b','c','d','e'), block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.assertEquals(r, 4*[[True, False, True, True, False]])) + return d + + def test_clear_pending_deferreds(self): + self.addEngine(4) + did_list = [] + d= self.multiengine.execute('a=10',block=False) + d.addCallback(lambda did: did_list.append(did)) + d.addCallback(lambda _: self.multiengine.push(dict(b=10),block=False)) + d.addCallback(lambda did: did_list.append(did)) + d.addCallback(lambda _: self.multiengine.pull(('a','b'),block=False)) + d.addCallback(lambda did: did_list.append(did)) + d.addCallback(lambda _: self.multiengine.clear_pending_deferreds()) + d.addCallback(lambda _: self.multiengine.get_pending_deferred(did_list[0],True)) + d.addErrback(lambda f: self.assertRaises(InvalidDeferredID, f.raiseException)) + d.addCallback(lambda _: self.multiengine.get_pending_deferred(did_list[1],True)) + d.addErrback(lambda f: self.assertRaises(InvalidDeferredID, f.raiseException)) + d.addCallback(lambda _: self.multiengine.get_pending_deferred(did_list[2],True)) + d.addErrback(lambda f: self.assertRaises(InvalidDeferredID, f.raiseException)) + return d + +#------------------------------------------------------------------------------- +# Coordinator test cases +#------------------------------------------------------------------------------- + +class IMultiEngineCoordinatorTestCase(object): + + def testScatterGather(self): + self.addEngine(4) + d= self.multiengine.scatter('a', range(16)) + d.addCallback(lambda r: self.multiengine.gather('a')) + d.addCallback(lambda r: self.assertEquals(r, range(16))) + d.addCallback(lambda _: self.multiengine.gather('asdf')) + d.addErrback(lambda f: self.assertRaises(NameError, _raise_it, f)) + return d + + def testScatterGatherNumpy(self): + try: + import numpy + from numpy.testing.utils import assert_array_equal, assert_array_almost_equal + except: + return + else: + self.addEngine(4) + a = numpy.arange(16) + d = self.multiengine.scatter('a', a) + d.addCallback(lambda r: self.multiengine.gather('a')) + d.addCallback(lambda r: assert_array_equal(r, a)) + return d + + def testMap(self): + self.addEngine(4) + def f(x): + return x**2 + data = range(16) + d= self.multiengine.map(f, data) + d.addCallback(lambda r: self.assertEquals(r,[f(x) for x in data])) + return d + + +class ISynchronousMultiEngineCoordinatorTestCase(IMultiEngineCoordinatorTestCase): + + def testScatterGatherNonblocking(self): + self.addEngine(4) + d= self.multiengine.scatter('a', range(16), block=False) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.multiengine.gather('a', block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.assertEquals(r, range(16))) + return d + + def testScatterGatherNumpyNonblocking(self): + try: + import numpy + from numpy.testing.utils import assert_array_equal, assert_array_almost_equal + except: + return + else: + self.addEngine(4) + a = numpy.arange(16) + d = self.multiengine.scatter('a', a, block=False) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.multiengine.gather('a', block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: assert_array_equal(r, a)) + return d + + def testMapNonblocking(self): + self.addEngine(4) + def f(x): + return x**2 + data = range(16) + d= self.multiengine.map(f, data, block=False) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.assertEquals(r,[f(x) for x in data])) + return d + + def test_clear_pending_deferreds(self): + self.addEngine(4) + did_list = [] + d= self.multiengine.scatter('a',range(16),block=False) + d.addCallback(lambda did: did_list.append(did)) + d.addCallback(lambda _: self.multiengine.gather('a',block=False)) + d.addCallback(lambda did: did_list.append(did)) + d.addCallback(lambda _: self.multiengine.map(lambda x: x, range(16),block=False)) + d.addCallback(lambda did: did_list.append(did)) + d.addCallback(lambda _: self.multiengine.clear_pending_deferreds()) + d.addCallback(lambda _: self.multiengine.get_pending_deferred(did_list[0],True)) + d.addErrback(lambda f: self.assertRaises(InvalidDeferredID, f.raiseException)) + d.addCallback(lambda _: self.multiengine.get_pending_deferred(did_list[1],True)) + d.addErrback(lambda f: self.assertRaises(InvalidDeferredID, f.raiseException)) + d.addCallback(lambda _: self.multiengine.get_pending_deferred(did_list[2],True)) + d.addErrback(lambda f: self.assertRaises(InvalidDeferredID, f.raiseException)) + return d + +#------------------------------------------------------------------------------- +# Extras test cases +#------------------------------------------------------------------------------- + +class IMultiEngineExtrasTestCase(object): + + def testZipPull(self): + self.addEngine(4) + d= self.multiengine.push(dict(a=10,b=20)) + d.addCallback(lambda r: self.multiengine.zip_pull(('a','b'))) + d.addCallback(lambda r: self.assert_(r, [4*[10],4*[20]])) + return d + + def testRun(self): + self.addEngine(4) + import tempfile + fname = tempfile.mktemp('foo.py') + f= open(fname, 'w') + f.write('a = 10\nb=30') + f.close() + d= self.multiengine.run(fname) + d.addCallback(lambda r: self.multiengine.pull(('a','b'))) + d.addCallback(lambda r: self.assertEquals(r, 4*[[10,30]])) + return d + + +class ISynchronousMultiEngineExtrasTestCase(IMultiEngineExtrasTestCase): + + def testZipPullNonblocking(self): + self.addEngine(4) + d= self.multiengine.push(dict(a=10,b=20)) + d.addCallback(lambda r: self.multiengine.zip_pull(('a','b'), block=False)) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.assert_(r, [4*[10],4*[20]])) + return d + + def testRunNonblocking(self): + self.addEngine(4) + import tempfile + fname = tempfile.mktemp('foo.py') + f= open(fname, 'w') + f.write('a = 10\nb=30') + f.close() + d= self.multiengine.run(fname, block=False) + d.addCallback(lambda did: self.multiengine.get_pending_deferred(did, True)) + d.addCallback(lambda r: self.multiengine.pull(('a','b'))) + d.addCallback(lambda r: self.assertEquals(r, 4*[[10,30]])) + return d + + +#------------------------------------------------------------------------------- +# IFullSynchronousMultiEngineTestCase +#------------------------------------------------------------------------------- + +class IFullSynchronousMultiEngineTestCase(ISynchronousMultiEngineTestCase, + ISynchronousMultiEngineCoordinatorTestCase, + ISynchronousMultiEngineExtrasTestCase): + pass diff --git a/IPython/kernel/tests/tasktest.py b/IPython/kernel/tests/tasktest.py new file mode 100644 index 0000000..10a0a35 --- /dev/null +++ b/IPython/kernel/tests/tasktest.py @@ -0,0 +1,158 @@ +#!/usr/bin/env python +# encoding: utf-8 + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import time + +from IPython.kernel import task, engineservice as es +from IPython.kernel.util import printer +from IPython.kernel import error + +#------------------------------------------------------------------------------- +# Tests +#------------------------------------------------------------------------------- + +def _raise_it(f): + try: + f.raiseException() + except CompositeError, e: + e.raise_exception() + +class TaskTestBase(object): + + def addEngine(self, n=1): + for i in range(n): + e = es.EngineService() + e.startService() + regDict = self.controller.register_engine(es.QueuedEngine(e), None) + e.id = regDict['id'] + self.engines.append(e) + + +class ITaskControllerTestCase(TaskTestBase): + + def testTaskIDs(self): + self.addEngine(1) + d = self.tc.run(task.Task('a=5')) + d.addCallback(lambda r: self.assertEquals(r, 0)) + d.addCallback(lambda r: self.tc.run(task.Task('a=5'))) + d.addCallback(lambda r: self.assertEquals(r, 1)) + d.addCallback(lambda r: self.tc.run(task.Task('a=5'))) + d.addCallback(lambda r: self.assertEquals(r, 2)) + d.addCallback(lambda r: self.tc.run(task.Task('a=5'))) + d.addCallback(lambda r: self.assertEquals(r, 3)) + return d + + def testAbort(self): + """Cannot do a proper abort test, because blocking execution prevents + abort from being called before task completes""" + self.addEngine(1) + t = task.Task('a=5') + d = self.tc.abort(0) + d.addErrback(lambda f: self.assertRaises(IndexError, f.raiseException)) + d.addCallback(lambda _:self.tc.run(t)) + d.addCallback(self.tc.abort) + d.addErrback(lambda f: self.assertRaises(IndexError, f.raiseException)) + return d + + def testAbortType(self): + self.addEngine(1) + d = self.tc.abort('asdfadsf') + d.addErrback(lambda f: self.assertRaises(TypeError, f.raiseException)) + return d + + def testClears(self): + self.addEngine(1) + t = task.Task('a=1', clear_before=True, pull='b', clear_after=True) + d = self.multiengine.execute('b=1', targets=0) + d.addCallback(lambda _: self.tc.run(t)) + d.addCallback(lambda tid: self.tc.get_task_result(tid,block=True)) + d.addCallback(lambda tr: tr.failure) + d.addErrback(lambda f: self.assertRaises(NameError, f.raiseException)) + d.addCallback(lambda _:self.multiengine.pull('a', targets=0)) + d.addErrback(lambda f: self.assertRaises(NameError, _raise_it, f)) + return d + + def testSimpleRetries(self): + self.addEngine(1) + t = task.Task("i += 1\nassert i == 16", pull='i',retries=10) + t2 = task.Task("i += 1\nassert i == 16", pull='i',retries=10) + d = self.multiengine.execute('i=0', targets=0) + d.addCallback(lambda r: self.tc.run(t)) + d.addCallback(self.tc.get_task_result, block=True) + d.addCallback(lambda tr: tr.ns.i) + d.addErrback(lambda f: self.assertRaises(AssertionError, f.raiseException)) + + d.addCallback(lambda r: self.tc.run(t2)) + d.addCallback(self.tc.get_task_result, block=True) + d.addCallback(lambda tr: tr.ns.i) + d.addCallback(lambda r: self.assertEquals(r, 16)) + return d + + def testRecoveryTasks(self): + self.addEngine(1) + t = task.Task("i=16", pull='i') + t2 = task.Task("raise Exception", recovery_task=t, retries = 2) + + d = self.tc.run(t2) + d.addCallback(self.tc.get_task_result, block=True) + d.addCallback(lambda tr: tr.ns.i) + d.addCallback(lambda r: self.assertEquals(r, 16)) + return d + + # def testInfiniteRecoveryLoop(self): + # self.addEngine(1) + # t = task.Task("raise Exception", retries = 5) + # t2 = task.Task("assert True", retries = 2, recovery_task = t) + # t.recovery_task = t2 + # + # d = self.tc.run(t) + # d.addCallback(self.tc.get_task_result, block=True) + # d.addCallback(lambda tr: tr.ns.i) + # d.addBoth(printer) + # d.addErrback(lambda f: self.assertRaises(AssertionError, f.raiseException)) + # return d + # + def testSetupNS(self): + self.addEngine(1) + d = self.multiengine.execute('a=0', targets=0) + ns = dict(a=1, b=0) + t = task.Task("", push=ns, pull=['a','b']) + d.addCallback(lambda r: self.tc.run(t)) + d.addCallback(self.tc.get_task_result, block=True) + d.addCallback(lambda tr: {'a':tr.ns.a, 'b':tr['b']}) + d.addCallback(lambda r: self.assertEquals(r, ns)) + return d + + def testTaskResults(self): + self.addEngine(1) + t1 = task.Task('a=5', pull='a') + d = self.tc.run(t1) + d.addCallback(self.tc.get_task_result, block=True) + d.addCallback(lambda tr: (tr.ns.a,tr['a'],tr.failure, tr.raiseException())) + d.addCallback(lambda r: self.assertEquals(r, (5,5,None,None))) + + t2 = task.Task('7=5') + d.addCallback(lambda r: self.tc.run(t2)) + d.addCallback(self.tc.get_task_result, block=True) + d.addCallback(lambda tr: tr.ns) + d.addErrback(lambda f: self.assertRaises(SyntaxError, f.raiseException)) + + t3 = task.Task('', pull='b') + d.addCallback(lambda r: self.tc.run(t3)) + d.addCallback(self.tc.get_task_result, block=True) + d.addCallback(lambda tr: tr.ns) + d.addErrback(lambda f: self.assertRaises(NameError, f.raiseException)) + return d diff --git a/IPython/kernel/tests/test_controllerservice.py b/IPython/kernel/tests/test_controllerservice.py new file mode 100644 index 0000000..58e27f2 --- /dev/null +++ b/IPython/kernel/tests/test_controllerservice.py @@ -0,0 +1,43 @@ +# encoding: utf-8 + +"""This file contains unittests for the kernel.engineservice.py module. + +Things that should be tested: + + - Should the EngineService return Deferred objects? + - Run the same tests that are run in shell.py. + - Make sure that the Interface is really implemented. + - The startService and stopService methods. +""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +try: + from twisted.application.service import IService + from IPython.kernel.controllerservice import ControllerService + from IPython.kernel.tests import multienginetest as met + from controllertest import IControllerCoreTestCase + from IPython.testing.util import DeferredTestCase +except ImportError: + pass +else: + class BasicControllerServiceTest(DeferredTestCase, + IControllerCoreTestCase): + + def setUp(self): + self.controller = ControllerService() + self.controller.startService() + + def tearDown(self): + self.controller.stopService() diff --git a/IPython/kernel/tests/test_enginefc.py b/IPython/kernel/tests/test_enginefc.py new file mode 100644 index 0000000..be99c4c --- /dev/null +++ b/IPython/kernel/tests/test_enginefc.py @@ -0,0 +1,92 @@ +# encoding: utf-8 + +"""This file contains unittests for the enginepb.py module.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +try: + from twisted.python import components + from twisted.internet import reactor, defer + from twisted.spread import pb + from twisted.internet.base import DelayedCall + DelayedCall.debug = True + + import zope.interface as zi + + from IPython.kernel.fcutil import Tub, UnauthenticatedTub + from IPython.kernel import engineservice as es + from IPython.testing.util import DeferredTestCase + from IPython.kernel.controllerservice import IControllerBase + from IPython.kernel.enginefc import FCRemoteEngineRefFromService, IEngineBase + from IPython.kernel.engineservice import IEngineQueued + from IPython.kernel.engineconnector import EngineConnector + + from IPython.kernel.tests.engineservicetest import \ + IEngineCoreTestCase, \ + IEngineSerializedTestCase, \ + IEngineQueuedTestCase +except ImportError: + print "we got an error!!!" + pass +else: + class EngineFCTest(DeferredTestCase, + IEngineCoreTestCase, + IEngineSerializedTestCase, + IEngineQueuedTestCase + ): + + zi.implements(IControllerBase) + + def setUp(self): + + # Start a server and append to self.servers + self.controller_reference = FCRemoteEngineRefFromService(self) + self.controller_tub = Tub() + self.controller_tub.listenOn('tcp:10105:interface=127.0.0.1') + self.controller_tub.setLocation('127.0.0.1:10105') + + furl = self.controller_tub.registerReference(self.controller_reference) + self.controller_tub.startService() + + # Start an EngineService and append to services/client + self.engine_service = es.EngineService() + self.engine_service.startService() + self.engine_tub = Tub() + self.engine_tub.startService() + engine_connector = EngineConnector(self.engine_tub) + d = engine_connector.connect_to_controller(self.engine_service, furl) + # This deferred doesn't fire until after register_engine has returned and + # thus, self.engine has been defined and the tets can proceed. + return d + + def tearDown(self): + dlist = [] + # Shut down the engine + d = self.engine_tub.stopService() + dlist.append(d) + # Shut down the controller + d = self.controller_tub.stopService() + dlist.append(d) + return defer.DeferredList(dlist) + + #--------------------------------------------------------------------------- + # Make me look like a basic controller + #--------------------------------------------------------------------------- + + def register_engine(self, engine_ref, id=None, ip=None, port=None, pid=None): + self.engine = IEngineQueued(IEngineBase(engine_ref)) + return {'id':id} + + def unregister_engine(self, id): + pass \ No newline at end of file diff --git a/IPython/kernel/tests/test_engineservice.py b/IPython/kernel/tests/test_engineservice.py new file mode 100644 index 0000000..40c047f --- /dev/null +++ b/IPython/kernel/tests/test_engineservice.py @@ -0,0 +1,66 @@ +# encoding: utf-8 + +"""This file contains unittests for the kernel.engineservice.py module. + +Things that should be tested: + + - Should the EngineService return Deferred objects? + - Run the same tests that are run in shell.py. + - Make sure that the Interface is really implemented. + - The startService and stopService methods. +""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +try: + from twisted.internet import defer + from twisted.application.service import IService + + from IPython.kernel import engineservice as es + from IPython.testing.util import DeferredTestCase + from IPython.kernel.tests.engineservicetest import \ + IEngineCoreTestCase, \ + IEngineSerializedTestCase, \ + IEngineQueuedTestCase, \ + IEnginePropertiesTestCase +except ImportError: + pass +else: + class BasicEngineServiceTest(DeferredTestCase, + IEngineCoreTestCase, + IEngineSerializedTestCase, + IEnginePropertiesTestCase): + + def setUp(self): + self.engine = es.EngineService() + self.engine.startService() + + def tearDown(self): + return self.engine.stopService() + + class QueuedEngineServiceTest(DeferredTestCase, + IEngineCoreTestCase, + IEngineSerializedTestCase, + IEnginePropertiesTestCase, + IEngineQueuedTestCase): + + def setUp(self): + self.rawEngine = es.EngineService() + self.rawEngine.startService() + self.engine = es.IEngineQueued(self.rawEngine) + + def tearDown(self): + return self.rawEngine.stopService() + + diff --git a/IPython/kernel/tests/test_multiengine.py b/IPython/kernel/tests/test_multiengine.py new file mode 100644 index 0000000..97510f2 --- /dev/null +++ b/IPython/kernel/tests/test_multiengine.py @@ -0,0 +1,54 @@ +# encoding: utf-8 + +"""""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +try: + from twisted.internet import defer + from IPython.testing.util import DeferredTestCase + from IPython.kernel.controllerservice import ControllerService + from IPython.kernel import multiengine as me + from IPython.kernel.tests.multienginetest import (IMultiEngineTestCase, + ISynchronousMultiEngineTestCase) +except ImportError: + pass +else: + class BasicMultiEngineTestCase(DeferredTestCase, IMultiEngineTestCase): + + def setUp(self): + self.controller = ControllerService() + self.controller.startService() + self.multiengine = me.IMultiEngine(self.controller) + self.engines = [] + + def tearDown(self): + self.controller.stopService() + for e in self.engines: + e.stopService() + + + class SynchronousMultiEngineTestCase(DeferredTestCase, ISynchronousMultiEngineTestCase): + + def setUp(self): + self.controller = ControllerService() + self.controller.startService() + self.multiengine = me.ISynchronousMultiEngine(me.IMultiEngine(self.controller)) + self.engines = [] + + def tearDown(self): + self.controller.stopService() + for e in self.engines: + e.stopService() + diff --git a/IPython/kernel/tests/test_multienginefc.py b/IPython/kernel/tests/test_multienginefc.py new file mode 100644 index 0000000..610ada1 --- /dev/null +++ b/IPython/kernel/tests/test_multienginefc.py @@ -0,0 +1,70 @@ +#!/usr/bin/env python +# encoding: utf-8 + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +try: + from twisted.internet import defer, reactor + + from IPython.kernel.fcutil import Tub, UnauthenticatedTub + + from IPython.testing.util import DeferredTestCase + from IPython.kernel.controllerservice import ControllerService + from IPython.kernel.multiengine import IMultiEngine + from IPython.kernel.tests.multienginetest import IFullSynchronousMultiEngineTestCase + from IPython.kernel.multienginefc import IFCSynchronousMultiEngine + from IPython.kernel import multiengine as me + from IPython.kernel.clientconnector import ClientConnector +except ImportError: + pass +else: + class FullSynchronousMultiEngineTestCase(DeferredTestCase, IFullSynchronousMultiEngineTestCase): + + def setUp(self): + + self.engines = [] + + self.controller = ControllerService() + self.controller.startService() + self.imultiengine = IMultiEngine(self.controller) + self.mec_referenceable = IFCSynchronousMultiEngine(self.imultiengine) + + self.controller_tub = Tub() + self.controller_tub.listenOn('tcp:10105:interface=127.0.0.1') + self.controller_tub.setLocation('127.0.0.1:10105') + + furl = self.controller_tub.registerReference(self.mec_referenceable) + self.controller_tub.startService() + + self.client_tub = ClientConnector() + d = self.client_tub.get_multiengine_client(furl) + d.addCallback(self.handle_got_client) + return d + + def handle_got_client(self, client): + self.multiengine = client + + def tearDown(self): + dlist = [] + # Shut down the multiengine client + d = self.client_tub.tub.stopService() + dlist.append(d) + # Shut down the engines + for e in self.engines: + e.stopService() + # Shut down the controller + d = self.controller_tub.stopService() + d.addBoth(lambda _: self.controller.stopService()) + dlist.append(d) + return defer.DeferredList(dlist) diff --git a/IPython/kernel/tests/test_newserialized.py b/IPython/kernel/tests/test_newserialized.py new file mode 100644 index 0000000..09de5a6 --- /dev/null +++ b/IPython/kernel/tests/test_newserialized.py @@ -0,0 +1,102 @@ +# encoding: utf-8 + +"""This file contains unittests for the shell.py module.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +try: + import zope.interface as zi + from twisted.trial import unittest + from IPython.testing.util import DeferredTestCase + + from IPython.kernel.newserialized import \ + ISerialized, \ + IUnSerialized, \ + Serialized, \ + UnSerialized, \ + SerializeIt, \ + UnSerializeIt +except ImportError: + pass +else: + #------------------------------------------------------------------------------- + # Tests + #------------------------------------------------------------------------------- + + class SerializedTestCase(unittest.TestCase): + + def setUp(self): + pass + + def tearDown(self): + pass + + def testSerializedInterfaces(self): + + us = UnSerialized({'a':10, 'b':range(10)}) + s = ISerialized(us) + uss = IUnSerialized(s) + self.assert_(ISerialized.providedBy(s)) + self.assert_(IUnSerialized.providedBy(us)) + self.assert_(IUnSerialized.providedBy(uss)) + for m in list(ISerialized): + self.assert_(hasattr(s, m)) + for m in list(IUnSerialized): + self.assert_(hasattr(us, m)) + for m in list(IUnSerialized): + self.assert_(hasattr(uss, m)) + + def testPickleSerialized(self): + obj = {'a':1.45345, 'b':'asdfsdf', 'c':10000L} + original = UnSerialized(obj) + originalSer = ISerialized(original) + firstData = originalSer.getData() + firstTD = originalSer.getTypeDescriptor() + firstMD = originalSer.getMetadata() + self.assert_(firstTD == 'pickle') + self.assert_(firstMD == {}) + unSerialized = IUnSerialized(originalSer) + secondObj = unSerialized.getObject() + for k, v in secondObj.iteritems(): + self.assert_(obj[k] == v) + secondSer = ISerialized(UnSerialized(secondObj)) + self.assert_(firstData == secondSer.getData()) + self.assert_(firstTD == secondSer.getTypeDescriptor() ) + self.assert_(firstMD == secondSer.getMetadata()) + + def testNDArraySerialized(self): + try: + import numpy + except ImportError: + pass + else: + a = numpy.linspace(0.0, 1.0, 1000) + unSer1 = UnSerialized(a) + ser1 = ISerialized(unSer1) + td = ser1.getTypeDescriptor() + self.assert_(td == 'ndarray') + md = ser1.getMetadata() + self.assert_(md['shape'] == a.shape) + self.assert_(md['dtype'] == a.dtype.str) + buff = ser1.getData() + self.assert_(buff == numpy.getbuffer(a)) + s = Serialized(buff, td, md) + us = IUnSerialized(s) + final = us.getObject() + self.assert_(numpy.getbuffer(a) == numpy.getbuffer(final)) + self.assert_(a.dtype.str == final.dtype.str) + self.assert_(a.shape == final.shape) + + + \ No newline at end of file diff --git a/IPython/kernel/tests/test_pendingdeferred.py b/IPython/kernel/tests/test_pendingdeferred.py new file mode 100644 index 0000000..52e0bc0 --- /dev/null +++ b/IPython/kernel/tests/test_pendingdeferred.py @@ -0,0 +1,218 @@ +#!/usr/bin/env python +# encoding: utf-8 + +"""Tests for pendingdeferred.py""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +try: + from twisted.internet import defer + from twisted.python import failure + + from IPython.testing import tcommon + from IPython.testing.tcommon import * + from IPython.testing.util import DeferredTestCase + import IPython.kernel.pendingdeferred as pd + from IPython.kernel import error + from IPython.kernel.util import printer +except ImportError: + pass +else: + + #------------------------------------------------------------------------------- + # Setup for inline and standalone doctests + #------------------------------------------------------------------------------- + + + # If you have standalone doctests in a separate file, set their names in the + # dt_files variable (as a single string or a list thereof): + dt_files = [] + + # If you have any modules whose docstrings should be scanned for embedded tests + # as examples accorging to standard doctest practice, set them here (as a + # single string or a list thereof): + dt_modules = [] + + #------------------------------------------------------------------------------- + # Regular Unittests + #------------------------------------------------------------------------------- + + + class Foo(object): + + def bar(self, bahz): + return defer.succeed('blahblah: %s' % bahz) + + class TwoPhaseFoo(pd.PendingDeferredManager): + + def __init__(self, foo): + self.foo = foo + pd.PendingDeferredManager.__init__(self) + + @pd.two_phase + def bar(self, bahz): + return self.foo.bar(bahz) + + class PendingDeferredManagerTest(DeferredTestCase): + + def setUp(self): + self.pdm = pd.PendingDeferredManager() + + def tearDown(self): + pass + + def testBasic(self): + dDict = {} + # Create 10 deferreds and save them + for i in range(10): + d = defer.Deferred() + did = self.pdm.save_pending_deferred(d) + dDict[did] = d + # Make sure they are begin saved + for k in dDict.keys(): + self.assert_(self.pdm.quick_has_id(k)) + # Get the pending deferred (block=True), then callback with 'foo' and compare + for did in dDict.keys()[0:5]: + d = self.pdm.get_pending_deferred(did,block=True) + dDict[did].callback('foo') + d.addCallback(lambda r: self.assert_(r=='foo')) + # Get the pending deferreds with (block=False) and make sure ResultNotCompleted is raised + for did in dDict.keys()[5:10]: + d = self.pdm.get_pending_deferred(did,block=False) + d.addErrback(lambda f: self.assertRaises(error.ResultNotCompleted, f.raiseException)) + # Now callback the last 5, get them and compare. + for did in dDict.keys()[5:10]: + dDict[did].callback('foo') + d = self.pdm.get_pending_deferred(did,block=False) + d.addCallback(lambda r: self.assert_(r=='foo')) + + def test_save_then_delete(self): + d = defer.Deferred() + did = self.pdm.save_pending_deferred(d) + self.assert_(self.pdm.quick_has_id(did)) + self.pdm.delete_pending_deferred(did) + self.assert_(not self.pdm.quick_has_id(did)) + + def test_save_get_delete(self): + d = defer.Deferred() + did = self.pdm.save_pending_deferred(d) + d2 = self.pdm.get_pending_deferred(did,True) + d2.addErrback(lambda f: self.assertRaises(error.AbortedPendingDeferredError, f.raiseException)) + self.pdm.delete_pending_deferred(did) + return d2 + + def test_double_get(self): + d = defer.Deferred() + did = self.pdm.save_pending_deferred(d) + d2 = self.pdm.get_pending_deferred(did,True) + d3 = self.pdm.get_pending_deferred(did,True) + d3.addErrback(lambda f: self.assertRaises(error.InvalidDeferredID, f.raiseException)) + + def test_get_after_callback(self): + d = defer.Deferred() + did = self.pdm.save_pending_deferred(d) + d.callback('foo') + d2 = self.pdm.get_pending_deferred(did,True) + d2.addCallback(lambda r: self.assertEquals(r,'foo')) + self.assert_(not self.pdm.quick_has_id(did)) + + def test_get_before_callback(self): + d = defer.Deferred() + did = self.pdm.save_pending_deferred(d) + d2 = self.pdm.get_pending_deferred(did,True) + d.callback('foo') + d2.addCallback(lambda r: self.assertEquals(r,'foo')) + self.assert_(not self.pdm.quick_has_id(did)) + d = defer.Deferred() + did = self.pdm.save_pending_deferred(d) + d2 = self.pdm.get_pending_deferred(did,True) + d2.addCallback(lambda r: self.assertEquals(r,'foo')) + d.callback('foo') + self.assert_(not self.pdm.quick_has_id(did)) + + def test_get_after_errback(self): + class MyError(Exception): + pass + d = defer.Deferred() + did = self.pdm.save_pending_deferred(d) + d.errback(failure.Failure(MyError('foo'))) + d2 = self.pdm.get_pending_deferred(did,True) + d2.addErrback(lambda f: self.assertRaises(MyError, f.raiseException)) + self.assert_(not self.pdm.quick_has_id(did)) + + def test_get_before_errback(self): + class MyError(Exception): + pass + d = defer.Deferred() + did = self.pdm.save_pending_deferred(d) + d2 = self.pdm.get_pending_deferred(did,True) + d.errback(failure.Failure(MyError('foo'))) + d2.addErrback(lambda f: self.assertRaises(MyError, f.raiseException)) + self.assert_(not self.pdm.quick_has_id(did)) + d = defer.Deferred() + did = self.pdm.save_pending_deferred(d) + d2 = self.pdm.get_pending_deferred(did,True) + d2.addErrback(lambda f: self.assertRaises(MyError, f.raiseException)) + d.errback(failure.Failure(MyError('foo'))) + self.assert_(not self.pdm.quick_has_id(did)) + + def test_noresult_noblock(self): + d = defer.Deferred() + did = self.pdm.save_pending_deferred(d) + d2 = self.pdm.get_pending_deferred(did,False) + d2.addErrback(lambda f: self.assertRaises(error.ResultNotCompleted, f.raiseException)) + + def test_with_callbacks(self): + d = defer.Deferred() + d.addCallback(lambda r: r+' foo') + d.addCallback(lambda r: r+' bar') + did = self.pdm.save_pending_deferred(d) + d2 = self.pdm.get_pending_deferred(did,True) + d.callback('bam') + d2.addCallback(lambda r: self.assertEquals(r,'bam foo bar')) + + def test_with_errbacks(self): + class MyError(Exception): + pass + d = defer.Deferred() + d.addCallback(lambda r: 'foo') + d.addErrback(lambda f: 'caught error') + did = self.pdm.save_pending_deferred(d) + d2 = self.pdm.get_pending_deferred(did,True) + d.errback(failure.Failure(MyError('bam'))) + d2.addErrback(lambda f: self.assertRaises(MyError, f.raiseException)) + + def test_nested_deferreds(self): + d = defer.Deferred() + d2 = defer.Deferred() + d.addCallback(lambda r: d2) + did = self.pdm.save_pending_deferred(d) + d.callback('foo') + d3 = self.pdm.get_pending_deferred(did,False) + d3.addErrback(lambda f: self.assertRaises(error.ResultNotCompleted, f.raiseException)) + d2.callback('bar') + d3 = self.pdm.get_pending_deferred(did,False) + d3.addCallback(lambda r: self.assertEquals(r,'bar')) + +#------------------------------------------------------------------------------- +# Regular Unittests +#------------------------------------------------------------------------------- + +# This ensures that the code will run either standalone as a script, or that it +# can be picked up by Twisted's `trial` test wrapper to run all the tests. +if tcommon.pexpect is not None: + if __name__ == '__main__': + unittest.main(testLoader=IPDocTestLoader(dt_files,dt_modules)) + else: + testSuite = lambda : makeTestSuite(__name__,dt_files,dt_modules) diff --git a/IPython/kernel/tests/test_task.py b/IPython/kernel/tests/test_task.py new file mode 100644 index 0000000..f957504 --- /dev/null +++ b/IPython/kernel/tests/test_task.py @@ -0,0 +1,50 @@ +# encoding: utf-8 + +"""This file contains unittests for the kernel.task.py module.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +try: + import time + + from twisted.internet import defer + from twisted.trial import unittest + + from IPython.kernel import task, controllerservice as cs, engineservice as es + from IPython.kernel.multiengine import IMultiEngine + from IPython.testing.util import DeferredTestCase + from IPython.kernel.tests.tasktest import ITaskControllerTestCase +except ImportError: + pass +else: + #------------------------------------------------------------------------------- + # Tests + #------------------------------------------------------------------------------- + + class BasicTaskControllerTestCase(DeferredTestCase, ITaskControllerTestCase): + + def setUp(self): + self.controller = cs.ControllerService() + self.controller.startService() + self.multiengine = IMultiEngine(self.controller) + self.tc = task.ITaskController(self.controller) + self.tc.failurePenalty = 0 + self.engines=[] + + def tearDown(self): + self.controller.stopService() + for e in self.engines: + e.stopService() + + diff --git a/IPython/kernel/tests/test_taskfc.py b/IPython/kernel/tests/test_taskfc.py new file mode 100644 index 0000000..1e15317 --- /dev/null +++ b/IPython/kernel/tests/test_taskfc.py @@ -0,0 +1,90 @@ +#!/usr/bin/env python +# encoding: utf-8 + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +try: + import time + + from twisted.internet import defer, reactor + + from IPython.kernel.fcutil import Tub, UnauthenticatedTub + + from IPython.kernel import task as taskmodule + from IPython.kernel import controllerservice as cs + import IPython.kernel.multiengine as me + from IPython.testing.util import DeferredTestCase + from IPython.kernel.multienginefc import IFCSynchronousMultiEngine + from IPython.kernel.taskfc import IFCTaskController + from IPython.kernel.util import printer + from IPython.kernel.tests.tasktest import ITaskControllerTestCase + from IPython.kernel.clientconnector import ClientConnector +except ImportError: + pass +else: + + #------------------------------------------------------------------------------- + # Tests + #------------------------------------------------------------------------------- + + class TaskTest(DeferredTestCase, ITaskControllerTestCase): + + def setUp(self): + + self.engines = [] + + self.controller = cs.ControllerService() + self.controller.startService() + self.imultiengine = me.IMultiEngine(self.controller) + self.itc = taskmodule.ITaskController(self.controller) + self.itc.failurePenalty = 0 + + self.mec_referenceable = IFCSynchronousMultiEngine(self.imultiengine) + self.tc_referenceable = IFCTaskController(self.itc) + + self.controller_tub = Tub() + self.controller_tub.listenOn('tcp:10105:interface=127.0.0.1') + self.controller_tub.setLocation('127.0.0.1:10105') + + mec_furl = self.controller_tub.registerReference(self.mec_referenceable) + tc_furl = self.controller_tub.registerReference(self.tc_referenceable) + self.controller_tub.startService() + + self.client_tub = ClientConnector() + d = self.client_tub.get_multiengine_client(mec_furl) + d.addCallback(self.handle_mec_client) + d.addCallback(lambda _: self.client_tub.get_task_client(tc_furl)) + d.addCallback(self.handle_tc_client) + return d + + def handle_mec_client(self, client): + self.multiengine = client + + def handle_tc_client(self, client): + self.tc = client + + def tearDown(self): + dlist = [] + # Shut down the multiengine client + d = self.client_tub.tub.stopService() + dlist.append(d) + # Shut down the engines + for e in self.engines: + e.stopService() + # Shut down the controller + d = self.controller_tub.stopService() + d.addBoth(lambda _: self.controller.stopService()) + dlist.append(d) + return defer.DeferredList(dlist) + diff --git a/IPython/kernel/twistedutil.py b/IPython/kernel/twistedutil.py new file mode 100644 index 0000000..6956d38 --- /dev/null +++ b/IPython/kernel/twistedutil.py @@ -0,0 +1,206 @@ +#!/usr/bin/env python +# encoding: utf-8 + +"""Things directly related to all of twisted.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import threading, Queue, atexit +import twisted + +from twisted.internet import defer, reactor +from twisted.python import log, failure + +#------------------------------------------------------------------------------- +# Classes related to twisted and threads +#------------------------------------------------------------------------------- + + +class ReactorInThread(threading.Thread): + """Run the twisted reactor in a different thread. + + For the process to be able to exit cleanly, do the following: + + rit = ReactorInThread() + rit.setDaemon(True) + rit.start() + + """ + + def run(self): + reactor.run(installSignalHandlers=0) + # self.join() + + def stop(self): + # I don't think this does anything useful. + blockingCallFromThread(reactor.stop) + self.join() + +if(twisted.version.major >= 8): + import twisted.internet.threads + def blockingCallFromThread(f, *a, **kw): + """ + Run a function in the reactor from a thread, and wait for the result + synchronously, i.e. until the callback chain returned by the function get a + result. + + Delegates to twisted.internet.threads.blockingCallFromThread(reactor, f, *a, **kw), + passing twisted.internet.reactor for the first argument. + + @param f: the callable to run in the reactor thread + @type f: any callable. + @param a: the arguments to pass to C{f}. + @param kw: the keyword arguments to pass to C{f}. + + @return: the result of the callback chain. + @raise: any error raised during the callback chain. + """ + return twisted.internet.threads.blockingCallFromThread(reactor, f, *a, **kw) + +else: + def blockingCallFromThread(f, *a, **kw): + """ + Run a function in the reactor from a thread, and wait for the result + synchronously, i.e. until the callback chain returned by the function get a + result. + + @param f: the callable to run in the reactor thread + @type f: any callable. + @param a: the arguments to pass to C{f}. + @param kw: the keyword arguments to pass to C{f}. + + @return: the result of the callback chain. + @raise: any error raised during the callback chain. + """ + from twisted.internet import reactor + queue = Queue.Queue() + def _callFromThread(): + result = defer.maybeDeferred(f, *a, **kw) + result.addBoth(queue.put) + + reactor.callFromThread(_callFromThread) + result = queue.get() + if isinstance(result, failure.Failure): + # This makes it easier for the debugger to get access to the instance + try: + result.raiseException() + except Exception, e: + raise e + return result + + + +#------------------------------------------------------------------------------- +# Things for managing deferreds +#------------------------------------------------------------------------------- + + +def parseResults(results): + """Pull out results/Failures from a DeferredList.""" + return [x[1] for x in results] + +def gatherBoth(dlist, fireOnOneCallback=0, + fireOnOneErrback=0, + consumeErrors=0, + logErrors=0): + """This is like gatherBoth, but sets consumeErrors=1.""" + d = DeferredList(dlist, fireOnOneCallback, fireOnOneErrback, + consumeErrors, logErrors) + if not fireOnOneCallback: + d.addCallback(parseResults) + return d + +SUCCESS = True +FAILURE = False + +class DeferredList(defer.Deferred): + """I combine a group of deferreds into one callback. + + I track a list of L{Deferred}s for their callbacks, and make a single + callback when they have all completed, a list of (success, result) + tuples, 'success' being a boolean. + + Note that you can still use a L{Deferred} after putting it in a + DeferredList. For example, you can suppress 'Unhandled error in Deferred' + messages by adding errbacks to the Deferreds *after* putting them in the + DeferredList, as a DeferredList won't swallow the errors. (Although a more + convenient way to do this is simply to set the consumeErrors flag) + + Note: This is a modified version of the twisted.internet.defer.DeferredList + """ + + fireOnOneCallback = 0 + fireOnOneErrback = 0 + + def __init__(self, deferredList, fireOnOneCallback=0, fireOnOneErrback=0, + consumeErrors=0, logErrors=0): + """Initialize a DeferredList. + + @type deferredList: C{list} of L{Deferred}s + @param deferredList: The list of deferreds to track. + @param fireOnOneCallback: (keyword param) a flag indicating that + only one callback needs to be fired for me to call + my callback + @param fireOnOneErrback: (keyword param) a flag indicating that + only one errback needs to be fired for me to call + my errback + @param consumeErrors: (keyword param) a flag indicating that any errors + raised in the original deferreds should be + consumed by this DeferredList. This is useful to + prevent spurious warnings being logged. + """ + self.resultList = [None] * len(deferredList) + defer.Deferred.__init__(self) + if len(deferredList) == 0 and not fireOnOneCallback: + self.callback(self.resultList) + + # These flags need to be set *before* attaching callbacks to the + # deferreds, because the callbacks use these flags, and will run + # synchronously if any of the deferreds are already fired. + self.fireOnOneCallback = fireOnOneCallback + self.fireOnOneErrback = fireOnOneErrback + self.consumeErrors = consumeErrors + self.logErrors = logErrors + self.finishedCount = 0 + + index = 0 + for deferred in deferredList: + deferred.addCallbacks(self._cbDeferred, self._cbDeferred, + callbackArgs=(index,SUCCESS), + errbackArgs=(index,FAILURE)) + index = index + 1 + + def _cbDeferred(self, result, index, succeeded): + """(internal) Callback for when one of my deferreds fires. + """ + self.resultList[index] = (succeeded, result) + + self.finishedCount += 1 + if not self.called: + if succeeded == SUCCESS and self.fireOnOneCallback: + self.callback((result, index)) + elif succeeded == FAILURE and self.fireOnOneErrback: + # We have modified this to fire the errback chain with the actual + # Failure instance the originally occured rather than twisted's + # FirstError which wraps the failure + self.errback(result) + elif self.finishedCount == len(self.resultList): + self.callback(self.resultList) + + if succeeded == FAILURE and self.logErrors: + log.err(result) + if succeeded == FAILURE and self.consumeErrors: + result = None + + return result diff --git a/IPython/kernel/util.py b/IPython/kernel/util.py new file mode 100644 index 0000000..5a5b8eb --- /dev/null +++ b/IPython/kernel/util.py @@ -0,0 +1,102 @@ +# encoding: utf-8 + +"""General utilities for kernel related things.""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import os, types + + +#------------------------------------------------------------------------------- +# Code +#------------------------------------------------------------------------------- + +def tarModule(mod): + """Makes a tarball (as a string) of a locally imported module. + + This method looks at the __file__ attribute of an imported module + and makes a tarball of the top level of the module. It then + reads the tarball into a binary string. + + The method returns the tarball's name and the binary string + representing the tarball. + + Notes: + + - It will handle both single module files, as well as packages. + - The byte code files (\*.pyc) are not deleted. + - It has not been tested with modules containing extension code, but + it should work in most cases. + - There are cross platform issues. + + """ + + if not isinstance(mod, types.ModuleType): + raise TypeError, "Pass an imported module to push_module" + module_dir, module_file = os.path.split(mod.__file__) + + # Figure out what the module is called and where it is + print "Locating the module..." + if "__init__.py" in module_file: # package + module_name = module_dir.split("/")[-1] + module_dir = "/".join(module_dir.split("/")[:-1]) + module_file = module_name + else: # Simple module + module_name = module_file.split(".")[0] + module_dir = module_dir + print "Module (%s) found in:\n%s" % (module_name, module_dir) + + # Make a tarball of the module in the cwd + if module_dir: + os.system('tar -cf %s.tar -C %s %s' % \ + (module_name, module_dir, module_file)) + else: # must be the cwd + os.system('tar -cf %s.tar %s' % \ + (module_name, module_file)) + + # Read the tarball into a binary string + tarball_name = module_name + ".tar" + tar_file = open(tarball_name,'rb') + fileString = tar_file.read() + tar_file.close() + + # Remove the local copy of the tarball + #os.system("rm %s" % tarball_name) + + return tarball_name, fileString + +#from the Python Cookbook: + +def curry(f, *curryArgs, **curryKWargs): + """Curry the function f with curryArgs and curryKWargs.""" + + def curried(*args, **kwargs): + dikt = dict(kwargs) + dikt.update(curryKWargs) + return f(*(curryArgs+args), **dikt) + + return curried + +#useful callbacks + +def catcher(r): + pass + +def printer(r, msg=''): + print "%s\n%r" % (msg, r) + return r + + + + diff --git a/IPython/testing/__init__.py b/IPython/testing/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/IPython/testing/__init__.py diff --git a/IPython/testing/ipdoctest.py b/IPython/testing/ipdoctest.py new file mode 100755 index 0000000..ff91f92 --- /dev/null +++ b/IPython/testing/ipdoctest.py @@ -0,0 +1,800 @@ +#!/usr/bin/env python +"""IPython-enhanced doctest module with unittest integration. + +This module is heavily based on the standard library's doctest module, but +enhances it with IPython support. This enables docstrings to contain +unmodified IPython input and output pasted from real IPython sessions. + +It should be possible to use this module as a drop-in replacement for doctest +whenever you wish to use IPython input. + +Since the module absorbs all normal doctest functionality, you can use a mix of +both plain Python and IPython examples in any given module, though not in the +same docstring. + +See a simple example at the bottom of this code which serves as self-test and +demonstration code. Simply run this file (use -v for details) to run the +tests. + +This module also contains routines to ease the integration of doctests with +regular unittest-based testing. In particular, see the DocTestLoader class and +the makeTestSuite utility function. + + +Limitations: + + - When generating examples for use as doctests, make sure that you have + pretty-printing OFF. This can be done either by starting ipython with the + flag '--nopprint', by setting pprint to 0 in your ipythonrc file, or by + interactively disabling it with %Pprint. This is required so that IPython + output matches that of normal Python, which is used by doctest for internal + execution. + + - Do not rely on specific prompt numbers for results (such as using + '_34==True', for example). For IPython tests run via an external process + the prompt numbers may be different, and IPython tests run as normal python + code won't even have these special _NN variables set at all. + + - IPython functions that produce output as a side-effect of calling a system + process (e.g. 'ls') can be doc-tested, but they must be handled in an + external IPython process. Such doctests must be tagged with: + + # ipdoctest: EXTERNAL + + so that the testing machinery handles them differently. Since these are run + via pexpect in an external process, they can't deal with exceptions or other + fancy featurs of regular doctests. You must limit such tests to simple + matching of the output. For this reason, I recommend you limit these kinds + of doctests to features that truly require a separate process, and use the + normal IPython ones (which have all the features of normal doctests) for + everything else. See the examples at the bottom of this file for a + comparison of what can be done with both types. +""" + +# Standard library imports +import __builtin__ +import doctest +import inspect +import os +import re +import sys +import unittest + +from doctest import * + +# Our own imports +from IPython.tools import utils + +########################################################################### +# +# We must start our own ipython object and heavily muck with it so that all the +# modifications IPython makes to system behavior don't send the doctest +# machinery into a fit. This code should be considered a gross hack, but it +# gets the job done. + +import IPython + +# Hack to restore __main__, which ipython modifies upon startup +_main = sys.modules.get('__main__') +ipython = IPython.Shell.IPShell(['--classic','--noterm_title']).IP +sys.modules['__main__'] = _main + +# Deactivate the various python system hooks added by ipython for +# interactive convenience so we don't confuse the doctest system +sys.displayhook = sys.__displayhook__ +sys.excepthook = sys.__excepthook__ + +# So that ipython magics and aliases can be doctested +__builtin__._ip = IPython.ipapi.get() + +# for debugging only!!! +#from IPython.Shell import IPShellEmbed;ipshell=IPShellEmbed(['--noterm_title']) # dbg + + +# runner +from IPython.irunner import IPythonRunner +iprunner = IPythonRunner(echo=False) + +########################################################################### + +# A simple subclassing of the original with a different class name, so we can +# distinguish and treat differently IPython examples from pure python ones. +class IPExample(doctest.Example): pass + +class IPExternalExample(doctest.Example): + """Doctest examples to be run in an external process.""" + + def __init__(self, source, want, exc_msg=None, lineno=0, indent=0, + options=None): + # Parent constructor + doctest.Example.__init__(self,source,want,exc_msg,lineno,indent,options) + + # An EXTRA newline is needed to prevent pexpect hangs + self.source += '\n' + +class IPDocTestParser(doctest.DocTestParser): + """ + A class used to parse strings containing doctest examples. + + Note: This is a version modified to properly recognize IPython input and + convert any IPython examples into valid Python ones. + """ + # This regular expression is used to find doctest examples in a + # string. It defines three groups: `source` is the source code + # (including leading indentation and prompts); `indent` is the + # indentation of the first (PS1) line of the source code; and + # `want` is the expected output (including leading indentation). + + # Classic Python prompts or default IPython ones + _PS1_PY = r'>>>' + _PS2_PY = r'\.\.\.' + + _PS1_IP = r'In\ \[\d+\]:' + _PS2_IP = r'\ \ \ \.\.\.+:' + + _RE_TPL = r''' + # Source consists of a PS1 line followed by zero or more PS2 lines. + (?P + (?:^(?P [ ]*) (?P %s) .*) # PS1 line + (?:\n [ ]* (?P %s) .*)*) # PS2 lines + \n? # a newline + # Want consists of any non-blank lines that do not start with PS1. + (?P (?:(?![ ]*$) # Not a blank line + (?![ ]*%s) # Not a line starting with PS1 + (?![ ]*%s) # Not a line starting with PS2 + .*$\n? # But any other line + )*) + ''' + + _EXAMPLE_RE_PY = re.compile( _RE_TPL % (_PS1_PY,_PS2_PY,_PS1_PY,_PS2_PY), + re.MULTILINE | re.VERBOSE) + + _EXAMPLE_RE_IP = re.compile( _RE_TPL % (_PS1_IP,_PS2_IP,_PS1_IP,_PS2_IP), + re.MULTILINE | re.VERBOSE) + + def ip2py(self,source): + """Convert input IPython source into valid Python.""" + out = [] + newline = out.append + for line in source.splitlines(): + newline(ipython.prefilter(line,True)) + newline('') # ensure a closing newline, needed by doctest + return '\n'.join(out) + + def parse(self, string, name=''): + """ + Divide the given string into examples and intervening text, + and return them as a list of alternating Examples and strings. + Line numbers for the Examples are 0-based. The optional + argument `name` is a name identifying this string, and is only + used for error messages. + """ + string = string.expandtabs() + # If all lines begin with the same indentation, then strip it. + min_indent = self._min_indent(string) + if min_indent > 0: + string = '\n'.join([l[min_indent:] for l in string.split('\n')]) + + output = [] + charno, lineno = 0, 0 + + # Whether to convert the input from ipython to python syntax + ip2py = False + # Find all doctest examples in the string. First, try them as Python + # examples, then as IPython ones + terms = list(self._EXAMPLE_RE_PY.finditer(string)) + if terms: + # Normal Python example + Example = doctest.Example + else: + # It's an ipython example. Note that IPExamples are run + # in-process, so their syntax must be turned into valid python. + # IPExternalExamples are run out-of-process (via pexpect) so they + # don't need any filtering (a real ipython will be executing them). + terms = list(self._EXAMPLE_RE_IP.finditer(string)) + if re.search(r'#\s*ipdoctest:\s*EXTERNAL',string): + #print '-'*70 # dbg + #print 'IPExternalExample, Source:\n',string # dbg + #print '-'*70 # dbg + Example = IPExternalExample + else: + #print '-'*70 # dbg + #print 'IPExample, Source:\n',string # dbg + #print '-'*70 # dbg + Example = IPExample + ip2py = True + + for m in terms: + # Add the pre-example text to `output`. + output.append(string[charno:m.start()]) + # Update lineno (lines before this example) + lineno += string.count('\n', charno, m.start()) + # Extract info from the regexp match. + (source, options, want, exc_msg) = \ + self._parse_example(m, name, lineno,ip2py) + if Example is IPExternalExample: + options[doctest.NORMALIZE_WHITESPACE] = True + # Create an Example, and add it to the list. + if not self._IS_BLANK_OR_COMMENT(source): + output.append(Example(source, want, exc_msg, + lineno=lineno, + indent=min_indent+len(m.group('indent')), + options=options)) + # Update lineno (lines inside this example) + lineno += string.count('\n', m.start(), m.end()) + # Update charno. + charno = m.end() + # Add any remaining post-example text to `output`. + output.append(string[charno:]) + + return output + + def _parse_example(self, m, name, lineno,ip2py=False): + """ + Given a regular expression match from `_EXAMPLE_RE` (`m`), + return a pair `(source, want)`, where `source` is the matched + example's source code (with prompts and indentation stripped); + and `want` is the example's expected output (with indentation + stripped). + + `name` is the string's name, and `lineno` is the line number + where the example starts; both are used for error messages. + + Optional: + `ip2py`: if true, filter the input via IPython to convert the syntax + into valid python. + """ + + # Get the example's indentation level. + indent = len(m.group('indent')) + + # Divide source into lines; check that they're properly + # indented; and then strip their indentation & prompts. + source_lines = m.group('source').split('\n') + + # We're using variable-length input prompts + ps1 = m.group('ps1') + ps2 = m.group('ps2') + ps1_len = len(ps1) + + self._check_prompt_blank(source_lines, indent, name, lineno,ps1_len) + if ps2: + self._check_prefix(source_lines[1:], ' '*indent + ps2, name, lineno) + + source = '\n'.join([sl[indent+ps1_len+1:] for sl in source_lines]) + + if ip2py: + # Convert source input from IPython into valid Python syntax + source = self.ip2py(source) + + # Divide want into lines; check that it's properly indented; and + # then strip the indentation. Spaces before the last newline should + # be preserved, so plain rstrip() isn't good enough. + want = m.group('want') + want_lines = want.split('\n') + if len(want_lines) > 1 and re.match(r' *$', want_lines[-1]): + del want_lines[-1] # forget final newline & spaces after it + self._check_prefix(want_lines, ' '*indent, name, + lineno + len(source_lines)) + + # Remove ipython output prompt that might be present in the first line + want_lines[0] = re.sub(r'Out\[\d+\]: \s*?\n?','',want_lines[0]) + + want = '\n'.join([wl[indent:] for wl in want_lines]) + + # If `want` contains a traceback message, then extract it. + m = self._EXCEPTION_RE.match(want) + if m: + exc_msg = m.group('msg') + else: + exc_msg = None + + # Extract options from the source. + options = self._find_options(source, name, lineno) + + return source, options, want, exc_msg + + def _check_prompt_blank(self, lines, indent, name, lineno, ps1_len): + """ + Given the lines of a source string (including prompts and + leading indentation), check to make sure that every prompt is + followed by a space character. If any line is not followed by + a space character, then raise ValueError. + + Note: IPython-modified version which takes the input prompt length as a + parameter, so that prompts of variable length can be dealt with. + """ + space_idx = indent+ps1_len + min_len = space_idx+1 + for i, line in enumerate(lines): + if len(line) >= min_len and line[space_idx] != ' ': + raise ValueError('line %r of the docstring for %s ' + 'lacks blank after %s: %r' % + (lineno+i+1, name, + line[indent:space_idx], line)) + + +SKIP = register_optionflag('SKIP') + +class IPDocTestRunner(doctest.DocTestRunner): + """Modified DocTestRunner which can also run IPython tests. + + This runner is capable of handling IPython doctests that require + out-of-process output capture (such as system calls via !cmd or aliases). + Note however that because these tests are run in a separate process, many + of doctest's fancier capabilities (such as detailed exception analysis) are + not available. So try to limit such tests to simple cases of matching + actual output. + """ + + #///////////////////////////////////////////////////////////////// + # DocTest Running + #///////////////////////////////////////////////////////////////// + + def _run_iptest(self, test, out): + """ + Run the examples in `test`. Write the outcome of each example with one + of the `DocTestRunner.report_*` methods, using the writer function + `out`. Return a tuple `(f, t)`, where `t` is the number of examples + tried, and `f` is the number of examples that failed. The examples are + run in the namespace `test.globs`. + + IPython note: this is a modified version of the original __run() + private method to handle out-of-process examples. + """ + + if out is None: + out = sys.stdout.write + + # Keep track of the number of failures and tries. + failures = tries = 0 + + # Save the option flags (since option directives can be used + # to modify them). + original_optionflags = self.optionflags + + SUCCESS, FAILURE, BOOM = range(3) # `outcome` state + + check = self._checker.check_output + + # Process each example. + for examplenum, example in enumerate(test.examples): + + # If REPORT_ONLY_FIRST_FAILURE is set, then supress + # reporting after the first failure. + quiet = (self.optionflags & REPORT_ONLY_FIRST_FAILURE and + failures > 0) + + # Merge in the example's options. + self.optionflags = original_optionflags + if example.options: + for (optionflag, val) in example.options.items(): + if val: + self.optionflags |= optionflag + else: + self.optionflags &= ~optionflag + + # If 'SKIP' is set, then skip this example. + if self.optionflags & SKIP: + continue + + # Record that we started this example. + tries += 1 + if not quiet: + self.report_start(out, test, example) + + # Run the example in the given context (globs), and record + # any exception that gets raised. (But don't intercept + # keyboard interrupts.) + try: + # Don't blink! This is where the user's code gets run. + got = '' + # The code is run in an external process + got = iprunner.run_source(example.source,get_output=True) + except KeyboardInterrupt: + raise + except: + self.debugger.set_continue() # ==== Example Finished ==== + + outcome = FAILURE # guilty until proved innocent or insane + + if check(example.want, got, self.optionflags): + outcome = SUCCESS + + # Report the outcome. + if outcome is SUCCESS: + if not quiet: + self.report_success(out, test, example, got) + elif outcome is FAILURE: + if not quiet: + self.report_failure(out, test, example, got) + failures += 1 + elif outcome is BOOM: + if not quiet: + self.report_unexpected_exception(out, test, example, + exc_info) + failures += 1 + else: + assert False, ("unknown outcome", outcome) + + # Restore the option flags (in case they were modified) + self.optionflags = original_optionflags + + # Record and return the number of failures and tries. + + # Hack to access a parent private method by working around Python's + # name mangling (which is fortunately simple). + doctest.DocTestRunner._DocTestRunner__record_outcome(self,test, + failures, tries) + return failures, tries + + def run(self, test, compileflags=None, out=None, clear_globs=True): + """Run examples in `test`. + + This method will defer to the parent for normal Python examples, but it + will run IPython ones via pexpect. + """ + if not test.examples: + return + + if isinstance(test.examples[0],IPExternalExample): + self._run_iptest(test,out) + else: + DocTestRunner.run(self,test,compileflags,out,clear_globs) + + +class IPDebugRunner(IPDocTestRunner,doctest.DebugRunner): + """IPython-modified DebugRunner, see the original class for details.""" + + def run(self, test, compileflags=None, out=None, clear_globs=True): + r = IPDocTestRunner.run(self, test, compileflags, out, False) + if clear_globs: + test.globs.clear() + return r + + +class IPDocTestLoader(unittest.TestLoader): + """A test loader with IPython-enhanced doctest support. + + Instances of this loader will automatically add doctests found in a module + to the test suite returned by the loadTestsFromModule method. In + addition, at initialization time a string of doctests can be given to the + loader, enabling it to add doctests to a module which didn't have them in + its docstring, coming from an external source.""" + + + def __init__(self,dt_files=None,dt_modules=None,test_finder=None): + """Initialize the test loader. + + :Keywords: + + dt_files : list (None) + List of names of files to be executed as doctests. + + dt_modules : list (None) + List of module names to be scanned for doctests in their + docstrings. + + test_finder : instance (None) + Instance of a testfinder (see doctest for details). + """ + + if dt_files is None: dt_files = [] + if dt_modules is None: dt_modules = [] + self.dt_files = utils.list_strings(dt_files) + self.dt_modules = utils.list_strings(dt_modules) + if test_finder is None: + test_finder = doctest.DocTestFinder(parser=IPDocTestParser()) + self.test_finder = test_finder + + def loadTestsFromModule(self, module): + """Return a suite of all tests cases contained in the given module. + + If the loader was initialized with a doctests argument, then this + string is assigned as the module's docstring.""" + + # Start by loading any tests in the called module itself + suite = super(self.__class__,self).loadTestsFromModule(module) + + # Now, load also tests referenced at construction time as companion + # doctests that reside in standalone files + for fname in self.dt_files: + #print 'mod:',module # dbg + #print 'fname:',fname # dbg + #suite.addTest(doctest.DocFileSuite(fname)) + suite.addTest(doctest.DocFileSuite(fname,module_relative=False)) + # Add docstring tests from module, if given at construction time + for mod in self.dt_modules: + suite.addTest(doctest.DocTestSuite(mod, + test_finder=self.test_finder)) + + #ipshell() # dbg + return suite + +def my_import(name): + """Module importer - taken from the python documentation. + + This function allows importing names with dots in them.""" + + mod = __import__(name) + components = name.split('.') + for comp in components[1:]: + mod = getattr(mod, comp) + return mod + +def makeTestSuite(module_name,dt_files=None,dt_modules=None,idt=True): + """Make a TestSuite object for a given module, specified by name. + + This extracts all the doctests associated with a module using a + DocTestLoader object. + + :Parameters: + + - module_name: a string containing the name of a module with unittests. + + :Keywords: + + dt_files : list of strings + List of names of plain text files to be treated as doctests. + + dt_modules : list of strings + List of names of modules to be scanned for doctests in docstrings. + + idt : bool (True) + If True, return integrated doctests. This means that each filename + listed in dt_files is turned into a *single* unittest, suitable for + running via unittest's runner or Twisted's Trial runner. If false, the + dt_files parameter is returned unmodified, so that other test runners + (such as oilrun) can run the doctests with finer granularity. + """ + + mod = my_import(module_name) + if idt: + suite = IPDocTestLoader(dt_files,dt_modules).loadTestsFromModule(mod) + else: + suite = IPDocTestLoader(None,dt_modules).loadTestsFromModule(mod) + + if idt: + return suite + else: + return suite,dt_files + +# Copied from doctest in py2.5 and modified for our purposes (since they don't +# parametrize what we need) + +# For backward compatibility, a global instance of a DocTestRunner +# class, updated by testmod. +master = None + +def testmod(m=None, name=None, globs=None, verbose=None, + report=True, optionflags=0, extraglobs=None, + raise_on_error=False, exclude_empty=False): + """m=None, name=None, globs=None, verbose=None, report=True, + optionflags=0, extraglobs=None, raise_on_error=False, + exclude_empty=False + + Note: IPython-modified version which loads test finder and runners that + recognize IPython syntax in doctests. + + Test examples in docstrings in functions and classes reachable + from module m (or the current module if m is not supplied), starting + with m.__doc__. + + Also test examples reachable from dict m.__test__ if it exists and is + not None. m.__test__ maps names to functions, classes and strings; + function and class docstrings are tested even if the name is private; + strings are tested directly, as if they were docstrings. + + Return (#failures, #tests). + + See doctest.__doc__ for an overview. + + Optional keyword arg "name" gives the name of the module; by default + use m.__name__. + + Optional keyword arg "globs" gives a dict to be used as the globals + when executing examples; by default, use m.__dict__. A copy of this + dict is actually used for each docstring, so that each docstring's + examples start with a clean slate. + + Optional keyword arg "extraglobs" gives a dictionary that should be + merged into the globals that are used to execute examples. By + default, no extra globals are used. This is new in 2.4. + + Optional keyword arg "verbose" prints lots of stuff if true, prints + only failures if false; by default, it's true iff "-v" is in sys.argv. + + Optional keyword arg "report" prints a summary at the end when true, + else prints nothing at the end. In verbose mode, the summary is + detailed, else very brief (in fact, empty if all tests passed). + + Optional keyword arg "optionflags" or's together module constants, + and defaults to 0. This is new in 2.3. Possible values (see the + docs for details): + + DONT_ACCEPT_TRUE_FOR_1 + DONT_ACCEPT_BLANKLINE + NORMALIZE_WHITESPACE + ELLIPSIS + SKIP + IGNORE_EXCEPTION_DETAIL + REPORT_UDIFF + REPORT_CDIFF + REPORT_NDIFF + REPORT_ONLY_FIRST_FAILURE + + Optional keyword arg "raise_on_error" raises an exception on the + first unexpected exception or failure. This allows failures to be + post-mortem debugged. + + Advanced tomfoolery: testmod runs methods of a local instance of + class doctest.Tester, then merges the results into (or creates) + global Tester instance doctest.master. Methods of doctest.master + can be called directly too, if you want to do something unusual. + Passing report=0 to testmod is especially useful then, to delay + displaying a summary. Invoke doctest.master.summarize(verbose) + when you're done fiddling. + """ + global master + + # If no module was given, then use __main__. + if m is None: + # DWA - m will still be None if this wasn't invoked from the command + # line, in which case the following TypeError is about as good an error + # as we should expect + m = sys.modules.get('__main__') + + # Check that we were actually given a module. + if not inspect.ismodule(m): + raise TypeError("testmod: module required; %r" % (m,)) + + # If no name was given, then use the module's name. + if name is None: + name = m.__name__ + + #---------------------------------------------------------------------- + # fperez - make IPython finder and runner: + # Find, parse, and run all tests in the given module. + finder = DocTestFinder(exclude_empty=exclude_empty, + parser=IPDocTestParser()) + + if raise_on_error: + runner = IPDebugRunner(verbose=verbose, optionflags=optionflags) + else: + runner = IPDocTestRunner(verbose=verbose, optionflags=optionflags, + #checker=IPOutputChecker() # dbg + ) + + # /fperez - end of ipython changes + #---------------------------------------------------------------------- + + for test in finder.find(m, name, globs=globs, extraglobs=extraglobs): + runner.run(test) + + if report: + runner.summarize() + + if master is None: + master = runner + else: + master.merge(runner) + + return runner.failures, runner.tries + + +# Simple testing and example code +if __name__ == "__main__": + + def ipfunc(): + """ + Some ipython tests... + + In [1]: import os + + In [2]: cd / + / + + In [3]: 2+3 + Out[3]: 5 + + In [26]: for i in range(3): + ....: print i, + ....: print i+1, + ....: + 0 1 1 2 2 3 + + + Examples that access the operating system work: + + In [19]: cd /tmp + /tmp + + In [20]: mkdir foo_ipython + + In [21]: cd foo_ipython + /tmp/foo_ipython + + In [23]: !touch bar baz + + # We unfortunately can't just call 'ls' because its output is not + # seen by doctest, since it happens in a separate process + + In [24]: os.listdir('.') + Out[24]: ['bar', 'baz'] + + In [25]: cd /tmp + /tmp + + In [26]: rm -rf foo_ipython + + + It's OK to use '_' for the last result, but do NOT try to use IPython's + numbered history of _NN outputs, since those won't exist under the + doctest environment: + + In [7]: 3+4 + Out[7]: 7 + + In [8]: _+3 + Out[8]: 10 + """ + + def ipfunc_external(): + """ + Tests that must be run in an external process + + + # ipdoctest: EXTERNAL + + In [11]: for i in range(10): + ....: print i, + ....: print i+1, + ....: + 0 1 1 2 2 3 3 4 4 5 5 6 6 7 7 8 8 9 9 10 + + + In [1]: import os + + In [1]: print "hello" + hello + + In [19]: cd /tmp + /tmp + + In [20]: mkdir foo_ipython2 + + In [21]: cd foo_ipython2 + /tmp/foo_ipython2 + + In [23]: !touch bar baz + + In [24]: ls + bar baz + + In [24]: !ls + bar baz + + In [25]: cd /tmp + /tmp + + In [26]: rm -rf foo_ipython2 + """ + + def pyfunc(): + """ + Some pure python tests... + + >>> import os + + >>> 2+3 + 5 + + >>> for i in range(3): + ... print i, + ... print i+1, + ... + 0 1 1 2 2 3 + """ + + # Call the global testmod() just like you would with normal doctest + testmod() diff --git a/IPython/testing/mkdoctests.py b/IPython/testing/mkdoctests.py new file mode 100755 index 0000000..142f1cc --- /dev/null +++ b/IPython/testing/mkdoctests.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python +"""Utility for making a doctest file out of Python or IPython input. + + %prog [options] input_file [output_file] + +This script is a convenient generator of doctest files that uses IPython's +irunner script to execute valid Python or IPython input in a separate process, +capture all of the output, and write it to an output file. + +It can be used in one of two ways: + +1. With a plain Python or IPython input file (denoted by extensions '.py' or + '.ipy'. In this case, the output is an auto-generated reST file with a + basic header, and the captured Python input and output contained in an + indented code block. + + If no output filename is given, the input name is used, with the extension + replaced by '.txt'. + +2. With an input template file. Template files are simply plain text files + with special directives of the form + + %run filename + + to include the named file at that point. + + If no output filename is given and the input filename is of the form + 'base.tpl.txt', the output will be automatically named 'base.txt'. +""" + +# Standard library imports + +import optparse +import os +import re +import sys +import tempfile + +# IPython-specific libraries +from IPython import irunner +from IPython.genutils import fatal + +class IndentOut(object): + """A simple output stream that indents all output by a fixed amount. + + Instances of this class trap output to a given stream and first reformat it + to indent every input line.""" + + def __init__(self,out=sys.stdout,indent=4): + """Create an indented writer. + + :Keywords: + + - `out` : stream (sys.stdout) + Output stream to actually write to after indenting. + + - `indent` : int + Number of spaces to indent every input line by. + """ + + self.indent_text = ' '*indent + self.indent = re.compile('^',re.MULTILINE).sub + self.out = out + self._write = out.write + self.buffer = [] + self._closed = False + + def write(self,data): + """Write a string to the output stream.""" + + if self._closed: + raise ValueError('I/O operation on closed file') + self.buffer.append(data) + + def flush(self): + if self.buffer: + data = ''.join(self.buffer) + self.buffer[:] = [] + self._write(self.indent(self.indent_text,data)) + + def close(self): + self.flush() + self._closed = True + +class RunnerFactory(object): + """Code runner factory. + + This class provides an IPython code runner, but enforces that only one + runner is every instantiated. The runner is created based on the extension + of the first file to run, and it raises an exception if a runner is later + requested for a different extension type. + + This ensures that we don't generate example files for doctest with a mix of + python and ipython syntax. + """ + + def __init__(self,out=sys.stdout): + """Instantiate a code runner.""" + + self.out = out + self.runner = None + self.runnerClass = None + + def _makeRunner(self,runnerClass): + self.runnerClass = runnerClass + self.runner = runnerClass(out=self.out) + return self.runner + + def __call__(self,fname): + """Return a runner for the given filename.""" + + if fname.endswith('.py'): + runnerClass = irunner.PythonRunner + elif fname.endswith('.ipy'): + runnerClass = irunner.IPythonRunner + else: + raise ValueError('Unknown file type for Runner: %r' % fname) + + if self.runner is None: + return self._makeRunner(runnerClass) + else: + if runnerClass==self.runnerClass: + return self.runner + else: + e='A runner of type %r can not run file %r' % \ + (self.runnerClass,fname) + raise ValueError(e) + +TPL = """ +========================= + Auto-generated doctests +========================= + +This file was auto-generated by IPython in its entirety. If you need finer +control over the contents, simply make a manual template. See the +mkdoctests.py script for details. + +%%run %s +""" + +def main(): + """Run as a script.""" + + # Parse options and arguments. + parser = optparse.OptionParser(usage=__doc__) + newopt = parser.add_option + newopt('-f','--force',action='store_true',dest='force',default=False, + help='Force overwriting of the output file.') + newopt('-s','--stdout',action='store_true',dest='stdout',default=False, + help='Use stdout instead of a file for output.') + + opts,args = parser.parse_args() + if len(args) < 1: + parser.error("incorrect number of arguments") + + # Input filename + fname = args[0] + + # We auto-generate the output file based on a trivial template to make it + # really easy to create simple doctests. + + auto_gen_output = False + try: + outfname = args[1] + except IndexError: + outfname = None + + if fname.endswith('.tpl.txt') and outfname is None: + outfname = fname.replace('.tpl.txt','.txt') + else: + bname, ext = os.path.splitext(fname) + if ext in ['.py','.ipy']: + auto_gen_output = True + if outfname is None: + outfname = bname+'.txt' + + # Open input file + + # In auto-gen mode, we actually change the name of the input file to be our + # auto-generated template + if auto_gen_output: + infile = tempfile.TemporaryFile() + infile.write(TPL % fname) + infile.flush() + infile.seek(0) + else: + infile = open(fname) + + # Now open the output file. If opts.stdout was given, this overrides any + # explicit choice of output filename and just directs all output to + # stdout. + if opts.stdout: + outfile = sys.stdout + else: + # Argument processing finished, start main code + if os.path.isfile(outfname) and not opts.force: + fatal("Output file %r exists, use --force (-f) to overwrite." + % outfname) + outfile = open(outfname,'w') + + + # all output from included files will be indented + indentOut = IndentOut(outfile,4) + getRunner = RunnerFactory(indentOut) + + # Marker in reST for transition lines + rst_transition = '\n'+'-'*76+'\n\n' + + # local shorthand for loop + write = outfile.write + + # Process input, simply writing back out all normal lines and executing the + # files in lines marked as '%run filename'. + for line in infile: + if line.startswith('%run '): + # We don't support files with spaces in their names. + incfname = line.split()[1] + + # We make the output of the included file appear bracketed between + # clear reST transition marks, and indent it so that if anyone + # makes an HTML or PDF out of the file, all doctest input and + # output appears in proper literal blocks. + write(rst_transition) + write('Begin included file %s::\n\n' % incfname) + + # I deliberately do NOT trap any exceptions here, so that if + # there's any problem, the user running this at the command line + # finds out immediately by the code blowing up, rather than ending + # up silently with an incomplete or incorrect file. + getRunner(incfname).run_file(incfname) + + write('\nEnd included file %s\n' % incfname) + write(rst_transition) + else: + # The rest of the input file is just written out + write(line) + infile.close() + + # Don't close sys.stdout!!! + if outfile is not sys.stdout: + outfile.close() + +if __name__ == '__main__': + main() diff --git a/IPython/testing/parametric.py b/IPython/testing/parametric.py new file mode 100644 index 0000000..cd2bf81 --- /dev/null +++ b/IPython/testing/parametric.py @@ -0,0 +1,55 @@ +"""Parametric testing on top of twisted.trial.unittest. + +""" + +__all__ = ['parametric','Parametric'] + +from twisted.trial.unittest import TestCase + +def partial(f, *partial_args, **partial_kwargs): + """Generate a partial class method. + + """ + def partial_func(self, *args, **kwargs): + dikt = dict(kwargs) + dikt.update(partial_kwargs) + return f(self, *(partial_args+args), **dikt) + + return partial_func + +def parametric(f): + """Mark f as a parametric test. + + """ + f._parametric = True + return classmethod(f) + +def Parametric(cls): + """Register parametric tests with a class. + + """ + # Walk over all tests marked with @parametric + test_generators = [getattr(cls,f) for f in dir(cls) + if f.startswith('test')] + test_generators = [m for m in test_generators if hasattr(m,'_parametric')] + for test_gen in test_generators: + test_name = test_gen.func_name + + # Insert a new test for each parameter + for n,test_and_params in enumerate(test_gen()): + test_method = test_and_params[0] + test_params = test_and_params[1:] + + # Here we use partial (defined above), which returns a + # class method of type ``types.FunctionType``, unlike + # functools.partial which returns a function of type + # ``functools.partial``. + partial_func = partial(test_method,*test_params) + # rename the test to look like a testcase + partial_func.__name__ = 'test_' + partial_func.__name__ + + # insert the new function into the class as a test + setattr(cls, test_name + '_%s' % n, partial_func) + + # rename test generator so it isn't called again by nose + test_gen.im_func.func_name = '__done_' + test_name diff --git a/IPython/testing/tcommon.py b/IPython/testing/tcommon.py new file mode 100644 index 0000000..73f7362 --- /dev/null +++ b/IPython/testing/tcommon.py @@ -0,0 +1,36 @@ +"""Common utilities for testing IPython. + +This file is meant to be used as + +from IPython.testing.tcommon import * + +by any test code. + +While a bit ugly, this helps us keep all testing facilities in one place, and +start coding standalone test scripts easily, which can then be pasted into the +larger test suites without any modifications required. +""" + +# Required modules and packages + +# Standard Python lib +import cPickle as pickle +import doctest +import math +import os +import sys +import unittest + +from pprint import pformat, pprint + +# From the IPython test lib +import tutils +from tutils import fullPath + +try: + import pexpect +except ImportError: + pexpect = None +else: + from IPython.testing.ipdoctest import IPDocTestLoader,makeTestSuite + diff --git a/IPython/testing/testTEMPLATE.py b/IPython/testing/testTEMPLATE.py new file mode 100755 index 0000000..2484a13 --- /dev/null +++ b/IPython/testing/testTEMPLATE.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python +# encoding: utf-8 +"""Simple template for unit tests. + +This file should be renamed to + +test_FEATURE.py + +so that it is recognized by the overall test driver (Twisted's 'trial'), which +looks for all test_*.py files in the current directory to extract tests from +them. +""" +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2005 Fernando Perez +# Brian E Granger +# Benjamin Ragan-Kelley +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from IPython.testing import tcommon +from IPython.testing.tcommon import * + +#------------------------------------------------------------------------------- +# Setup for inline and standalone doctests +#------------------------------------------------------------------------------- + + +# If you have standalone doctests in a separate file, set their names in the +# dt_files variable (as a single string or a list thereof). The mkPath call +# forms an absolute path based on the current file, it is not needed if you +# provide the full pahts. +dt_files = fullPath(__file__,[]) + + +# If you have any modules whose docstrings should be scanned for embedded tests +# as examples accorging to standard doctest practice, set them here (as a +# single string or a list thereof): +dt_modules = [] + +#------------------------------------------------------------------------------- +# Regular Unittests +#------------------------------------------------------------------------------- + +class FooTestCase(unittest.TestCase): + def test_foo(self): + pass + +#------------------------------------------------------------------------------- +# Regular Unittests +#------------------------------------------------------------------------------- + +# This ensures that the code will run either standalone as a script, or that it +# can be picked up by Twisted's `trial` test wrapper to run all the tests. +if tcommon.pexpect is not None: + if __name__ == '__main__': + unittest.main(testLoader=IPDocTestLoader(dt_files,dt_modules)) + else: + testSuite = lambda : makeTestSuite(__name__,dt_files,dt_modules) diff --git a/IPython/testing/tests/__init__.py b/IPython/testing/tests/__init__.py new file mode 100644 index 0000000..f751f68 --- /dev/null +++ b/IPython/testing/tests/__init__.py @@ -0,0 +1,10 @@ +# encoding: utf-8 +__docformat__ = "restructuredtext en" +#------------------------------------------------------------------------------- +# Copyright (C) 2005 Fernando Perez +# Brian E Granger +# Benjamin Ragan-Kelley +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- diff --git a/IPython/testing/tests/test_testutils.py b/IPython/testing/tests/test_testutils.py new file mode 100755 index 0000000..683661b --- /dev/null +++ b/IPython/testing/tests/test_testutils.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python +# encoding: utf-8 +"""Simple template for unit tests. + +This file should be renamed to + +test_FEATURE.py + +so that it is recognized by the overall test driver (Twisted's 'trial'), which +looks for all test_*.py files in the current directory to extract tests from +them. +""" +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2005 Fernando Perez +# Brian E Granger +# Benjamin Ragan-Kelley +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from IPython.testing import tcommon +from IPython.testing.tcommon import * + +#------------------------------------------------------------------------------- +# Setup for inline and standalone doctests +#------------------------------------------------------------------------------- + + +# If you have standalone doctests in a separate file, set their names in the +# dt_files variable (as a single string or a list thereof). The mkPath call +# forms an absolute path based on the current file, it is not needed if you +# provide the full pahts. +dt_files = fullPath(__file__,[]) + + +# If you have any modules whose docstrings should be scanned for embedded tests +# as examples accorging to standard doctest practice, set them here (as a +# single string or a list thereof): +dt_modules = ['IPython.testing.tutils'] + +#------------------------------------------------------------------------------- +# Regular Unittests +#------------------------------------------------------------------------------- + +## class FooTestCase(unittest.TestCase): +## def test_foo(self): +## pass + +#------------------------------------------------------------------------------- +# Regular Unittests +#------------------------------------------------------------------------------- + +# This ensures that the code will run either standalone as a script, or that it +# can be picked up by Twisted's `trial` test wrapper to run all the tests. +if tcommon.pexpect is not None: + if __name__ == '__main__': + unittest.main(testLoader=IPDocTestLoader(dt_files,dt_modules)) + else: + testSuite = lambda : makeTestSuite(__name__,dt_files,dt_modules) diff --git a/IPython/testing/tstTEMPLATE_doctest.py b/IPython/testing/tstTEMPLATE_doctest.py new file mode 100644 index 0000000..a01b7f5 --- /dev/null +++ b/IPython/testing/tstTEMPLATE_doctest.py @@ -0,0 +1,16 @@ +"""Run this file with + + irunner --python filename + +to generate valid doctest input. + +NOTE: make sure to ALWAYS have a blank line before comments, otherwise doctest +gets confused.""" + +#--------------------------------------------------------------------------- + +# Setup - all imports are done in tcommon +import tcommon; reload(tcommon) # for interactive use +from IPython.testing.tcommon import * + +# Doctest code begins here diff --git a/IPython/testing/tstTEMPLATE_doctest.txt b/IPython/testing/tstTEMPLATE_doctest.txt new file mode 100644 index 0000000..9f7d7e5 --- /dev/null +++ b/IPython/testing/tstTEMPLATE_doctest.txt @@ -0,0 +1,24 @@ +Doctests for the ``XXX`` module +===================================== + +The way doctest loads these, the entire document is applied as a single test +rather than multiple individual ones, unfortunately. + + +Auto-generated tests +-------------------- + +The tests below are generated from the companion file +test_toeplitz_doctest.py, which is run via IPython's irunner script to create +valid doctest input automatically. + +# Setup - all imports are done in tcommon +>>> from IPython.testing.tcommon import * + +# Rest of doctest goes here... + + +Manually generated tests +------------------------ + +These are one-off tests written by hand, copied from an interactive prompt. diff --git a/IPython/testing/tutils.py b/IPython/testing/tutils.py new file mode 100644 index 0000000..6bb14b4 --- /dev/null +++ b/IPython/testing/tutils.py @@ -0,0 +1,72 @@ +"""Utilities for testing code. +""" + +# Required modules and packages + +# Standard Python lib +import os +import sys + +# From this project +from IPython.tools import utils + +# path to our own installation, so we can find source files under this. +TEST_PATH = os.path.dirname(os.path.abspath(__file__)) + +# Global flag, used by vprint +VERBOSE = '-v' in sys.argv or '--verbose' in sys.argv + +########################################################################## +# Code begins + +# Some utility functions +def vprint(*args): + """Print-like function which relies on a global VERBOSE flag.""" + if not VERBOSE: + return + + write = sys.stdout.write + for item in args: + write(str(item)) + write('\n') + sys.stdout.flush() + +def test_path(path): + """Return a path as a subdir of the test package. + + This finds the correct path of the test package on disk, and prepends it + to the input path.""" + + return os.path.join(TEST_PATH,path) + +def fullPath(startPath,files): + """Make full paths for all the listed files, based on startPath. + + Only the base part of startPath is kept, since this routine is typically + used with a script's __file__ variable as startPath. The base of startPath + is then prepended to all the listed files, forming the output list. + + :Parameters: + startPath : string + Initial path to use as the base for the results. This path is split + using os.path.split() and only its first component is kept. + + files : string or list + One or more files. + + :Examples: + + >>> fullPath('/foo/bar.py',['a.txt','b.txt']) + ['/foo/a.txt', '/foo/b.txt'] + + >>> fullPath('/foo',['a.txt','b.txt']) + ['/a.txt', '/b.txt'] + + If a single file is given, the output is still a list: + >>> fullPath('/foo','a.txt') + ['/a.txt'] + """ + + files = utils.list_strings(files) + base = os.path.split(startPath)[0] + return [ os.path.join(base,f) for f in files ] diff --git a/IPython/testing/util.py b/IPython/testing/util.py new file mode 100644 index 0000000..7036650 --- /dev/null +++ b/IPython/testing/util.py @@ -0,0 +1,64 @@ +# encoding: utf-8 +"""This file contains utility classes for performing tests with Deferreds. +""" +__docformat__ = "restructuredtext en" +#------------------------------------------------------------------------------- +# Copyright (C) 2005 Fernando Perez +# Brian E Granger +# Benjamin Ragan-Kelley +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +from twisted.trial import unittest +from twisted.internet import defer + +class DeferredTestCase(unittest.TestCase): + + def assertDeferredEquals(self, deferred, expectedResult, + chainDeferred=None): + """Calls assertEquals on the result of the deferred and expectedResult. + + chainDeferred can be used to pass in previous Deferred objects that + have tests being run on them. This chaining of Deferred's in tests + is needed to insure that all Deferred's are cleaned up at the end of + a test. + """ + + if chainDeferred is None: + chainDeferred = defer.succeed(None) + + def gotResult(actualResult): + self.assertEquals(actualResult, expectedResult) + + deferred.addCallback(gotResult) + + return chainDeferred.addCallback(lambda _: deferred) + + def assertDeferredRaises(self, deferred, expectedException, + chainDeferred=None): + """Calls assertRaises on the Failure of the deferred and expectedException. + + chainDeferred can be used to pass in previous Deferred objects that + have tests being run on them. This chaining of Deferred's in tests + is needed to insure that all Deferred's are cleaned up at the end of + a test. + """ + + if chainDeferred is None: + chainDeferred = defer.succeed(None) + + def gotFailure(f): + #f.printTraceback() + self.assertRaises(expectedException, f.raiseException) + #return f + + deferred.addBoth(gotFailure) + + return chainDeferred.addCallback(lambda _: deferred) + diff --git a/IPython/tools/__init__.py b/IPython/tools/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/IPython/tools/__init__.py diff --git a/IPython/tools/growl.py b/IPython/tools/growl.py new file mode 100644 index 0000000..cc4613c --- /dev/null +++ b/IPython/tools/growl.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# encoding: utf-8 + + +class IPythonGrowlError(Exception): + pass + +class Notifier(object): + + def __init__(self, app_name): + try: + import Growl + except ImportError: + self.g_notifier = None + else: + self.g_notifier = Growl.GrowlNotifier(app_name, ['kernel', 'core']) + self.g_notifier.register() + + def _notify(self, title, msg): + if self.g_notifier is not None: + self.g_notifier.notify('kernel', title, msg) + + def notify(self, title, msg): + self._notify(title, msg) + + def notify_deferred(self, r, msg): + title = "Deferred Result" + msg = msg + '\n' + repr(r) + self._notify(title, msg) + return r + +_notifier = None + +def notify(title, msg): + pass + +def notify_deferred(r, msg): + return r + +def start(app_name): + global _notifier, notify, notify_deferred + if _notifier is not None: + raise IPythonGrowlError("this process is already registered with Growl") + else: + _notifier = Notifier(app_name) + notify = _notifier.notify + notify_deferred = _notifier.notify_deferred + + diff --git a/IPython/tools/tests/__init__.py b/IPython/tools/tests/__init__.py new file mode 100644 index 0000000..f751f68 --- /dev/null +++ b/IPython/tools/tests/__init__.py @@ -0,0 +1,10 @@ +# encoding: utf-8 +__docformat__ = "restructuredtext en" +#------------------------------------------------------------------------------- +# Copyright (C) 2005 Fernando Perez +# Brian E Granger +# Benjamin Ragan-Kelley +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- diff --git a/IPython/tools/tests/test_tools_utils.py b/IPython/tools/tests/test_tools_utils.py new file mode 100755 index 0000000..61ec558 --- /dev/null +++ b/IPython/tools/tests/test_tools_utils.py @@ -0,0 +1,36 @@ +#!/usr/bin/env python +"""Testing script for the tools.utils module. +""" + +# Module imports +from IPython.testing import tcommon +from IPython.testing.tcommon import * + +# If you have standalone doctests in a separate file, set their names in the +# dt_files variable (as a single string or a list thereof). The mkPath call +# forms an absolute path based on the current file, it is not needed if you +# provide the full pahts. +dt_files = fullPath(__file__,['tst_tools_utils_doctest.txt', + 'tst_tools_utils_doctest2.txt']) + +# If you have any modules whose docstrings should be scanned for embedded tests +# as examples accorging to standard doctest practice, set them here (as a +# single string or a list thereof): +dt_modules = 'IPython.tools.utils' + +########################################################################## +### Regular unittest test classes go here + +## class utilsTestCase(unittest.TestCase): +## def test_foo(self): +## pass + +########################################################################## +### Main +# This ensures that the code will run either standalone as a script, or that it +# can be picked up by Twisted's `trial` test wrapper to run all the tests. +if tcommon.pexpect is not None: + if __name__ == '__main__': + unittest.main(testLoader=IPDocTestLoader(dt_files,dt_modules)) + else: + testSuite = lambda : makeTestSuite(__name__,dt_files,dt_modules) diff --git a/IPython/tools/tests/tst_tools_utils_doctest.py b/IPython/tools/tests/tst_tools_utils_doctest.py new file mode 100644 index 0000000..bb0abe5 --- /dev/null +++ b/IPython/tools/tests/tst_tools_utils_doctest.py @@ -0,0 +1,12 @@ +# Setup - all imports are done in tcommon +from IPython.testing import tcommon +from IPython.testing.tcommon import * + +# Doctest code begins here +from IPython.tools import utils + +for i in range(10): + print i, + print i+1 + +print 'simple loop is over' diff --git a/IPython/tools/tests/tst_tools_utils_doctest.tpl.txt b/IPython/tools/tests/tst_tools_utils_doctest.tpl.txt new file mode 100644 index 0000000..929d9c2 --- /dev/null +++ b/IPython/tools/tests/tst_tools_utils_doctest.tpl.txt @@ -0,0 +1,18 @@ +========================================= + Doctests for the ``tools.utils`` module +========================================= + +The way doctest loads these, the entire document is applied as a single test +rather than multiple individual ones, unfortunately. + + +Auto-generated tests +==================== + +%run tst_tools_utils_doctest.py + + +Manually generated tests +======================== + +These are one-off tests written by hand, copied from an interactive prompt. diff --git a/IPython/tools/tests/tst_tools_utils_doctest.txt b/IPython/tools/tests/tst_tools_utils_doctest.txt new file mode 100644 index 0000000..8873850 --- /dev/null +++ b/IPython/tools/tests/tst_tools_utils_doctest.txt @@ -0,0 +1,42 @@ + +========================= + Auto-generated doctests +========================= + +This file was auto-generated by IPython in its entirety. If you need finer +control over the contents, simply make a manual template. See the +mkdoctests.py script for details. + + +---------------------------------------------------------------------------- + +Begin included file tst_tools_utils_doctest.py:: + + # Setup - all imports are done in tcommon + >>> from IPython.testing import tcommon + >>> from IPython.testing.tcommon import * + + # Doctest code begins here + >>> from IPython.tools import utils + + >>> for i in range(10): + ... print i, + ... print i+1 + ... + 0 1 + 1 2 + 2 3 + 3 4 + 4 5 + 5 6 + 6 7 + 7 8 + 8 9 + 9 10 + >>> print 'simple loop is over' + simple loop is over + +End included file tst_tools_utils_doctest.py + +---------------------------------------------------------------------------- + diff --git a/IPython/tools/tests/tst_tools_utils_doctest2.py b/IPython/tools/tests/tst_tools_utils_doctest2.py new file mode 100644 index 0000000..4e6f3b2 --- /dev/null +++ b/IPython/tools/tests/tst_tools_utils_doctest2.py @@ -0,0 +1,13 @@ +# Setup - all imports are done in tcommon +from IPython.testing import tcommon +from IPython.testing.tcommon import * + +# Doctest code begins here +from IPython.tools import utils + +# Some other tests for utils + +utils.marquee('Testing marquee') + +utils.marquee('Another test',30,'.') + diff --git a/IPython/tools/tests/tst_tools_utils_doctest2.tpl.txt b/IPython/tools/tests/tst_tools_utils_doctest2.tpl.txt new file mode 100644 index 0000000..d09f4c4 --- /dev/null +++ b/IPython/tools/tests/tst_tools_utils_doctest2.tpl.txt @@ -0,0 +1,18 @@ +========================================= + Doctests for the ``tools.utils`` module +========================================= + +The way doctest loads these, the entire document is applied as a single test +rather than multiple individual ones, unfortunately. + + +Auto-generated tests +==================== + +%run tst_tools_utils_doctest2.py + + +Manually generated tests +======================== + +These are one-off tests written by hand, copied from an interactive prompt. diff --git a/IPython/tools/tests/tst_tools_utils_doctest2.txt b/IPython/tools/tests/tst_tools_utils_doctest2.txt new file mode 100644 index 0000000..1ccea08 --- /dev/null +++ b/IPython/tools/tests/tst_tools_utils_doctest2.txt @@ -0,0 +1,42 @@ +========================================= + Doctests for the ``tools.utils`` module +========================================= + +The way doctest loads these, the entire document is applied as a single test +rather than multiple individual ones, unfortunately. + + +Auto-generated tests +==================== + + +---------------------------------------------------------------------------- + +Begin included file tst_tools_utils_doctest2.py:: + + # Setup - all imports are done in tcommon + >>> from IPython.testing import tcommon + >>> from IPython.testing.tcommon import * + + # Doctest code begins here + >>> from IPython.tools import utils + + # Some other tests for utils + + >>> utils.marquee('Testing marquee') + '****************************** Testing marquee ******************************' + + >>> utils.marquee('Another test',30,'.') + '........ Another test ........' + + +End included file tst_tools_utils_doctest2.py + +---------------------------------------------------------------------------- + + + +Manually generated tests +======================== + +These are one-off tests written by hand, copied from an interactive prompt. diff --git a/IPython/tools/utils.py b/IPython/tools/utils.py new file mode 100644 index 0000000..f34cf5a --- /dev/null +++ b/IPython/tools/utils.py @@ -0,0 +1,128 @@ +# encoding: utf-8 +"""Generic utilities for use by IPython's various subsystems. +""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2006 Fernando Perez +# Brian E Granger +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#--------------------------------------------------------------------------- +# Stdlib imports +#--------------------------------------------------------------------------- + +import os +import sys + +#--------------------------------------------------------------------------- +# Other IPython utilities +#--------------------------------------------------------------------------- + + +#--------------------------------------------------------------------------- +# Normal code begins +#--------------------------------------------------------------------------- + +def extractVars(*names,**kw): + """Extract a set of variables by name from another frame. + + :Parameters: + - `*names`: strings + One or more variable names which will be extracted from the caller's + frame. + + :Keywords: + - `depth`: integer (0) + How many frames in the stack to walk when looking for your variables. + + + Examples: + + In [2]: def func(x): + ...: y = 1 + ...: print extractVars('x','y') + ...: + + In [3]: func('hello') + {'y': 1, 'x': 'hello'} + """ + + depth = kw.get('depth',0) + + callerNS = sys._getframe(depth+1).f_locals + return dict((k,callerNS[k]) for k in names) + + +def extractVarsAbove(*names): + """Extract a set of variables by name from another frame. + + Similar to extractVars(), but with a specified depth of 1, so that names + are exctracted exactly from above the caller. + + This is simply a convenience function so that the very common case (for us) + of skipping exactly 1 frame doesn't have to construct a special dict for + keyword passing.""" + + callerNS = sys._getframe(2).f_locals + return dict((k,callerNS[k]) for k in names) + +def shexp(s): + """Expand $VARS and ~names in a string, like a shell + + :Examples: + + In [2]: os.environ['FOO']='test' + + In [3]: shexp('variable FOO is $FOO') + Out[3]: 'variable FOO is test' + """ + return os.path.expandvars(os.path.expanduser(s)) + + +def list_strings(arg): + """Always return a list of strings, given a string or list of strings + as input. + + :Examples: + + In [7]: list_strings('A single string') + Out[7]: ['A single string'] + + In [8]: list_strings(['A single string in a list']) + Out[8]: ['A single string in a list'] + + In [9]: list_strings(['A','list','of','strings']) + Out[9]: ['A', 'list', 'of', 'strings'] + """ + + if isinstance(arg,basestring): return [arg] + else: return arg + +def marquee(txt='',width=78,mark='*'): + """Return the input string centered in a 'marquee'. + + :Examples: + + In [16]: marquee('A test',40) + Out[16]: '**************** A test ****************' + + In [17]: marquee('A test',40,'-') + Out[17]: '---------------- A test ----------------' + + In [18]: marquee('A test',40,' ') + Out[18]: ' A test ' + + """ + if not txt: + return (mark*width)[:width] + nmark = (width-len(txt)-2)/len(mark)/2 + if nmark < 0: nmark =0 + marks = mark*nmark + return '%s %s %s' % (marks,txt,marks) + + diff --git a/MANIFEST.in b/MANIFEST.in index 4ff6ecb..43955dd 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -8,6 +8,11 @@ graft setupext graft IPython/UserConfig +graft IPython/kernel +graft IPython/config +graft IPython/testing +graft IPython/tools + graft doc exclude doc/\#* exclude doc/*.1 diff --git a/setup.py b/setup.py index a98f793..5826432 100755 --- a/setup.py +++ b/setup.py @@ -6,12 +6,16 @@ Under Posix environments it works like a typical setup.py script. Under Windows, the command sdist is not supported, since IPython requires utilities which are not available under Windows.""" -#***************************************************************************** -# Copyright (C) 2001-2005 Fernando Perez +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team # # Distributed under the terms of the BSD License. The full license is in # the file COPYING, distributed as part of this software. -#***************************************************************************** +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- # Stdlib imports import os @@ -24,35 +28,24 @@ from glob import glob if os.path.exists('MANIFEST'): os.remove('MANIFEST') from distutils.core import setup -from setupext import install_data_ext # Local imports from IPython.genutils import target_update -# A few handy globals +from setupbase import ( + setup_args, + find_packages, + find_package_data, + find_scripts, + find_data_files, + check_for_dependencies +) + isfile = os.path.isfile -pjoin = os.path.join - -############################################################################## -# Utility functions -def oscmd(s): - print ">", s - os.system(s) - -# A little utility we'll need below, since glob() does NOT allow you to do -# exclusion on multiple endings! -def file_doesnt_endwith(test,endings): - """Return true if test is a file and its name does NOT end with any - of the strings listed in endings.""" - if not isfile(test): - return False - for e in endings: - if test.endswith(e): - return False - return True - -############################################################################### -# Main code begins + +#------------------------------------------------------------------------------- +# Handle OS specific things +#------------------------------------------------------------------------------- if os.name == 'posix': os_name = 'posix' @@ -69,18 +62,17 @@ if os_name == 'windows' and 'sdist' in sys.argv: print 'The sdist command is not available under Windows. Exiting.' sys.exit(1) +#------------------------------------------------------------------------------- +# Things related to the IPython documentation +#------------------------------------------------------------------------------- + # update the manuals when building a source dist if len(sys.argv) >= 2 and sys.argv[1] in ('sdist','bdist_rpm'): import textwrap # List of things to be updated. Each entry is a triplet of args for # target_update() - to_update = [ # The do_sphinx scripts builds html and pdf, so just one - # target is enough to cover all manual generation - ('doc/manual/ipython.pdf', - ['IPython/Release.py','doc/source/ipython.rst'], - "cd doc && python do_sphinx.py" ), - + to_update = [ # FIXME - Disabled for now: we need to redo an automatic way # of generating the magic info inside the rst. #('doc/magic.tex', @@ -96,91 +88,82 @@ if len(sys.argv) >= 2 and sys.argv[1] in ('sdist','bdist_rpm'): "cd doc && gzip -9c pycolor.1 > pycolor.1.gz"), ] + try: + import sphinx + except ImportError: + pass + else: + # The do_sphinx scripts builds html and pdf, so just one + # target is enough to cover all manual generation + to_update.append( + ('doc/manual/ipython.pdf', + ['IPython/Release.py','doc/source/ipython.rst'], + "cd doc && python do_sphinx.py") + ) [ target_update(*t) for t in to_update ] -# Release.py contains version, authors, license, url, keywords, etc. -execfile(pjoin('IPython','Release.py')) - -# I can't find how to make distutils create a nested dir. structure, so -# in the meantime do it manually. Butt ugly. -# Note that http://www.redbrick.dcu.ie/~noel/distutils.html, ex. 2/3, contain -# information on how to do this more cleanly once python 2.4 can be assumed. -# Thanks to Noel for the tip. -docdirbase = 'share/doc/ipython' -manpagebase = 'share/man/man1' - -# We only need to exclude from this things NOT already excluded in the -# MANIFEST.in file. -exclude = ('.sh','.1.gz') -docfiles = filter(lambda f:file_doesnt_endwith(f,exclude),glob('doc/*')) -examfiles = filter(isfile, glob('doc/examples/*.py')) -manfiles = filter(isfile, glob('doc/manual/*')) -manstatic = filter(isfile, glob('doc/manual/_static/*')) -manpages = filter(isfile, glob('doc/*.1.gz')) - -cfgfiles = filter(isfile, glob('IPython/UserConfig/*')) -scriptfiles = filter(isfile, ['scripts/ipython','scripts/pycolor', - 'scripts/irunner']) - -igridhelpfiles = filter(isfile, glob('IPython/Extensions/igrid_help.*')) - -# Script to be run by the windows binary installer after the default setup -# routine, to add shortcuts and similar windows-only things. Windows -# post-install scripts MUST reside in the scripts/ dir, otherwise distutils -# doesn't find them. -if 'bdist_wininst' in sys.argv: - if len(sys.argv) > 2 and ('sdist' in sys.argv or 'bdist_rpm' in sys.argv): - print >> sys.stderr,"ERROR: bdist_wininst must be run alone. Exiting." - sys.exit(1) - scriptfiles.append('scripts/ipython_win_post_install.py') - -datafiles = [('data', docdirbase, docfiles), - ('data', pjoin(docdirbase, 'examples'),examfiles), - ('data', pjoin(docdirbase, 'manual'),manfiles), - ('data', pjoin(docdirbase, 'manual/_static'),manstatic), - ('data', manpagebase, manpages), - ('data',pjoin(docdirbase, 'extensions'),igridhelpfiles), - ] +#--------------------------------------------------------------------------- +# Find all the packages, package data, scripts and data_files +#--------------------------------------------------------------------------- + +packages = find_packages() +package_data = find_package_data() +scripts = find_scripts() +data_files = find_data_files() + +#--------------------------------------------------------------------------- +# Handle dependencies and setuptools specific things +#--------------------------------------------------------------------------- + +# This dict is used for passing extra arguments that are setuptools +# specific to setup +setuptools_extra_args = {} if 'setuptools' in sys.modules: - # setuptools config for egg building - egg_extra_kwds = { - 'entry_points': { - 'console_scripts': [ + setuptools_extra_args['zip_safe'] = False + setuptools_extra_args['entry_points'] = { + 'console_scripts': [ 'ipython = IPython.ipapi:launch_new_instance', - 'pycolor = IPython.PyColorize:main' - ]} - } - scriptfiles = [] + 'pycolor = IPython.PyColorize:main', + 'ipcontroller = IPython.kernel.scripts.ipcontroller:main', + 'ipengine = IPython.kernel.scripts.ipengine:main', + 'ipcluster = IPython.kernel.scripts.ipcluster:main' + ] + } + setup_args["extras_require"] = dict( + kernel = [ + "zope.interface>=3.4.1", + "Twisted>=8.0.1", + "foolscap>=0.2.6" + ], + doc=['Sphinx>=0.3','pygments'], + test='nose>=0.10.1', + security=["pyOpenSSL>=0.6"] + ) + # Allow setuptools to handle the scripts + scripts = [] # eggs will lack docs, examples - datafiles = [] + data_files = [] else: - # Normal, non-setuptools install - egg_extra_kwds = {} # package_data of setuptools was introduced to distutils in 2.4 + cfgfiles = filter(isfile, glob('IPython/UserConfig/*')) if sys.version_info < (2,4): - datafiles.append(('lib', 'IPython/UserConfig', cfgfiles)) - -# Call the setup() routine which does most of the work -setup(name = name, - version = version, - description = description, - long_description = long_description, - author = authors['Fernando'][0], - author_email = authors['Fernando'][1], - url = url, - download_url = download_url, - license = license, - platforms = platforms, - keywords = keywords, - packages = ['IPython', 'IPython.Extensions', 'IPython.external', - 'IPython.gui', 'IPython.gui.wx', - 'IPython.UserConfig'], - scripts = scriptfiles, - package_data = {'IPython.UserConfig' : ['*'] }, - - cmdclass = {'install_data': install_data_ext}, - data_files = datafiles, - # extra params needed for eggs - **egg_extra_kwds - ) + data_files.append(('lib', 'IPython/UserConfig', cfgfiles)) + # If we are running without setuptools, call this function which will + # check for dependencies an inform the user what is needed. This is + # just to make life easy for users. + check_for_dependencies() + + +#--------------------------------------------------------------------------- +# Do the actual setup now +#--------------------------------------------------------------------------- + +setup_args['packages'] = packages +setup_args['package_data'] = package_data +setup_args['scripts'] = scripts +setup_args['data_files'] = data_files +setup_args.update(setuptools_extra_args) + +if __name__ == '__main__': + setup(**setup_args) diff --git a/setupbase.py b/setupbase.py new file mode 100644 index 0000000..bb90de9 --- /dev/null +++ b/setupbase.py @@ -0,0 +1,230 @@ +# encoding: utf-8 + +""" +This module defines the things that are used in setup.py for building IPython + +This includes: + + * The basic arguments to setup + * Functions for finding things like packages, package data, etc. + * A function for checking dependencies. +""" + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import os, sys + +from glob import glob + +from setupext import install_data_ext + +#------------------------------------------------------------------------------- +# Useful globals and utility functions +#------------------------------------------------------------------------------- + +# A few handy globals +isfile = os.path.isfile +pjoin = os.path.join + +def oscmd(s): + print ">", s + os.system(s) + +# A little utility we'll need below, since glob() does NOT allow you to do +# exclusion on multiple endings! +def file_doesnt_endwith(test,endings): + """Return true if test is a file and its name does NOT end with any + of the strings listed in endings.""" + if not isfile(test): + return False + for e in endings: + if test.endswith(e): + return False + return True + +#--------------------------------------------------------------------------- +# Basic project information +#--------------------------------------------------------------------------- + +# Release.py contains version, authors, license, url, keywords, etc. +execfile(pjoin('IPython','Release.py')) + +# Create a dict with the basic information +# This dict is eventually passed to setup after additional keys are added. +setup_args = dict( + name = name, + version = version, + description = description, + long_description = long_description, + author = author, + author_email = author_email, + url = url, + download_url = download_url, + license = license, + platforms = platforms, + keywords = keywords, + cmdclass = {'install_data': install_data_ext}, + ) + + +#--------------------------------------------------------------------------- +# Find packages +#--------------------------------------------------------------------------- + +def add_package(packages, pname, config=False, tests=False, scripts=False, others=None): + """ + Add a package to the list of packages, including certain subpackages. + """ + packages.append('.'.join(['IPython',pname])) + if config: + packages.append('.'.join(['IPython',pname,'config'])) + if tests: + packages.append('.'.join(['IPython',pname,'tests'])) + if scripts: + packages.append('.'.join(['IPython',pname,'scripts'])) + if others is not None: + for o in others: + packages.append('.'.join(['IPython',pname,o])) + +def find_packages(): + """ + Find all of IPython's packages. + """ + packages = ['IPython'] + add_package(packages, 'config', tests=True) + add_package(packages , 'Extensions') + add_package(packages, 'external') + add_package(packages, 'gui') + add_package(packages, 'gui.wx') + add_package(packages, 'kernel', config=True, tests=True, scripts=True) + add_package(packages, 'kernel.core', config=True, tests=True) + add_package(packages, 'testing', tests=True) + add_package(packages, 'tools', tests=True) + add_package(packages, 'UserConfig') + return packages + +#--------------------------------------------------------------------------- +# Find package data +#--------------------------------------------------------------------------- + +def find_package_data(): + """ + Find IPython's package_data. + """ + # This is not enough for these things to appear in an sdist. + # We need to muck with the MANIFEST to get this to work + package_data = {'IPython.UserConfig' : ['*'] } + return package_data + + +#--------------------------------------------------------------------------- +# Find data files +#--------------------------------------------------------------------------- + +def find_data_files(): + """ + Find IPython's data_files. + """ + + # I can't find how to make distutils create a nested dir. structure, so + # in the meantime do it manually. Butt ugly. + # Note that http://www.redbrick.dcu.ie/~noel/distutils.html, ex. 2/3, contain + # information on how to do this more cleanly once python 2.4 can be assumed. + # Thanks to Noel for the tip. + docdirbase = 'share/doc/ipython' + manpagebase = 'share/man/man1' + + # We only need to exclude from this things NOT already excluded in the + # MANIFEST.in file. + exclude = ('.sh','.1.gz') + docfiles = filter(lambda f:file_doesnt_endwith(f,exclude),glob('doc/*')) + examfiles = filter(isfile, glob('doc/examples/*.py')) + manfiles = filter(isfile, glob('doc/manual/*')) + manstatic = filter(isfile, glob('doc/manual/_static/*')) + manpages = filter(isfile, glob('doc/*.1.gz')) + scriptfiles = filter(isfile, ['scripts/ipython','scripts/pycolor', + 'scripts/irunner']) + igridhelpfiles = filter(isfile, glob('IPython/Extensions/igrid_help.*')) + + data_files = [('data', docdirbase, docfiles), + ('data', pjoin(docdirbase, 'examples'),examfiles), + ('data', pjoin(docdirbase, 'manual'),manfiles), + ('data', pjoin(docdirbase, 'manual/_static'),manstatic), + ('data', manpagebase, manpages), + ('data',pjoin(docdirbase, 'extensions'),igridhelpfiles), + ] + return data_files + +#--------------------------------------------------------------------------- +# Find scripts +#--------------------------------------------------------------------------- + +def find_scripts(): + """ + Find IPython's scripts. + """ + scripts = [] + scripts.append('ipython/kernel/scripts/ipengine') + scripts.append('ipython/kernel/scripts/ipcontroller') + scripts.append('ipython/kernel/scripts/ipcluster') + scripts.append('scripts/ipython') + scripts.append('scripts/pycolor') + scripts.append('scripts/irunner') + + # Script to be run by the windows binary installer after the default setup + # routine, to add shortcuts and similar windows-only things. Windows + # post-install scripts MUST reside in the scripts/ dir, otherwise distutils + # doesn't find them. + if 'bdist_wininst' in sys.argv: + if len(sys.argv) > 2 and ('sdist' in sys.argv or 'bdist_rpm' in sys.argv): + print >> sys.stderr,"ERROR: bdist_wininst must be run alone. Exiting." + sys.exit(1) + scripts.append('scripts/ipython_win_post_install.py') + + return scripts + +#--------------------------------------------------------------------------- +# Find scripts +#--------------------------------------------------------------------------- + +def check_for_dependencies(): + """Check for IPython's dependencies. + + This function should NOT be called if running under setuptools! + """ + from setupext.setupext import ( + print_line, print_raw, print_status, print_message, + check_for_zopeinterface, check_for_twisted, + check_for_foolscap, check_for_pyopenssl, + check_for_sphinx, check_for_pygments, + check_for_nose, check_for_pexpect + ) + print_line() + print_raw("BUILDING IPYTHON") + print_status('python', sys.version) + print_status('platform', sys.platform) + if sys.platform == 'win32': + print_status('Windows version', sys.getwindowsversion()) + + print_raw("") + print_raw("OPTIONAL DEPENDENCIES") + + check_for_zopeinterface() + check_for_twisted() + check_for_foolscap() + check_for_pyopenssl() + check_for_sphinx() + check_for_pygments() + check_for_nose() + check_for_pexpect() \ No newline at end of file diff --git a/eggsetup.py b/setupegg.py similarity index 87% rename from eggsetup.py rename to setupegg.py index 5070bce..c59ea94 100755 --- a/eggsetup.py +++ b/setupegg.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -"""Wrapper to build IPython as an egg (setuptools format).""" +"""Wrapper to run setup.py using setuptools.""" import os import sys @@ -11,7 +11,7 @@ sys.path.insert(0,'%s/usr/local/lib/python%s/site-packages' % # now, import setuptools and call the actual setup import setuptools -print sys.argv +# print sys.argv #sys.argv=['','bdist_egg'] execfile('setup.py') diff --git a/exesetup.py b/setupexe.py similarity index 100% rename from exesetup.py rename to setupexe.py diff --git a/setupext/setupext.py b/setupext/setupext.py new file mode 100644 index 0000000..e7b743b --- /dev/null +++ b/setupext/setupext.py @@ -0,0 +1,178 @@ +# encoding: utf-8 + +__docformat__ = "restructuredtext en" + +#------------------------------------------------------------------------------- +# Copyright (C) 2008 The IPython Development Team +# +# Distributed under the terms of the BSD License. The full license is in +# the file COPYING, distributed as part of this software. +#------------------------------------------------------------------------------- + +#------------------------------------------------------------------------------- +# Imports +#------------------------------------------------------------------------------- + +import sys, os +from textwrap import fill + +display_status=True + +if display_status: + def print_line(char='='): + print char * 76 + + def print_status(package, status): + initial_indent = "%22s: " % package + indent = ' ' * 24 + print fill(str(status), width=76, + initial_indent=initial_indent, + subsequent_indent=indent) + + def print_message(message): + indent = ' ' * 24 + "* " + print fill(str(message), width=76, + initial_indent=indent, + subsequent_indent=indent) + + def print_raw(section): + print section +else: + def print_line(*args, **kwargs): + pass + print_status = print_message = print_raw = print_line + +#------------------------------------------------------------------------------- +# Tests for specific packages +#------------------------------------------------------------------------------- + +def check_for_ipython(): + try: + import IPython + except ImportError: + print_status("IPython", "Not found") + return False + else: + print_status("IPython", IPython.__version__) + return True + +def check_for_zopeinterface(): + try: + import zope.interface + except ImportError: + print_status("zope.Interface", "Not found (required for parallel computing capabilities)") + return False + else: + print_status("Zope.Interface","yes") + return True + +def check_for_twisted(): + try: + import twisted + except ImportError: + print_status("Twisted", "Not found (required for parallel computing capabilities)") + return False + else: + major = twisted.version.major + minor = twisted.version.minor + micro = twisted.version.micro + print_status("Twisted", twisted.version.short()) + if not ((major==2 and minor>=5 and micro>=0) or \ + major>=8): + print_message("WARNING: IPython requires Twisted 2.5.0 or greater, you have version %s"%twisted.version.short()) + print_message("Twisted is required for parallel computing capabilities") + return False + else: + return True + +def check_for_foolscap(): + try: + import foolscap + except ImportError: + print_status('Foolscap', "Not found (required for parallel computing capabilities)") + return False + else: + print_status('Foolscap', foolscap.__version__) + return True + +def check_for_pyopenssl(): + try: + import OpenSSL + except ImportError: + print_status('OpenSSL', "Not found (required if you want security in the parallel computing capabilities)") + return False + else: + print_status('OpenSSL', OpenSSL.__version__) + return True + +def check_for_sphinx(): + try: + import sphinx + except ImportError: + print_status('sphinx', "Not found (required for building documentation)") + return False + else: + print_status('sphinx', sphinx.__version__) + return True + +def check_for_pygments(): + try: + import pygments + except ImportError: + print_status('pygments', "Not found (required for syntax highlighting documentation)") + return False + else: + print_status('pygments', pygments.__version__) + return True + +def check_for_nose(): + try: + import nose + except ImportError: + print_status('nose', "Not found (required for running the test suite)") + return False + else: + print_status('nose', nose.__version__) + return True + +def check_for_pexpect(): + try: + import pexpect + except ImportError: + print_status("pexpect", "no (required for running standalone doctests)") + return False + else: + print_status("pexpect", pexpect.__version__) + return True + +def check_for_httplib2(): + try: + import httplib2 + except ImportError: + print_status("httplib2", "no (required for blocking http clients)") + return False + else: + print_status("httplib2","yes") + return True + +def check_for_sqlalchemy(): + try: + import sqlalchemy + except ImportError: + print_status("sqlalchemy", "no (required for the ipython1 notebook)") + return False + else: + print_status("sqlalchemy","yes") + return True + +def check_for_simplejson(): + try: + import simplejson + except ImportError: + print_status("simplejson", "no (required for the ipython1 notebook)") + return False + else: + print_status("simplejson","yes") + return True + + \ No newline at end of file