##// END OF EJS Templates
Merge pull request #2462 from takluyver/extensions-loaded...
Bradley M. Froehle -
r8658:ddeb9bb3 merge
parent child Browse files
Show More
@@ -0,0 +1,73
1 import os.path
2
3 import nose.tools as nt
4
5 import IPython.testing.tools as tt
6 from IPython.utils.syspathcontext import prepended_to_syspath
7 from IPython.utils.tempdir import TemporaryDirectory
8
9 ext1_content = """
10 def load_ipython_extension(ip):
11 print("Running ext1 load")
12
13 def unload_ipython_extension(ip):
14 print("Running ext1 unload")
15 """
16
17 ext2_content = """
18 def load_ipython_extension(ip):
19 print("Running ext2 load")
20 """
21
22 def test_extension_loading():
23 em = get_ipython().extension_manager
24 with TemporaryDirectory() as td:
25 ext1 = os.path.join(td, 'ext1.py')
26 with open(ext1, 'w') as f:
27 f.write(ext1_content)
28
29 ext2 = os.path.join(td, 'ext2.py')
30 with open(ext2, 'w') as f:
31 f.write(ext2_content)
32
33 with prepended_to_syspath(td):
34 assert 'ext1' not in em.loaded
35 assert 'ext2' not in em.loaded
36
37 # Load extension
38 with tt.AssertPrints("Running ext1 load"):
39 assert em.load_extension('ext1') is None
40 assert 'ext1' in em.loaded
41
42 # Should refuse to load it again
43 with tt.AssertNotPrints("Running ext1 load"):
44 assert em.load_extension('ext1') == 'already loaded'
45
46 # Reload
47 with tt.AssertPrints("Running ext1 unload"):
48 with tt.AssertPrints("Running ext1 load", suppress=False):
49 em.reload_extension('ext1')
50
51 # Unload
52 with tt.AssertPrints("Running ext1 unload"):
53 assert em.unload_extension('ext1') is None
54
55 # Can't unload again
56 with tt.AssertNotPrints("Running ext1 unload"):
57 assert em.unload_extension('ext1') == 'not loaded'
58 assert em.unload_extension('ext2') == 'not loaded'
59
60 # Load extension 2
61 with tt.AssertPrints("Running ext2 load"):
62 assert em.load_extension('ext2') is None
63
64 # Can't unload this
65 assert em.unload_extension('ext2') == 'no unload function'
66
67 # But can reload it
68 with tt.AssertPrints("Running ext2 load"):
69 em.reload_extension('ext2')
70
71 def test_non_extension():
72 em = get_ipython().extension_manager
73 nt.assert_equal(em.load_extension('sys'), "no load function")
@@ -1,157 +1,184
1 # encoding: utf-8
1 # encoding: utf-8
2 """A class for managing IPython extensions.
2 """A class for managing IPython extensions.
3
3
4 Authors:
4 Authors:
5
5
6 * Brian Granger
6 * Brian Granger
7 """
7 """
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Copyright (C) 2010-2011 The IPython Development Team
10 # Copyright (C) 2010-2011 The IPython Development Team
11 #
11 #
12 # Distributed under the terms of the BSD License. The full license is in
12 # Distributed under the terms of the BSD License. The full license is in
13 # the file COPYING, distributed as part of this software.
13 # the file COPYING, distributed as part of this software.
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15
15
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17 # Imports
17 # Imports
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 import os
20 import os
21 from shutil import copyfile
21 from shutil import copyfile
22 import sys
22 import sys
23 from urllib import urlretrieve
23 from urllib import urlretrieve
24 from urlparse import urlparse
24 from urlparse import urlparse
25
25
26 from IPython.core.error import UsageError
26 from IPython.config.configurable import Configurable
27 from IPython.config.configurable import Configurable
27 from IPython.utils.traitlets import Instance
28 from IPython.utils.traitlets import Instance
29 from IPython.utils.py3compat import PY3
30 if PY3:
31 from imp import reload
28
32
29 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
30 # Main class
34 # Main class
31 #-----------------------------------------------------------------------------
35 #-----------------------------------------------------------------------------
32
36
33 class ExtensionManager(Configurable):
37 class ExtensionManager(Configurable):
34 """A class to manage IPython extensions.
38 """A class to manage IPython extensions.
35
39
36 An IPython extension is an importable Python module that has
40 An IPython extension is an importable Python module that has
37 a function with the signature::
41 a function with the signature::
38
42
39 def load_ipython_extension(ipython):
43 def load_ipython_extension(ipython):
40 # Do things with ipython
44 # Do things with ipython
41
45
42 This function is called after your extension is imported and the
46 This function is called after your extension is imported and the
43 currently active :class:`InteractiveShell` instance is passed as
47 currently active :class:`InteractiveShell` instance is passed as
44 the only argument. You can do anything you want with IPython at
48 the only argument. You can do anything you want with IPython at
45 that point, including defining new magic and aliases, adding new
49 that point, including defining new magic and aliases, adding new
46 components, etc.
50 components, etc.
47
51
48 The :func:`load_ipython_extension` will be called again is you
52 You can also optionaly define an :func:`unload_ipython_extension(ipython)`
49 load or reload the extension again. It is up to the extension
53 function, which will be called if the user unloads or reloads the extension.
50 author to add code to manage that.
54 The extension manager will only call :func:`load_ipython_extension` again
55 if the extension is reloaded.
51
56
52 You can put your extension modules anywhere you want, as long as
57 You can put your extension modules anywhere you want, as long as
53 they can be imported by Python's standard import mechanism. However,
58 they can be imported by Python's standard import mechanism. However,
54 to make it easy to write extensions, you can also put your extensions
59 to make it easy to write extensions, you can also put your extensions
55 in ``os.path.join(self.ipython_dir, 'extensions')``. This directory
60 in ``os.path.join(self.ipython_dir, 'extensions')``. This directory
56 is added to ``sys.path`` automatically.
61 is added to ``sys.path`` automatically.
57 """
62 """
58
63
59 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
64 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
60
65
61 def __init__(self, shell=None, config=None):
66 def __init__(self, shell=None, config=None):
62 super(ExtensionManager, self).__init__(shell=shell, config=config)
67 super(ExtensionManager, self).__init__(shell=shell, config=config)
63 self.shell.on_trait_change(
68 self.shell.on_trait_change(
64 self._on_ipython_dir_changed, 'ipython_dir'
69 self._on_ipython_dir_changed, 'ipython_dir'
65 )
70 )
71 self.loaded = set()
66
72
67 def __del__(self):
73 def __del__(self):
68 self.shell.on_trait_change(
74 self.shell.on_trait_change(
69 self._on_ipython_dir_changed, 'ipython_dir', remove=True
75 self._on_ipython_dir_changed, 'ipython_dir', remove=True
70 )
76 )
71
77
72 @property
78 @property
73 def ipython_extension_dir(self):
79 def ipython_extension_dir(self):
74 return os.path.join(self.shell.ipython_dir, u'extensions')
80 return os.path.join(self.shell.ipython_dir, u'extensions')
75
81
76 def _on_ipython_dir_changed(self):
82 def _on_ipython_dir_changed(self):
77 if not os.path.isdir(self.ipython_extension_dir):
83 if not os.path.isdir(self.ipython_extension_dir):
78 os.makedirs(self.ipython_extension_dir, mode = 0o777)
84 os.makedirs(self.ipython_extension_dir, mode = 0o777)
79
85
80 def load_extension(self, module_str):
86 def load_extension(self, module_str):
81 """Load an IPython extension by its module name.
87 """Load an IPython extension by its module name.
82
88
83 If :func:`load_ipython_extension` returns anything, this function
89 Returns the string "already loaded" if the extension is already loaded,
84 will return that object.
90 "no load function" if the module doesn't have a load_ipython_extension
91 function, or None if it succeeded.
85 """
92 """
93 if module_str in self.loaded:
94 return "already loaded"
95
86 from IPython.utils.syspathcontext import prepended_to_syspath
96 from IPython.utils.syspathcontext import prepended_to_syspath
87
97
88 if module_str not in sys.modules:
98 if module_str not in sys.modules:
89 with prepended_to_syspath(self.ipython_extension_dir):
99 with prepended_to_syspath(self.ipython_extension_dir):
90 __import__(module_str)
100 __import__(module_str)
91 mod = sys.modules[module_str]
101 mod = sys.modules[module_str]
92 return self._call_load_ipython_extension(mod)
102 if self._call_load_ipython_extension(mod):
103 self.loaded.add(module_str)
104 else:
105 return "no load function"
93
106
94 def unload_extension(self, module_str):
107 def unload_extension(self, module_str):
95 """Unload an IPython extension by its module name.
108 """Unload an IPython extension by its module name.
96
109
97 This function looks up the extension's name in ``sys.modules`` and
110 This function looks up the extension's name in ``sys.modules`` and
98 simply calls ``mod.unload_ipython_extension(self)``.
111 simply calls ``mod.unload_ipython_extension(self)``.
112
113 Returns the string "no unload function" if the extension doesn't define
114 a function to unload itself, "not loaded" if the extension isn't loaded,
115 otherwise None.
99 """
116 """
117 if module_str not in self.loaded:
118 return "not loaded"
119
100 if module_str in sys.modules:
120 if module_str in sys.modules:
101 mod = sys.modules[module_str]
121 mod = sys.modules[module_str]
102 self._call_unload_ipython_extension(mod)
122 if self._call_unload_ipython_extension(mod):
123 self.loaded.discard(module_str)
124 else:
125 return "no unload function"
103
126
104 def reload_extension(self, module_str):
127 def reload_extension(self, module_str):
105 """Reload an IPython extension by calling reload.
128 """Reload an IPython extension by calling reload.
106
129
107 If the module has not been loaded before,
130 If the module has not been loaded before,
108 :meth:`InteractiveShell.load_extension` is called. Otherwise
131 :meth:`InteractiveShell.load_extension` is called. Otherwise
109 :func:`reload` is called and then the :func:`load_ipython_extension`
132 :func:`reload` is called and then the :func:`load_ipython_extension`
110 function of the module, if it exists is called.
133 function of the module, if it exists is called.
111 """
134 """
112 from IPython.utils.syspathcontext import prepended_to_syspath
135 from IPython.utils.syspathcontext import prepended_to_syspath
113
136
114 with prepended_to_syspath(self.ipython_extension_dir):
137 if (module_str in self.loaded) and (module_str in sys.modules):
115 if module_str in sys.modules:
138 self.unload_extension(module_str)
116 mod = sys.modules[module_str]
139 mod = sys.modules[module_str]
140 with prepended_to_syspath(self.ipython_extension_dir):
117 reload(mod)
141 reload(mod)
118 self._call_load_ipython_extension(mod)
142 if self._call_load_ipython_extension(mod):
119 else:
143 self.loaded.add(module_str)
120 self.load_extension(module_str)
144 else:
145 self.load_extension(module_str)
121
146
122 def _call_load_ipython_extension(self, mod):
147 def _call_load_ipython_extension(self, mod):
123 if hasattr(mod, 'load_ipython_extension'):
148 if hasattr(mod, 'load_ipython_extension'):
124 return mod.load_ipython_extension(self.shell)
149 mod.load_ipython_extension(self.shell)
150 return True
125
151
126 def _call_unload_ipython_extension(self, mod):
152 def _call_unload_ipython_extension(self, mod):
127 if hasattr(mod, 'unload_ipython_extension'):
153 if hasattr(mod, 'unload_ipython_extension'):
128 return mod.unload_ipython_extension(self.shell)
154 mod.unload_ipython_extension(self.shell)
155 return True
129
156
130 def install_extension(self, url, filename=None):
157 def install_extension(self, url, filename=None):
131 """Download and install an IPython extension.
158 """Download and install an IPython extension.
132
159
133 If filename is given, the file will be so named (inside the extension
160 If filename is given, the file will be so named (inside the extension
134 directory). Otherwise, the name from the URL will be used. The file must
161 directory). Otherwise, the name from the URL will be used. The file must
135 have a .py or .zip extension; otherwise, a ValueError will be raised.
162 have a .py or .zip extension; otherwise, a ValueError will be raised.
136
163
137 Returns the full path to the installed file.
164 Returns the full path to the installed file.
138 """
165 """
139 # Ensure the extension directory exists
166 # Ensure the extension directory exists
140 if not os.path.isdir(self.ipython_extension_dir):
167 if not os.path.isdir(self.ipython_extension_dir):
141 os.makedirs(self.ipython_extension_dir, mode = 0o777)
168 os.makedirs(self.ipython_extension_dir, mode = 0o777)
142
169
143 if os.path.isfile(url):
170 if os.path.isfile(url):
144 src_filename = os.path.basename(url)
171 src_filename = os.path.basename(url)
145 copy = copyfile
172 copy = copyfile
146 else:
173 else:
147 src_filename = urlparse(url).path.split('/')[-1]
174 src_filename = urlparse(url).path.split('/')[-1]
148 copy = urlretrieve
175 copy = urlretrieve
149
176
150 if filename is None:
177 if filename is None:
151 filename = src_filename
178 filename = src_filename
152 if os.path.splitext(filename)[1] not in ('.py', '.zip'):
179 if os.path.splitext(filename)[1] not in ('.py', '.zip'):
153 raise ValueError("The file must have a .py or .zip extension", filename)
180 raise ValueError("The file must have a .py or .zip extension", filename)
154
181
155 filename = os.path.join(self.ipython_extension_dir, filename)
182 filename = os.path.join(self.ipython_extension_dir, filename)
156 copy(url, filename)
183 copy(url, filename)
157 return filename
184 return filename
@@ -1,76 +1,92
1 """Implementation of magic functions for the extension machinery.
1 """Implementation of magic functions for the extension machinery.
2 """
2 """
3 #-----------------------------------------------------------------------------
3 #-----------------------------------------------------------------------------
4 # Copyright (c) 2012 The IPython Development Team.
4 # Copyright (c) 2012 The IPython Development Team.
5 #
5 #
6 # Distributed under the terms of the Modified BSD License.
6 # Distributed under the terms of the Modified BSD License.
7 #
7 #
8 # The full license is in the file COPYING.txt, distributed with this software.
8 # The full license is in the file COPYING.txt, distributed with this software.
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10
10
11 #-----------------------------------------------------------------------------
11 #-----------------------------------------------------------------------------
12 # Imports
12 # Imports
13 #-----------------------------------------------------------------------------
13 #-----------------------------------------------------------------------------
14
14
15 # Stdlib
15 # Stdlib
16 import os
16 import os
17
17
18 # Our own packages
18 # Our own packages
19 from IPython.core.error import UsageError
19 from IPython.core.error import UsageError
20 from IPython.core.magic import Magics, magics_class, line_magic
20 from IPython.core.magic import Magics, magics_class, line_magic
21
21
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23 # Magic implementation classes
23 # Magic implementation classes
24 #-----------------------------------------------------------------------------
24 #-----------------------------------------------------------------------------
25
25
26 @magics_class
26 @magics_class
27 class ExtensionMagics(Magics):
27 class ExtensionMagics(Magics):
28 """Magics to manage the IPython extensions system."""
28 """Magics to manage the IPython extensions system."""
29
29
30 @line_magic
30 @line_magic
31 def install_ext(self, parameter_s=''):
31 def install_ext(self, parameter_s=''):
32 """Download and install an extension from a URL, e.g.::
32 """Download and install an extension from a URL, e.g.::
33
33
34 %install_ext https://bitbucket.org/birkenfeld/ipython-physics/raw/d1310a2ab15d/physics.py
34 %install_ext https://bitbucket.org/birkenfeld/ipython-physics/raw/d1310a2ab15d/physics.py
35
35
36 The URL should point to an importable Python module - either a .py file
36 The URL should point to an importable Python module - either a .py file
37 or a .zip file.
37 or a .zip file.
38
38
39 Parameters:
39 Parameters:
40
40
41 -n filename : Specify a name for the file, rather than taking it from
41 -n filename : Specify a name for the file, rather than taking it from
42 the URL.
42 the URL.
43 """
43 """
44 opts, args = self.parse_options(parameter_s, 'n:')
44 opts, args = self.parse_options(parameter_s, 'n:')
45 try:
45 try:
46 filename = self.shell.extension_manager.install_extension(args,
46 filename = self.shell.extension_manager.install_extension(args,
47 opts.get('n'))
47 opts.get('n'))
48 except ValueError as e:
48 except ValueError as e:
49 print e
49 print e
50 return
50 return
51
51
52 filename = os.path.basename(filename)
52 filename = os.path.basename(filename)
53 print "Installed %s. To use it, type:" % filename
53 print "Installed %s. To use it, type:" % filename
54 print " %%load_ext %s" % os.path.splitext(filename)[0]
54 print " %%load_ext %s" % os.path.splitext(filename)[0]
55
55
56
56
57 @line_magic
57 @line_magic
58 def load_ext(self, module_str):
58 def load_ext(self, module_str):
59 """Load an IPython extension by its module name."""
59 """Load an IPython extension by its module name."""
60 if not module_str:
60 if not module_str:
61 raise UsageError('Missing module name.')
61 raise UsageError('Missing module name.')
62 return self.shell.extension_manager.load_extension(module_str)
62 res = self.shell.extension_manager.load_extension(module_str)
63
64 if res == 'already loaded':
65 print "The %s extension is already loaded. To reload it, use:" % module_str
66 print " %reload_ext", module_str
67 elif res == 'no load function':
68 print "The %s module is not an IPython extension." % module_str
63
69
64 @line_magic
70 @line_magic
65 def unload_ext(self, module_str):
71 def unload_ext(self, module_str):
66 """Unload an IPython extension by its module name."""
72 """Unload an IPython extension by its module name.
73
74 Not all extensions can be unloaded, only those which define an
75 ``unload_ipython_extension`` function.
76 """
67 if not module_str:
77 if not module_str:
68 raise UsageError('Missing module name.')
78 raise UsageError('Missing module name.')
69 self.shell.extension_manager.unload_extension(module_str)
79
80 res = self.shell.extension_manager.unload_extension(module_str)
81
82 if res == 'no unload function':
83 print "The %s extension doesn't define how to unload it." % module_str
84 elif res == "not loaded":
85 print "The %s extension is not loaded." % module_str
70
86
71 @line_magic
87 @line_magic
72 def reload_ext(self, module_str):
88 def reload_ext(self, module_str):
73 """Reload an IPython extension by its module name."""
89 """Reload an IPython extension by its module name."""
74 if not module_str:
90 if not module_str:
75 raise UsageError('Missing module name.')
91 raise UsageError('Missing module name.')
76 self.shell.extension_manager.reload_extension(module_str)
92 self.shell.extension_manager.reload_extension(module_str)
@@ -1,527 +1,521
1 """IPython extension to reload modules before executing user code.
1 """IPython extension to reload modules before executing user code.
2
2
3 ``autoreload`` reloads modules automatically before entering the execution of
3 ``autoreload`` reloads modules automatically before entering the execution of
4 code typed at the IPython prompt.
4 code typed at the IPython prompt.
5
5
6 This makes for example the following workflow possible:
6 This makes for example the following workflow possible:
7
7
8 .. sourcecode:: ipython
8 .. sourcecode:: ipython
9
9
10 In [1]: %load_ext autoreload
10 In [1]: %load_ext autoreload
11
11
12 In [2]: %autoreload 2
12 In [2]: %autoreload 2
13
13
14 In [3]: from foo import some_function
14 In [3]: from foo import some_function
15
15
16 In [4]: some_function()
16 In [4]: some_function()
17 Out[4]: 42
17 Out[4]: 42
18
18
19 In [5]: # open foo.py in an editor and change some_function to return 43
19 In [5]: # open foo.py in an editor and change some_function to return 43
20
20
21 In [6]: some_function()
21 In [6]: some_function()
22 Out[6]: 43
22 Out[6]: 43
23
23
24 The module was reloaded without reloading it explicitly, and the object
24 The module was reloaded without reloading it explicitly, and the object
25 imported with ``from foo import ...`` was also updated.
25 imported with ``from foo import ...`` was also updated.
26
26
27 Usage
27 Usage
28 =====
28 =====
29
29
30 The following magic commands are provided:
30 The following magic commands are provided:
31
31
32 ``%autoreload``
32 ``%autoreload``
33
33
34 Reload all modules (except those excluded by ``%aimport``)
34 Reload all modules (except those excluded by ``%aimport``)
35 automatically now.
35 automatically now.
36
36
37 ``%autoreload 0``
37 ``%autoreload 0``
38
38
39 Disable automatic reloading.
39 Disable automatic reloading.
40
40
41 ``%autoreload 1``
41 ``%autoreload 1``
42
42
43 Reload all modules imported with ``%aimport`` every time before
43 Reload all modules imported with ``%aimport`` every time before
44 executing the Python code typed.
44 executing the Python code typed.
45
45
46 ``%autoreload 2``
46 ``%autoreload 2``
47
47
48 Reload all modules (except those excluded by ``%aimport``) every
48 Reload all modules (except those excluded by ``%aimport``) every
49 time before executing the Python code typed.
49 time before executing the Python code typed.
50
50
51 ``%aimport``
51 ``%aimport``
52
52
53 List modules which are to be automatically imported or not to be imported.
53 List modules which are to be automatically imported or not to be imported.
54
54
55 ``%aimport foo``
55 ``%aimport foo``
56
56
57 Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
57 Import module 'foo' and mark it to be autoreloaded for ``%autoreload 1``
58
58
59 ``%aimport -foo``
59 ``%aimport -foo``
60
60
61 Mark module 'foo' to not be autoreloaded.
61 Mark module 'foo' to not be autoreloaded.
62
62
63 Caveats
63 Caveats
64 =======
64 =======
65
65
66 Reloading Python modules in a reliable way is in general difficult,
66 Reloading Python modules in a reliable way is in general difficult,
67 and unexpected things may occur. ``%autoreload`` tries to work around
67 and unexpected things may occur. ``%autoreload`` tries to work around
68 common pitfalls by replacing function code objects and parts of
68 common pitfalls by replacing function code objects and parts of
69 classes previously in the module with new versions. This makes the
69 classes previously in the module with new versions. This makes the
70 following things to work:
70 following things to work:
71
71
72 - Functions and classes imported via 'from xxx import foo' are upgraded
72 - Functions and classes imported via 'from xxx import foo' are upgraded
73 to new versions when 'xxx' is reloaded.
73 to new versions when 'xxx' is reloaded.
74
74
75 - Methods and properties of classes are upgraded on reload, so that
75 - Methods and properties of classes are upgraded on reload, so that
76 calling 'c.foo()' on an object 'c' created before the reload causes
76 calling 'c.foo()' on an object 'c' created before the reload causes
77 the new code for 'foo' to be executed.
77 the new code for 'foo' to be executed.
78
78
79 Some of the known remaining caveats are:
79 Some of the known remaining caveats are:
80
80
81 - Replacing code objects does not always succeed: changing a @property
81 - Replacing code objects does not always succeed: changing a @property
82 in a class to an ordinary method or a method to a member variable
82 in a class to an ordinary method or a method to a member variable
83 can cause problems (but in old objects only).
83 can cause problems (but in old objects only).
84
84
85 - Functions that are removed (eg. via monkey-patching) from a module
85 - Functions that are removed (eg. via monkey-patching) from a module
86 before it is reloaded are not upgraded.
86 before it is reloaded are not upgraded.
87
87
88 - C extension modules cannot be reloaded, and so cannot be autoreloaded.
88 - C extension modules cannot be reloaded, and so cannot be autoreloaded.
89 """
89 """
90 from __future__ import print_function
90 from __future__ import print_function
91
91
92 skip_doctest = True
92 skip_doctest = True
93
93
94 #-----------------------------------------------------------------------------
94 #-----------------------------------------------------------------------------
95 # Copyright (C) 2000 Thomas Heller
95 # Copyright (C) 2000 Thomas Heller
96 # Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
96 # Copyright (C) 2008 Pauli Virtanen <pav@iki.fi>
97 # Copyright (C) 2012 The IPython Development Team
97 # Copyright (C) 2012 The IPython Development Team
98 #
98 #
99 # Distributed under the terms of the BSD License. The full license is in
99 # Distributed under the terms of the BSD License. The full license is in
100 # the file COPYING, distributed as part of this software.
100 # the file COPYING, distributed as part of this software.
101 #-----------------------------------------------------------------------------
101 #-----------------------------------------------------------------------------
102 #
102 #
103 # This IPython module is written by Pauli Virtanen, based on the autoreload
103 # This IPython module is written by Pauli Virtanen, based on the autoreload
104 # code by Thomas Heller.
104 # code by Thomas Heller.
105
105
106 #-----------------------------------------------------------------------------
106 #-----------------------------------------------------------------------------
107 # Imports
107 # Imports
108 #-----------------------------------------------------------------------------
108 #-----------------------------------------------------------------------------
109
109
110 import imp
110 import imp
111 import os
111 import os
112 import sys
112 import sys
113 import traceback
113 import traceback
114 import types
114 import types
115 import weakref
115 import weakref
116
116
117 try:
117 try:
118 # Reload is not defined by default in Python3.
118 # Reload is not defined by default in Python3.
119 reload
119 reload
120 except NameError:
120 except NameError:
121 from imp import reload
121 from imp import reload
122
122
123 from IPython.utils import pyfile
123 from IPython.utils import pyfile
124 from IPython.utils.py3compat import PY3
124 from IPython.utils.py3compat import PY3
125
125
126 #------------------------------------------------------------------------------
126 #------------------------------------------------------------------------------
127 # Autoreload functionality
127 # Autoreload functionality
128 #------------------------------------------------------------------------------
128 #------------------------------------------------------------------------------
129
129
130 def _get_compiled_ext():
130 def _get_compiled_ext():
131 """Official way to get the extension of compiled files (.pyc or .pyo)"""
131 """Official way to get the extension of compiled files (.pyc or .pyo)"""
132 for ext, mode, typ in imp.get_suffixes():
132 for ext, mode, typ in imp.get_suffixes():
133 if typ == imp.PY_COMPILED:
133 if typ == imp.PY_COMPILED:
134 return ext
134 return ext
135
135
136
136
137 PY_COMPILED_EXT = _get_compiled_ext()
137 PY_COMPILED_EXT = _get_compiled_ext()
138
138
139
139
140 class ModuleReloader(object):
140 class ModuleReloader(object):
141 enabled = False
141 enabled = False
142 """Whether this reloader is enabled"""
142 """Whether this reloader is enabled"""
143
143
144 failed = {}
144 failed = {}
145 """Modules that failed to reload: {module: mtime-on-failed-reload, ...}"""
145 """Modules that failed to reload: {module: mtime-on-failed-reload, ...}"""
146
146
147 modules = {}
147 modules = {}
148 """Modules specially marked as autoreloadable."""
148 """Modules specially marked as autoreloadable."""
149
149
150 skip_modules = {}
150 skip_modules = {}
151 """Modules specially marked as not autoreloadable."""
151 """Modules specially marked as not autoreloadable."""
152
152
153 check_all = True
153 check_all = True
154 """Autoreload all modules, not just those listed in 'modules'"""
154 """Autoreload all modules, not just those listed in 'modules'"""
155
155
156 old_objects = {}
156 old_objects = {}
157 """(module-name, name) -> weakref, for replacing old code objects"""
157 """(module-name, name) -> weakref, for replacing old code objects"""
158
158
159 def mark_module_skipped(self, module_name):
159 def mark_module_skipped(self, module_name):
160 """Skip reloading the named module in the future"""
160 """Skip reloading the named module in the future"""
161 try:
161 try:
162 del self.modules[module_name]
162 del self.modules[module_name]
163 except KeyError:
163 except KeyError:
164 pass
164 pass
165 self.skip_modules[module_name] = True
165 self.skip_modules[module_name] = True
166
166
167 def mark_module_reloadable(self, module_name):
167 def mark_module_reloadable(self, module_name):
168 """Reload the named module in the future (if it is imported)"""
168 """Reload the named module in the future (if it is imported)"""
169 try:
169 try:
170 del self.skip_modules[module_name]
170 del self.skip_modules[module_name]
171 except KeyError:
171 except KeyError:
172 pass
172 pass
173 self.modules[module_name] = True
173 self.modules[module_name] = True
174
174
175 def aimport_module(self, module_name):
175 def aimport_module(self, module_name):
176 """Import a module, and mark it reloadable
176 """Import a module, and mark it reloadable
177
177
178 Returns
178 Returns
179 -------
179 -------
180 top_module : module
180 top_module : module
181 The imported module if it is top-level, or the top-level
181 The imported module if it is top-level, or the top-level
182 top_name : module
182 top_name : module
183 Name of top_module
183 Name of top_module
184
184
185 """
185 """
186 self.mark_module_reloadable(module_name)
186 self.mark_module_reloadable(module_name)
187
187
188 __import__(module_name)
188 __import__(module_name)
189 top_name = module_name.split('.')[0]
189 top_name = module_name.split('.')[0]
190 top_module = sys.modules[top_name]
190 top_module = sys.modules[top_name]
191 return top_module, top_name
191 return top_module, top_name
192
192
193 def check(self, check_all=False):
193 def check(self, check_all=False):
194 """Check whether some modules need to be reloaded."""
194 """Check whether some modules need to be reloaded."""
195
195
196 if not self.enabled and not check_all:
196 if not self.enabled and not check_all:
197 return
197 return
198
198
199 if check_all or self.check_all:
199 if check_all or self.check_all:
200 modules = sys.modules.keys()
200 modules = sys.modules.keys()
201 else:
201 else:
202 modules = self.modules.keys()
202 modules = self.modules.keys()
203
203
204 for modname in modules:
204 for modname in modules:
205 m = sys.modules.get(modname, None)
205 m = sys.modules.get(modname, None)
206
206
207 if modname in self.skip_modules:
207 if modname in self.skip_modules:
208 continue
208 continue
209
209
210 if not hasattr(m, '__file__'):
210 if not hasattr(m, '__file__'):
211 continue
211 continue
212
212
213 if m.__name__ == '__main__':
213 if m.__name__ == '__main__':
214 # we cannot reload(__main__)
214 # we cannot reload(__main__)
215 continue
215 continue
216
216
217 filename = m.__file__
217 filename = m.__file__
218 path, ext = os.path.splitext(filename)
218 path, ext = os.path.splitext(filename)
219
219
220 if ext.lower() == '.py':
220 if ext.lower() == '.py':
221 ext = PY_COMPILED_EXT
221 ext = PY_COMPILED_EXT
222 pyc_filename = pyfile.cache_from_source(filename)
222 pyc_filename = pyfile.cache_from_source(filename)
223 py_filename = filename
223 py_filename = filename
224 else:
224 else:
225 pyc_filename = filename
225 pyc_filename = filename
226 try:
226 try:
227 py_filename = pyfile.source_from_cache(filename)
227 py_filename = pyfile.source_from_cache(filename)
228 except ValueError:
228 except ValueError:
229 continue
229 continue
230
230
231 try:
231 try:
232 pymtime = os.stat(py_filename).st_mtime
232 pymtime = os.stat(py_filename).st_mtime
233 if pymtime <= os.stat(pyc_filename).st_mtime:
233 if pymtime <= os.stat(pyc_filename).st_mtime:
234 continue
234 continue
235 if self.failed.get(py_filename, None) == pymtime:
235 if self.failed.get(py_filename, None) == pymtime:
236 continue
236 continue
237 except OSError:
237 except OSError:
238 continue
238 continue
239
239
240 try:
240 try:
241 superreload(m, reload, self.old_objects)
241 superreload(m, reload, self.old_objects)
242 if py_filename in self.failed:
242 if py_filename in self.failed:
243 del self.failed[py_filename]
243 del self.failed[py_filename]
244 except:
244 except:
245 print("[autoreload of %s failed: %s]" % (
245 print("[autoreload of %s failed: %s]" % (
246 modname, traceback.format_exc(1)), file=sys.stderr)
246 modname, traceback.format_exc(1)), file=sys.stderr)
247 self.failed[py_filename] = pymtime
247 self.failed[py_filename] = pymtime
248
248
249 #------------------------------------------------------------------------------
249 #------------------------------------------------------------------------------
250 # superreload
250 # superreload
251 #------------------------------------------------------------------------------
251 #------------------------------------------------------------------------------
252
252
253 if PY3:
253 if PY3:
254 func_attrs = ['__code__', '__defaults__', '__doc__',
254 func_attrs = ['__code__', '__defaults__', '__doc__',
255 '__closure__', '__globals__', '__dict__']
255 '__closure__', '__globals__', '__dict__']
256 else:
256 else:
257 func_attrs = ['func_code', 'func_defaults', 'func_doc',
257 func_attrs = ['func_code', 'func_defaults', 'func_doc',
258 'func_closure', 'func_globals', 'func_dict']
258 'func_closure', 'func_globals', 'func_dict']
259
259
260
260
261 def update_function(old, new):
261 def update_function(old, new):
262 """Upgrade the code object of a function"""
262 """Upgrade the code object of a function"""
263 for name in func_attrs:
263 for name in func_attrs:
264 try:
264 try:
265 setattr(old, name, getattr(new, name))
265 setattr(old, name, getattr(new, name))
266 except (AttributeError, TypeError):
266 except (AttributeError, TypeError):
267 pass
267 pass
268
268
269
269
270 def update_class(old, new):
270 def update_class(old, new):
271 """Replace stuff in the __dict__ of a class, and upgrade
271 """Replace stuff in the __dict__ of a class, and upgrade
272 method code objects"""
272 method code objects"""
273 for key in old.__dict__.keys():
273 for key in old.__dict__.keys():
274 old_obj = getattr(old, key)
274 old_obj = getattr(old, key)
275
275
276 try:
276 try:
277 new_obj = getattr(new, key)
277 new_obj = getattr(new, key)
278 except AttributeError:
278 except AttributeError:
279 # obsolete attribute: remove it
279 # obsolete attribute: remove it
280 try:
280 try:
281 delattr(old, key)
281 delattr(old, key)
282 except (AttributeError, TypeError):
282 except (AttributeError, TypeError):
283 pass
283 pass
284 continue
284 continue
285
285
286 if update_generic(old_obj, new_obj): continue
286 if update_generic(old_obj, new_obj): continue
287
287
288 try:
288 try:
289 setattr(old, key, getattr(new, key))
289 setattr(old, key, getattr(new, key))
290 except (AttributeError, TypeError):
290 except (AttributeError, TypeError):
291 pass # skip non-writable attributes
291 pass # skip non-writable attributes
292
292
293
293
294 def update_property(old, new):
294 def update_property(old, new):
295 """Replace get/set/del functions of a property"""
295 """Replace get/set/del functions of a property"""
296 update_generic(old.fdel, new.fdel)
296 update_generic(old.fdel, new.fdel)
297 update_generic(old.fget, new.fget)
297 update_generic(old.fget, new.fget)
298 update_generic(old.fset, new.fset)
298 update_generic(old.fset, new.fset)
299
299
300
300
301 def isinstance2(a, b, typ):
301 def isinstance2(a, b, typ):
302 return isinstance(a, typ) and isinstance(b, typ)
302 return isinstance(a, typ) and isinstance(b, typ)
303
303
304
304
305 UPDATE_RULES = [
305 UPDATE_RULES = [
306 (lambda a, b: isinstance2(a, b, type),
306 (lambda a, b: isinstance2(a, b, type),
307 update_class),
307 update_class),
308 (lambda a, b: isinstance2(a, b, types.FunctionType),
308 (lambda a, b: isinstance2(a, b, types.FunctionType),
309 update_function),
309 update_function),
310 (lambda a, b: isinstance2(a, b, property),
310 (lambda a, b: isinstance2(a, b, property),
311 update_property),
311 update_property),
312 ]
312 ]
313
313
314
314
315 if PY3:
315 if PY3:
316 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.MethodType),
316 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.MethodType),
317 lambda a, b: update_function(a.__func__, b.__func__)),
317 lambda a, b: update_function(a.__func__, b.__func__)),
318 ])
318 ])
319 else:
319 else:
320 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.ClassType),
320 UPDATE_RULES.extend([(lambda a, b: isinstance2(a, b, types.ClassType),
321 update_class),
321 update_class),
322 (lambda a, b: isinstance2(a, b, types.MethodType),
322 (lambda a, b: isinstance2(a, b, types.MethodType),
323 lambda a, b: update_function(a.im_func, b.im_func)),
323 lambda a, b: update_function(a.im_func, b.im_func)),
324 ])
324 ])
325
325
326
326
327 def update_generic(a, b):
327 def update_generic(a, b):
328 for type_check, update in UPDATE_RULES:
328 for type_check, update in UPDATE_RULES:
329 if type_check(a, b):
329 if type_check(a, b):
330 update(a, b)
330 update(a, b)
331 return True
331 return True
332 return False
332 return False
333
333
334
334
335 class StrongRef(object):
335 class StrongRef(object):
336 def __init__(self, obj):
336 def __init__(self, obj):
337 self.obj = obj
337 self.obj = obj
338 def __call__(self):
338 def __call__(self):
339 return self.obj
339 return self.obj
340
340
341
341
342 def superreload(module, reload=reload, old_objects={}):
342 def superreload(module, reload=reload, old_objects={}):
343 """Enhanced version of the builtin reload function.
343 """Enhanced version of the builtin reload function.
344
344
345 superreload remembers objects previously in the module, and
345 superreload remembers objects previously in the module, and
346
346
347 - upgrades the class dictionary of every old class in the module
347 - upgrades the class dictionary of every old class in the module
348 - upgrades the code object of every old function and method
348 - upgrades the code object of every old function and method
349 - clears the module's namespace before reloading
349 - clears the module's namespace before reloading
350
350
351 """
351 """
352
352
353 # collect old objects in the module
353 # collect old objects in the module
354 for name, obj in module.__dict__.items():
354 for name, obj in module.__dict__.items():
355 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
355 if not hasattr(obj, '__module__') or obj.__module__ != module.__name__:
356 continue
356 continue
357 key = (module.__name__, name)
357 key = (module.__name__, name)
358 try:
358 try:
359 old_objects.setdefault(key, []).append(weakref.ref(obj))
359 old_objects.setdefault(key, []).append(weakref.ref(obj))
360 except TypeError:
360 except TypeError:
361 # weakref doesn't work for all types;
361 # weakref doesn't work for all types;
362 # create strong references for 'important' cases
362 # create strong references for 'important' cases
363 if not PY3 and isinstance(obj, types.ClassType):
363 if not PY3 and isinstance(obj, types.ClassType):
364 old_objects.setdefault(key, []).append(StrongRef(obj))
364 old_objects.setdefault(key, []).append(StrongRef(obj))
365
365
366 # reload module
366 # reload module
367 try:
367 try:
368 # clear namespace first from old cruft
368 # clear namespace first from old cruft
369 old_dict = module.__dict__.copy()
369 old_dict = module.__dict__.copy()
370 old_name = module.__name__
370 old_name = module.__name__
371 module.__dict__.clear()
371 module.__dict__.clear()
372 module.__dict__['__name__'] = old_name
372 module.__dict__['__name__'] = old_name
373 module.__dict__['__loader__'] = old_dict['__loader__']
373 module.__dict__['__loader__'] = old_dict['__loader__']
374 except (TypeError, AttributeError, KeyError):
374 except (TypeError, AttributeError, KeyError):
375 pass
375 pass
376
376
377 try:
377 try:
378 module = reload(module)
378 module = reload(module)
379 except:
379 except:
380 # restore module dictionary on failed reload
380 # restore module dictionary on failed reload
381 module.__dict__.update(old_dict)
381 module.__dict__.update(old_dict)
382 raise
382 raise
383
383
384 # iterate over all objects and update functions & classes
384 # iterate over all objects and update functions & classes
385 for name, new_obj in module.__dict__.items():
385 for name, new_obj in module.__dict__.items():
386 key = (module.__name__, name)
386 key = (module.__name__, name)
387 if key not in old_objects: continue
387 if key not in old_objects: continue
388
388
389 new_refs = []
389 new_refs = []
390 for old_ref in old_objects[key]:
390 for old_ref in old_objects[key]:
391 old_obj = old_ref()
391 old_obj = old_ref()
392 if old_obj is None: continue
392 if old_obj is None: continue
393 new_refs.append(old_ref)
393 new_refs.append(old_ref)
394 update_generic(old_obj, new_obj)
394 update_generic(old_obj, new_obj)
395
395
396 if new_refs:
396 if new_refs:
397 old_objects[key] = new_refs
397 old_objects[key] = new_refs
398 else:
398 else:
399 del old_objects[key]
399 del old_objects[key]
400
400
401 return module
401 return module
402
402
403 #------------------------------------------------------------------------------
403 #------------------------------------------------------------------------------
404 # IPython connectivity
404 # IPython connectivity
405 #------------------------------------------------------------------------------
405 #------------------------------------------------------------------------------
406
406
407 from IPython.core.hooks import TryNext
407 from IPython.core.hooks import TryNext
408 from IPython.core.magic import Magics, magics_class, line_magic
408 from IPython.core.magic import Magics, magics_class, line_magic
409
409
410 @magics_class
410 @magics_class
411 class AutoreloadMagics(Magics):
411 class AutoreloadMagics(Magics):
412 def __init__(self, *a, **kw):
412 def __init__(self, *a, **kw):
413 super(AutoreloadMagics, self).__init__(*a, **kw)
413 super(AutoreloadMagics, self).__init__(*a, **kw)
414 self._reloader = ModuleReloader()
414 self._reloader = ModuleReloader()
415 self._reloader.check_all = False
415 self._reloader.check_all = False
416
416
417 @line_magic
417 @line_magic
418 def autoreload(self, parameter_s=''):
418 def autoreload(self, parameter_s=''):
419 r"""%autoreload => Reload modules automatically
419 r"""%autoreload => Reload modules automatically
420
420
421 %autoreload
421 %autoreload
422 Reload all modules (except those excluded by %aimport) automatically
422 Reload all modules (except those excluded by %aimport) automatically
423 now.
423 now.
424
424
425 %autoreload 0
425 %autoreload 0
426 Disable automatic reloading.
426 Disable automatic reloading.
427
427
428 %autoreload 1
428 %autoreload 1
429 Reload all modules imported with %aimport every time before executing
429 Reload all modules imported with %aimport every time before executing
430 the Python code typed.
430 the Python code typed.
431
431
432 %autoreload 2
432 %autoreload 2
433 Reload all modules (except those excluded by %aimport) every time
433 Reload all modules (except those excluded by %aimport) every time
434 before executing the Python code typed.
434 before executing the Python code typed.
435
435
436 Reloading Python modules in a reliable way is in general
436 Reloading Python modules in a reliable way is in general
437 difficult, and unexpected things may occur. %autoreload tries to
437 difficult, and unexpected things may occur. %autoreload tries to
438 work around common pitfalls by replacing function code objects and
438 work around common pitfalls by replacing function code objects and
439 parts of classes previously in the module with new versions. This
439 parts of classes previously in the module with new versions. This
440 makes the following things to work:
440 makes the following things to work:
441
441
442 - Functions and classes imported via 'from xxx import foo' are upgraded
442 - Functions and classes imported via 'from xxx import foo' are upgraded
443 to new versions when 'xxx' is reloaded.
443 to new versions when 'xxx' is reloaded.
444
444
445 - Methods and properties of classes are upgraded on reload, so that
445 - Methods and properties of classes are upgraded on reload, so that
446 calling 'c.foo()' on an object 'c' created before the reload causes
446 calling 'c.foo()' on an object 'c' created before the reload causes
447 the new code for 'foo' to be executed.
447 the new code for 'foo' to be executed.
448
448
449 Some of the known remaining caveats are:
449 Some of the known remaining caveats are:
450
450
451 - Replacing code objects does not always succeed: changing a @property
451 - Replacing code objects does not always succeed: changing a @property
452 in a class to an ordinary method or a method to a member variable
452 in a class to an ordinary method or a method to a member variable
453 can cause problems (but in old objects only).
453 can cause problems (but in old objects only).
454
454
455 - Functions that are removed (eg. via monkey-patching) from a module
455 - Functions that are removed (eg. via monkey-patching) from a module
456 before it is reloaded are not upgraded.
456 before it is reloaded are not upgraded.
457
457
458 - C extension modules cannot be reloaded, and so cannot be
458 - C extension modules cannot be reloaded, and so cannot be
459 autoreloaded.
459 autoreloaded.
460
460
461 """
461 """
462 if parameter_s == '':
462 if parameter_s == '':
463 self._reloader.check(True)
463 self._reloader.check(True)
464 elif parameter_s == '0':
464 elif parameter_s == '0':
465 self._reloader.enabled = False
465 self._reloader.enabled = False
466 elif parameter_s == '1':
466 elif parameter_s == '1':
467 self._reloader.check_all = False
467 self._reloader.check_all = False
468 self._reloader.enabled = True
468 self._reloader.enabled = True
469 elif parameter_s == '2':
469 elif parameter_s == '2':
470 self._reloader.check_all = True
470 self._reloader.check_all = True
471 self._reloader.enabled = True
471 self._reloader.enabled = True
472
472
473 @line_magic
473 @line_magic
474 def aimport(self, parameter_s='', stream=None):
474 def aimport(self, parameter_s='', stream=None):
475 """%aimport => Import modules for automatic reloading.
475 """%aimport => Import modules for automatic reloading.
476
476
477 %aimport
477 %aimport
478 List modules to automatically import and not to import.
478 List modules to automatically import and not to import.
479
479
480 %aimport foo
480 %aimport foo
481 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
481 Import module 'foo' and mark it to be autoreloaded for %autoreload 1
482
482
483 %aimport -foo
483 %aimport -foo
484 Mark module 'foo' to not be autoreloaded for %autoreload 1
484 Mark module 'foo' to not be autoreloaded for %autoreload 1
485 """
485 """
486 modname = parameter_s
486 modname = parameter_s
487 if not modname:
487 if not modname:
488 to_reload = self._reloader.modules.keys()
488 to_reload = self._reloader.modules.keys()
489 to_reload.sort()
489 to_reload.sort()
490 to_skip = self._reloader.skip_modules.keys()
490 to_skip = self._reloader.skip_modules.keys()
491 to_skip.sort()
491 to_skip.sort()
492 if stream is None:
492 if stream is None:
493 stream = sys.stdout
493 stream = sys.stdout
494 if self._reloader.check_all:
494 if self._reloader.check_all:
495 stream.write("Modules to reload:\nall-except-skipped\n")
495 stream.write("Modules to reload:\nall-except-skipped\n")
496 else:
496 else:
497 stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
497 stream.write("Modules to reload:\n%s\n" % ' '.join(to_reload))
498 stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
498 stream.write("\nModules to skip:\n%s\n" % ' '.join(to_skip))
499 elif modname.startswith('-'):
499 elif modname.startswith('-'):
500 modname = modname[1:]
500 modname = modname[1:]
501 self._reloader.mark_module_skipped(modname)
501 self._reloader.mark_module_skipped(modname)
502 else:
502 else:
503 top_module, top_name = self._reloader.aimport_module(modname)
503 top_module, top_name = self._reloader.aimport_module(modname)
504
504
505 # Inject module to user namespace
505 # Inject module to user namespace
506 self.shell.push({top_name: top_module})
506 self.shell.push({top_name: top_module})
507
507
508 def pre_run_code_hook(self, ip):
508 def pre_run_code_hook(self, ip):
509 if not self._reloader.enabled:
509 if not self._reloader.enabled:
510 raise TryNext
510 raise TryNext
511 try:
511 try:
512 self._reloader.check()
512 self._reloader.check()
513 except:
513 except:
514 pass
514 pass
515
515
516
516
517 _loaded = False
518
519
520 def load_ipython_extension(ip):
517 def load_ipython_extension(ip):
521 """Load the extension in IPython."""
518 """Load the extension in IPython."""
522 global _loaded
519 auto_reload = AutoreloadMagics(ip)
523 if not _loaded:
520 ip.register_magics(auto_reload)
524 auto_reload = AutoreloadMagics(ip)
521 ip.set_hook('pre_run_code_hook', auto_reload.pre_run_code_hook)
525 ip.register_magics(auto_reload)
526 ip.set_hook('pre_run_code_hook', auto_reload.pre_run_code_hook)
527 _loaded = True
@@ -1,283 +1,279
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """
2 """
3 Cython related magics.
3 Cython related magics.
4
4
5 Author:
5 Author:
6 * Brian Granger
6 * Brian Granger
7
7
8 Parts of this code were taken from Cython.inline.
8 Parts of this code were taken from Cython.inline.
9 """
9 """
10 #-----------------------------------------------------------------------------
10 #-----------------------------------------------------------------------------
11 # Copyright (C) 2010-2011, IPython Development Team.
11 # Copyright (C) 2010-2011, IPython Development Team.
12 #
12 #
13 # Distributed under the terms of the Modified BSD License.
13 # Distributed under the terms of the Modified BSD License.
14 #
14 #
15 # The full license is in the file COPYING.txt, distributed with this software.
15 # The full license is in the file COPYING.txt, distributed with this software.
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17
17
18 from __future__ import print_function
18 from __future__ import print_function
19
19
20 import imp
20 import imp
21 import io
21 import io
22 import os
22 import os
23 import re
23 import re
24 import sys
24 import sys
25 import time
25 import time
26
26
27 try:
27 try:
28 import hashlib
28 import hashlib
29 except ImportError:
29 except ImportError:
30 import md5 as hashlib
30 import md5 as hashlib
31
31
32 from distutils.core import Distribution, Extension
32 from distutils.core import Distribution, Extension
33 from distutils.command.build_ext import build_ext
33 from distutils.command.build_ext import build_ext
34
34
35 from IPython.core import display
35 from IPython.core import display
36 from IPython.core import magic_arguments
36 from IPython.core import magic_arguments
37 from IPython.core.magic import Magics, magics_class, cell_magic
37 from IPython.core.magic import Magics, magics_class, cell_magic
38 from IPython.testing.skipdoctest import skip_doctest
38 from IPython.testing.skipdoctest import skip_doctest
39 from IPython.utils import py3compat
39 from IPython.utils import py3compat
40
40
41 import Cython
41 import Cython
42 from Cython.Compiler.Errors import CompileError
42 from Cython.Compiler.Errors import CompileError
43 from Cython.Build.Dependencies import cythonize
43 from Cython.Build.Dependencies import cythonize
44
44
45
45
46 @magics_class
46 @magics_class
47 class CythonMagics(Magics):
47 class CythonMagics(Magics):
48
48
49 def __init__(self, shell):
49 def __init__(self, shell):
50 super(CythonMagics,self).__init__(shell)
50 super(CythonMagics,self).__init__(shell)
51 self._reloads = {}
51 self._reloads = {}
52 self._code_cache = {}
52 self._code_cache = {}
53
53
54 def _import_all(self, module):
54 def _import_all(self, module):
55 for k,v in module.__dict__.items():
55 for k,v in module.__dict__.items():
56 if not k.startswith('__'):
56 if not k.startswith('__'):
57 self.shell.push({k:v})
57 self.shell.push({k:v})
58
58
59 @cell_magic
59 @cell_magic
60 def cython_inline(self, line, cell):
60 def cython_inline(self, line, cell):
61 """Compile and run a Cython code cell using Cython.inline.
61 """Compile and run a Cython code cell using Cython.inline.
62
62
63 This magic simply passes the body of the cell to Cython.inline
63 This magic simply passes the body of the cell to Cython.inline
64 and returns the result. If the variables `a` and `b` are defined
64 and returns the result. If the variables `a` and `b` are defined
65 in the user's namespace, here is a simple example that returns
65 in the user's namespace, here is a simple example that returns
66 their sum::
66 their sum::
67
67
68 %%cython_inline
68 %%cython_inline
69 return a+b
69 return a+b
70
70
71 For most purposes, we recommend the usage of the `%%cython` magic.
71 For most purposes, we recommend the usage of the `%%cython` magic.
72 """
72 """
73 locs = self.shell.user_global_ns
73 locs = self.shell.user_global_ns
74 globs = self.shell.user_ns
74 globs = self.shell.user_ns
75 return Cython.inline(cell, locals=locs, globals=globs)
75 return Cython.inline(cell, locals=locs, globals=globs)
76
76
77 @cell_magic
77 @cell_magic
78 def cython_pyximport(self, line, cell):
78 def cython_pyximport(self, line, cell):
79 """Compile and import a Cython code cell using pyximport.
79 """Compile and import a Cython code cell using pyximport.
80
80
81 The contents of the cell are written to a `.pyx` file in the current
81 The contents of the cell are written to a `.pyx` file in the current
82 working directory, which is then imported using `pyximport`. This
82 working directory, which is then imported using `pyximport`. This
83 magic requires a module name to be passed::
83 magic requires a module name to be passed::
84
84
85 %%cython_pyximport modulename
85 %%cython_pyximport modulename
86 def f(x):
86 def f(x):
87 return 2.0*x
87 return 2.0*x
88
88
89 The compiled module is then imported and all of its symbols are
89 The compiled module is then imported and all of its symbols are
90 injected into the user's namespace. For most purposes, we recommend
90 injected into the user's namespace. For most purposes, we recommend
91 the usage of the `%%cython` magic.
91 the usage of the `%%cython` magic.
92 """
92 """
93 module_name = line.strip()
93 module_name = line.strip()
94 if not module_name:
94 if not module_name:
95 raise ValueError('module name must be given')
95 raise ValueError('module name must be given')
96 fname = module_name + '.pyx'
96 fname = module_name + '.pyx'
97 with io.open(fname, 'w', encoding='utf-8') as f:
97 with io.open(fname, 'w', encoding='utf-8') as f:
98 f.write(cell)
98 f.write(cell)
99 if 'pyximport' not in sys.modules:
99 if 'pyximport' not in sys.modules:
100 import pyximport
100 import pyximport
101 pyximport.install(reload_support=True)
101 pyximport.install(reload_support=True)
102 if module_name in self._reloads:
102 if module_name in self._reloads:
103 module = self._reloads[module_name]
103 module = self._reloads[module_name]
104 reload(module)
104 reload(module)
105 else:
105 else:
106 __import__(module_name)
106 __import__(module_name)
107 module = sys.modules[module_name]
107 module = sys.modules[module_name]
108 self._reloads[module_name] = module
108 self._reloads[module_name] = module
109 self._import_all(module)
109 self._import_all(module)
110
110
111 @magic_arguments.magic_arguments()
111 @magic_arguments.magic_arguments()
112 @magic_arguments.argument(
112 @magic_arguments.argument(
113 '-c', '--compile-args', action='append', default=[],
113 '-c', '--compile-args', action='append', default=[],
114 help="Extra flags to pass to compiler via the `extra_compile_args` "
114 help="Extra flags to pass to compiler via the `extra_compile_args` "
115 "Extension flag (can be specified multiple times)."
115 "Extension flag (can be specified multiple times)."
116 )
116 )
117 @magic_arguments.argument(
117 @magic_arguments.argument(
118 '-la', '--link-args', action='append', default=[],
118 '-la', '--link-args', action='append', default=[],
119 help="Extra flags to pass to linker via the `extra_link_args` "
119 help="Extra flags to pass to linker via the `extra_link_args` "
120 "Extension flag (can be specified multiple times)."
120 "Extension flag (can be specified multiple times)."
121 )
121 )
122 @magic_arguments.argument(
122 @magic_arguments.argument(
123 '-l', '--lib', action='append', default=[],
123 '-l', '--lib', action='append', default=[],
124 help="Add a library to link the extension against (can be specified "
124 help="Add a library to link the extension against (can be specified "
125 "multiple times)."
125 "multiple times)."
126 )
126 )
127 @magic_arguments.argument(
127 @magic_arguments.argument(
128 '-L', dest='library_dirs', metavar='dir', action='append', default=[],
128 '-L', dest='library_dirs', metavar='dir', action='append', default=[],
129 help="Add a path to the list of libary directories (can be specified "
129 help="Add a path to the list of libary directories (can be specified "
130 "multiple times)."
130 "multiple times)."
131 )
131 )
132 @magic_arguments.argument(
132 @magic_arguments.argument(
133 '-I', '--include', action='append', default=[],
133 '-I', '--include', action='append', default=[],
134 help="Add a path to the list of include directories (can be specified "
134 help="Add a path to the list of include directories (can be specified "
135 "multiple times)."
135 "multiple times)."
136 )
136 )
137 @magic_arguments.argument(
137 @magic_arguments.argument(
138 '-+', '--cplus', action='store_true', default=False,
138 '-+', '--cplus', action='store_true', default=False,
139 help="Output a C++ rather than C file."
139 help="Output a C++ rather than C file."
140 )
140 )
141 @magic_arguments.argument(
141 @magic_arguments.argument(
142 '-f', '--force', action='store_true', default=False,
142 '-f', '--force', action='store_true', default=False,
143 help="Force the compilation of a new module, even if the source has been "
143 help="Force the compilation of a new module, even if the source has been "
144 "previously compiled."
144 "previously compiled."
145 )
145 )
146 @magic_arguments.argument(
146 @magic_arguments.argument(
147 '-a', '--annotate', action='store_true', default=False,
147 '-a', '--annotate', action='store_true', default=False,
148 help="Produce a colorized HTML version of the source."
148 help="Produce a colorized HTML version of the source."
149 )
149 )
150 @cell_magic
150 @cell_magic
151 def cython(self, line, cell):
151 def cython(self, line, cell):
152 """Compile and import everything from a Cython code cell.
152 """Compile and import everything from a Cython code cell.
153
153
154 The contents of the cell are written to a `.pyx` file in the
154 The contents of the cell are written to a `.pyx` file in the
155 directory `IPYTHONDIR/cython` using a filename with the hash of the
155 directory `IPYTHONDIR/cython` using a filename with the hash of the
156 code. This file is then cythonized and compiled. The resulting module
156 code. This file is then cythonized and compiled. The resulting module
157 is imported and all of its symbols are injected into the user's
157 is imported and all of its symbols are injected into the user's
158 namespace. The usage is similar to that of `%%cython_pyximport` but
158 namespace. The usage is similar to that of `%%cython_pyximport` but
159 you don't have to pass a module name::
159 you don't have to pass a module name::
160
160
161 %%cython
161 %%cython
162 def f(x):
162 def f(x):
163 return 2.0*x
163 return 2.0*x
164 """
164 """
165 args = magic_arguments.parse_argstring(self.cython, line)
165 args = magic_arguments.parse_argstring(self.cython, line)
166 code = cell if cell.endswith('\n') else cell+'\n'
166 code = cell if cell.endswith('\n') else cell+'\n'
167 lib_dir = os.path.join(self.shell.ipython_dir, 'cython')
167 lib_dir = os.path.join(self.shell.ipython_dir, 'cython')
168 quiet = True
168 quiet = True
169 key = code, sys.version_info, sys.executable, Cython.__version__
169 key = code, sys.version_info, sys.executable, Cython.__version__
170
170
171 if not os.path.exists(lib_dir):
171 if not os.path.exists(lib_dir):
172 os.makedirs(lib_dir)
172 os.makedirs(lib_dir)
173
173
174 if args.force:
174 if args.force:
175 # Force a new module name by adding the current time to the
175 # Force a new module name by adding the current time to the
176 # key which is hashed to determine the module name.
176 # key which is hashed to determine the module name.
177 key += time.time(),
177 key += time.time(),
178
178
179 module_name = "_cython_magic_" + hashlib.md5(str(key).encode('utf-8')).hexdigest()
179 module_name = "_cython_magic_" + hashlib.md5(str(key).encode('utf-8')).hexdigest()
180 module_path = os.path.join(lib_dir, module_name + self.so_ext)
180 module_path = os.path.join(lib_dir, module_name + self.so_ext)
181
181
182 have_module = os.path.isfile(module_path)
182 have_module = os.path.isfile(module_path)
183 need_cythonize = not have_module
183 need_cythonize = not have_module
184
184
185 if args.annotate:
185 if args.annotate:
186 html_file = os.path.join(lib_dir, module_name + '.html')
186 html_file = os.path.join(lib_dir, module_name + '.html')
187 if not os.path.isfile(html_file):
187 if not os.path.isfile(html_file):
188 need_cythonize = True
188 need_cythonize = True
189
189
190 if need_cythonize:
190 if need_cythonize:
191 c_include_dirs = args.include
191 c_include_dirs = args.include
192 if 'numpy' in code:
192 if 'numpy' in code:
193 import numpy
193 import numpy
194 c_include_dirs.append(numpy.get_include())
194 c_include_dirs.append(numpy.get_include())
195 pyx_file = os.path.join(lib_dir, module_name + '.pyx')
195 pyx_file = os.path.join(lib_dir, module_name + '.pyx')
196 pyx_file = py3compat.cast_bytes_py2(pyx_file, encoding=sys.getfilesystemencoding())
196 pyx_file = py3compat.cast_bytes_py2(pyx_file, encoding=sys.getfilesystemencoding())
197 with io.open(pyx_file, 'w', encoding='utf-8') as f:
197 with io.open(pyx_file, 'w', encoding='utf-8') as f:
198 f.write(code)
198 f.write(code)
199 extension = Extension(
199 extension = Extension(
200 name = module_name,
200 name = module_name,
201 sources = [pyx_file],
201 sources = [pyx_file],
202 include_dirs = c_include_dirs,
202 include_dirs = c_include_dirs,
203 library_dirs = args.library_dirs,
203 library_dirs = args.library_dirs,
204 extra_compile_args = args.compile_args,
204 extra_compile_args = args.compile_args,
205 extra_link_args = args.link_args,
205 extra_link_args = args.link_args,
206 libraries = args.lib,
206 libraries = args.lib,
207 language = 'c++' if args.cplus else 'c',
207 language = 'c++' if args.cplus else 'c',
208 )
208 )
209 build_extension = self._get_build_extension()
209 build_extension = self._get_build_extension()
210 try:
210 try:
211 opts = dict(
211 opts = dict(
212 quiet=quiet,
212 quiet=quiet,
213 annotate = args.annotate,
213 annotate = args.annotate,
214 force = True,
214 force = True,
215 )
215 )
216 build_extension.extensions = cythonize([extension], **opts)
216 build_extension.extensions = cythonize([extension], **opts)
217 except CompileError:
217 except CompileError:
218 return
218 return
219
219
220 if not have_module:
220 if not have_module:
221 build_extension.build_temp = os.path.dirname(pyx_file)
221 build_extension.build_temp = os.path.dirname(pyx_file)
222 build_extension.build_lib = lib_dir
222 build_extension.build_lib = lib_dir
223 build_extension.run()
223 build_extension.run()
224 self._code_cache[key] = module_name
224 self._code_cache[key] = module_name
225
225
226 module = imp.load_dynamic(module_name, module_path)
226 module = imp.load_dynamic(module_name, module_path)
227 self._import_all(module)
227 self._import_all(module)
228
228
229 if args.annotate:
229 if args.annotate:
230 try:
230 try:
231 with io.open(html_file, encoding='utf-8') as f:
231 with io.open(html_file, encoding='utf-8') as f:
232 annotated_html = f.read()
232 annotated_html = f.read()
233 except IOError as e:
233 except IOError as e:
234 # File could not be opened. Most likely the user has a version
234 # File could not be opened. Most likely the user has a version
235 # of Cython before 0.15.1 (when `cythonize` learned the
235 # of Cython before 0.15.1 (when `cythonize` learned the
236 # `force` keyword argument) and has already compiled this
236 # `force` keyword argument) and has already compiled this
237 # exact source without annotation.
237 # exact source without annotation.
238 print('Cython completed successfully but the annotated '
238 print('Cython completed successfully but the annotated '
239 'source could not be read.', file=sys.stderr)
239 'source could not be read.', file=sys.stderr)
240 print(e, file=sys.stderr)
240 print(e, file=sys.stderr)
241 else:
241 else:
242 return display.HTML(self.clean_annotated_html(annotated_html))
242 return display.HTML(self.clean_annotated_html(annotated_html))
243
243
244 @property
244 @property
245 def so_ext(self):
245 def so_ext(self):
246 """The extension suffix for compiled modules."""
246 """The extension suffix for compiled modules."""
247 try:
247 try:
248 return self._so_ext
248 return self._so_ext
249 except AttributeError:
249 except AttributeError:
250 self._so_ext = self._get_build_extension().get_ext_filename('')
250 self._so_ext = self._get_build_extension().get_ext_filename('')
251 return self._so_ext
251 return self._so_ext
252
252
253 def _get_build_extension(self):
253 def _get_build_extension(self):
254 dist = Distribution()
254 dist = Distribution()
255 config_files = dist.find_config_files()
255 config_files = dist.find_config_files()
256 try:
256 try:
257 config_files.remove('setup.cfg')
257 config_files.remove('setup.cfg')
258 except ValueError:
258 except ValueError:
259 pass
259 pass
260 dist.parse_config_files(config_files)
260 dist.parse_config_files(config_files)
261 build_extension = build_ext(dist)
261 build_extension = build_ext(dist)
262 build_extension.finalize_options()
262 build_extension.finalize_options()
263 return build_extension
263 return build_extension
264
264
265 @staticmethod
265 @staticmethod
266 def clean_annotated_html(html):
266 def clean_annotated_html(html):
267 """Clean up the annotated HTML source.
267 """Clean up the annotated HTML source.
268
268
269 Strips the link to the generated C or C++ file, which we do not
269 Strips the link to the generated C or C++ file, which we do not
270 present to the user.
270 present to the user.
271 """
271 """
272 r = re.compile('<p>Raw output: <a href="(.*)">(.*)</a>')
272 r = re.compile('<p>Raw output: <a href="(.*)">(.*)</a>')
273 html = '\n'.join(l for l in html.splitlines() if not r.match(l))
273 html = '\n'.join(l for l in html.splitlines() if not r.match(l))
274 return html
274 return html
275
275
276 _loaded = False
277
276
278 def load_ipython_extension(ip):
277 def load_ipython_extension(ip):
279 """Load the extension in IPython."""
278 """Load the extension in IPython."""
280 global _loaded
279 ip.register_magics(CythonMagics)
281 if not _loaded:
282 ip.register_magics(CythonMagics)
283 _loaded = True
@@ -1,371 +1,367
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """
2 """
3 ===========
3 ===========
4 octavemagic
4 octavemagic
5 ===========
5 ===========
6
6
7 Magics for interacting with Octave via oct2py.
7 Magics for interacting with Octave via oct2py.
8
8
9 .. note::
9 .. note::
10
10
11 The ``oct2py`` module needs to be installed separately and
11 The ``oct2py`` module needs to be installed separately and
12 can be obtained using ``easy_install`` or ``pip``.
12 can be obtained using ``easy_install`` or ``pip``.
13
13
14 Usage
14 Usage
15 =====
15 =====
16
16
17 ``%octave``
17 ``%octave``
18
18
19 {OCTAVE_DOC}
19 {OCTAVE_DOC}
20
20
21 ``%octave_push``
21 ``%octave_push``
22
22
23 {OCTAVE_PUSH_DOC}
23 {OCTAVE_PUSH_DOC}
24
24
25 ``%octave_pull``
25 ``%octave_pull``
26
26
27 {OCTAVE_PULL_DOC}
27 {OCTAVE_PULL_DOC}
28
28
29 """
29 """
30
30
31 #-----------------------------------------------------------------------------
31 #-----------------------------------------------------------------------------
32 # Copyright (C) 2012 The IPython Development Team
32 # Copyright (C) 2012 The IPython Development Team
33 #
33 #
34 # Distributed under the terms of the BSD License. The full license is in
34 # Distributed under the terms of the BSD License. The full license is in
35 # the file COPYING, distributed as part of this software.
35 # the file COPYING, distributed as part of this software.
36 #-----------------------------------------------------------------------------
36 #-----------------------------------------------------------------------------
37
37
38 import tempfile
38 import tempfile
39 from glob import glob
39 from glob import glob
40 from shutil import rmtree
40 from shutil import rmtree
41
41
42 import numpy as np
42 import numpy as np
43 import oct2py
43 import oct2py
44 from xml.dom import minidom
44 from xml.dom import minidom
45
45
46 from IPython.core.displaypub import publish_display_data
46 from IPython.core.displaypub import publish_display_data
47 from IPython.core.magic import (Magics, magics_class, line_magic,
47 from IPython.core.magic import (Magics, magics_class, line_magic,
48 line_cell_magic, needs_local_scope)
48 line_cell_magic, needs_local_scope)
49 from IPython.testing.skipdoctest import skip_doctest
49 from IPython.testing.skipdoctest import skip_doctest
50 from IPython.core.magic_arguments import (
50 from IPython.core.magic_arguments import (
51 argument, magic_arguments, parse_argstring
51 argument, magic_arguments, parse_argstring
52 )
52 )
53 from IPython.utils.py3compat import unicode_to_str
53 from IPython.utils.py3compat import unicode_to_str
54
54
55 class OctaveMagicError(oct2py.Oct2PyError):
55 class OctaveMagicError(oct2py.Oct2PyError):
56 pass
56 pass
57
57
58 _mimetypes = {'png' : 'image/png',
58 _mimetypes = {'png' : 'image/png',
59 'svg' : 'image/svg+xml',
59 'svg' : 'image/svg+xml',
60 'jpg' : 'image/jpeg',
60 'jpg' : 'image/jpeg',
61 'jpeg': 'image/jpeg'}
61 'jpeg': 'image/jpeg'}
62
62
63 @magics_class
63 @magics_class
64 class OctaveMagics(Magics):
64 class OctaveMagics(Magics):
65 """A set of magics useful for interactive work with Octave via oct2py.
65 """A set of magics useful for interactive work with Octave via oct2py.
66 """
66 """
67 def __init__(self, shell):
67 def __init__(self, shell):
68 """
68 """
69 Parameters
69 Parameters
70 ----------
70 ----------
71 shell : IPython shell
71 shell : IPython shell
72
72
73 """
73 """
74 super(OctaveMagics, self).__init__(shell)
74 super(OctaveMagics, self).__init__(shell)
75 self._oct = oct2py.Oct2Py()
75 self._oct = oct2py.Oct2Py()
76 self._plot_format = 'png'
76 self._plot_format = 'png'
77
77
78 # Allow publish_display_data to be overridden for
78 # Allow publish_display_data to be overridden for
79 # testing purposes.
79 # testing purposes.
80 self._publish_display_data = publish_display_data
80 self._publish_display_data = publish_display_data
81
81
82
82
83 def _fix_gnuplot_svg_size(self, image, size=None):
83 def _fix_gnuplot_svg_size(self, image, size=None):
84 """
84 """
85 GnuPlot SVGs do not have height/width attributes. Set
85 GnuPlot SVGs do not have height/width attributes. Set
86 these to be the same as the viewBox, so that the browser
86 these to be the same as the viewBox, so that the browser
87 scales the image correctly.
87 scales the image correctly.
88
88
89 Parameters
89 Parameters
90 ----------
90 ----------
91 image : str
91 image : str
92 SVG data.
92 SVG data.
93 size : tuple of int
93 size : tuple of int
94 Image width, height.
94 Image width, height.
95
95
96 """
96 """
97 (svg,) = minidom.parseString(image).getElementsByTagName('svg')
97 (svg,) = minidom.parseString(image).getElementsByTagName('svg')
98 viewbox = svg.getAttribute('viewBox').split(' ')
98 viewbox = svg.getAttribute('viewBox').split(' ')
99
99
100 if size is not None:
100 if size is not None:
101 width, height = size
101 width, height = size
102 else:
102 else:
103 width, height = viewbox[2:]
103 width, height = viewbox[2:]
104
104
105 svg.setAttribute('width', '%dpx' % width)
105 svg.setAttribute('width', '%dpx' % width)
106 svg.setAttribute('height', '%dpx' % height)
106 svg.setAttribute('height', '%dpx' % height)
107 return svg.toxml()
107 return svg.toxml()
108
108
109
109
110 @skip_doctest
110 @skip_doctest
111 @line_magic
111 @line_magic
112 def octave_push(self, line):
112 def octave_push(self, line):
113 '''
113 '''
114 Line-level magic that pushes a variable to Octave.
114 Line-level magic that pushes a variable to Octave.
115
115
116 `line` should be made up of whitespace separated variable names in the
116 `line` should be made up of whitespace separated variable names in the
117 IPython namespace::
117 IPython namespace::
118
118
119 In [7]: import numpy as np
119 In [7]: import numpy as np
120
120
121 In [8]: X = np.arange(5)
121 In [8]: X = np.arange(5)
122
122
123 In [9]: X.mean()
123 In [9]: X.mean()
124 Out[9]: 2.0
124 Out[9]: 2.0
125
125
126 In [10]: %octave_push X
126 In [10]: %octave_push X
127
127
128 In [11]: %octave mean(X)
128 In [11]: %octave mean(X)
129 Out[11]: 2.0
129 Out[11]: 2.0
130
130
131 '''
131 '''
132 inputs = line.split(' ')
132 inputs = line.split(' ')
133 for input in inputs:
133 for input in inputs:
134 input = unicode_to_str(input)
134 input = unicode_to_str(input)
135 self._oct.put(input, self.shell.user_ns[input])
135 self._oct.put(input, self.shell.user_ns[input])
136
136
137
137
138 @skip_doctest
138 @skip_doctest
139 @line_magic
139 @line_magic
140 def octave_pull(self, line):
140 def octave_pull(self, line):
141 '''
141 '''
142 Line-level magic that pulls a variable from Octave.
142 Line-level magic that pulls a variable from Octave.
143
143
144 In [18]: _ = %octave x = [1 2; 3 4]; y = 'hello'
144 In [18]: _ = %octave x = [1 2; 3 4]; y = 'hello'
145
145
146 In [19]: %octave_pull x y
146 In [19]: %octave_pull x y
147
147
148 In [20]: x
148 In [20]: x
149 Out[20]:
149 Out[20]:
150 array([[ 1., 2.],
150 array([[ 1., 2.],
151 [ 3., 4.]])
151 [ 3., 4.]])
152
152
153 In [21]: y
153 In [21]: y
154 Out[21]: 'hello'
154 Out[21]: 'hello'
155
155
156 '''
156 '''
157 outputs = line.split(' ')
157 outputs = line.split(' ')
158 for output in outputs:
158 for output in outputs:
159 output = unicode_to_str(output)
159 output = unicode_to_str(output)
160 self.shell.push({output: self._oct.get(output)})
160 self.shell.push({output: self._oct.get(output)})
161
161
162
162
163 @skip_doctest
163 @skip_doctest
164 @magic_arguments()
164 @magic_arguments()
165 @argument(
165 @argument(
166 '-i', '--input', action='append',
166 '-i', '--input', action='append',
167 help='Names of input variables to be pushed to Octave. Multiple names '
167 help='Names of input variables to be pushed to Octave. Multiple names '
168 'can be passed, separated by commas with no whitespace.'
168 'can be passed, separated by commas with no whitespace.'
169 )
169 )
170 @argument(
170 @argument(
171 '-o', '--output', action='append',
171 '-o', '--output', action='append',
172 help='Names of variables to be pulled from Octave after executing cell '
172 help='Names of variables to be pulled from Octave after executing cell '
173 'body. Multiple names can be passed, separated by commas with no '
173 'body. Multiple names can be passed, separated by commas with no '
174 'whitespace.'
174 'whitespace.'
175 )
175 )
176 @argument(
176 @argument(
177 '-s', '--size', action='store',
177 '-s', '--size', action='store',
178 help='Pixel size of plots, "width,height". Default is "-s 400,250".'
178 help='Pixel size of plots, "width,height". Default is "-s 400,250".'
179 )
179 )
180 @argument(
180 @argument(
181 '-f', '--format', action='store',
181 '-f', '--format', action='store',
182 help='Plot format (png, svg or jpg).'
182 help='Plot format (png, svg or jpg).'
183 )
183 )
184
184
185 @needs_local_scope
185 @needs_local_scope
186 @argument(
186 @argument(
187 'code',
187 'code',
188 nargs='*',
188 nargs='*',
189 )
189 )
190 @line_cell_magic
190 @line_cell_magic
191 def octave(self, line, cell=None, local_ns=None):
191 def octave(self, line, cell=None, local_ns=None):
192 '''
192 '''
193 Execute code in Octave, and pull some of the results back into the
193 Execute code in Octave, and pull some of the results back into the
194 Python namespace.
194 Python namespace.
195
195
196 In [9]: %octave X = [1 2; 3 4]; mean(X)
196 In [9]: %octave X = [1 2; 3 4]; mean(X)
197 Out[9]: array([[ 2., 3.]])
197 Out[9]: array([[ 2., 3.]])
198
198
199 As a cell, this will run a block of Octave code, without returning any
199 As a cell, this will run a block of Octave code, without returning any
200 value::
200 value::
201
201
202 In [10]: %%octave
202 In [10]: %%octave
203 ....: p = [-2, -1, 0, 1, 2]
203 ....: p = [-2, -1, 0, 1, 2]
204 ....: polyout(p, 'x')
204 ....: polyout(p, 'x')
205
205
206 -2*x^4 - 1*x^3 + 0*x^2 + 1*x^1 + 2
206 -2*x^4 - 1*x^3 + 0*x^2 + 1*x^1 + 2
207
207
208 In the notebook, plots are published as the output of the cell, e.g.
208 In the notebook, plots are published as the output of the cell, e.g.
209
209
210 %octave plot([1 2 3], [4 5 6])
210 %octave plot([1 2 3], [4 5 6])
211
211
212 will create a line plot.
212 will create a line plot.
213
213
214 Objects can be passed back and forth between Octave and IPython via the
214 Objects can be passed back and forth between Octave and IPython via the
215 -i and -o flags in line::
215 -i and -o flags in line::
216
216
217 In [14]: Z = np.array([1, 4, 5, 10])
217 In [14]: Z = np.array([1, 4, 5, 10])
218
218
219 In [15]: %octave -i Z mean(Z)
219 In [15]: %octave -i Z mean(Z)
220 Out[15]: array([ 5.])
220 Out[15]: array([ 5.])
221
221
222
222
223 In [16]: %octave -o W W = Z * mean(Z)
223 In [16]: %octave -o W W = Z * mean(Z)
224 Out[16]: array([ 5., 20., 25., 50.])
224 Out[16]: array([ 5., 20., 25., 50.])
225
225
226 In [17]: W
226 In [17]: W
227 Out[17]: array([ 5., 20., 25., 50.])
227 Out[17]: array([ 5., 20., 25., 50.])
228
228
229 The size and format of output plots can be specified::
229 The size and format of output plots can be specified::
230
230
231 In [18]: %%octave -s 600,800 -f svg
231 In [18]: %%octave -s 600,800 -f svg
232 ...: plot([1, 2, 3]);
232 ...: plot([1, 2, 3]);
233
233
234 '''
234 '''
235 args = parse_argstring(self.octave, line)
235 args = parse_argstring(self.octave, line)
236
236
237 # arguments 'code' in line are prepended to the cell lines
237 # arguments 'code' in line are prepended to the cell lines
238 if cell is None:
238 if cell is None:
239 code = ''
239 code = ''
240 return_output = True
240 return_output = True
241 else:
241 else:
242 code = cell
242 code = cell
243 return_output = False
243 return_output = False
244
244
245 code = ' '.join(args.code) + code
245 code = ' '.join(args.code) + code
246
246
247 # if there is no local namespace then default to an empty dict
247 # if there is no local namespace then default to an empty dict
248 if local_ns is None:
248 if local_ns is None:
249 local_ns = {}
249 local_ns = {}
250
250
251 if args.input:
251 if args.input:
252 for input in ','.join(args.input).split(','):
252 for input in ','.join(args.input).split(','):
253 input = unicode_to_str(input)
253 input = unicode_to_str(input)
254 try:
254 try:
255 val = local_ns[input]
255 val = local_ns[input]
256 except KeyError:
256 except KeyError:
257 val = self.shell.user_ns[input]
257 val = self.shell.user_ns[input]
258 self._oct.put(input, val)
258 self._oct.put(input, val)
259
259
260 # generate plots in a temporary directory
260 # generate plots in a temporary directory
261 plot_dir = tempfile.mkdtemp()
261 plot_dir = tempfile.mkdtemp()
262 if args.size is not None:
262 if args.size is not None:
263 size = args.size
263 size = args.size
264 else:
264 else:
265 size = '400,240'
265 size = '400,240'
266
266
267 if args.format is not None:
267 if args.format is not None:
268 plot_format = args.format
268 plot_format = args.format
269 else:
269 else:
270 plot_format = 'png'
270 plot_format = 'png'
271
271
272 pre_call = '''
272 pre_call = '''
273 global __ipy_figures = [];
273 global __ipy_figures = [];
274 page_screen_output(0);
274 page_screen_output(0);
275
275
276 function fig_create(src, event)
276 function fig_create(src, event)
277 global __ipy_figures;
277 global __ipy_figures;
278 __ipy_figures(size(__ipy_figures) + 1) = src;
278 __ipy_figures(size(__ipy_figures) + 1) = src;
279 set(src, "visible", "off");
279 set(src, "visible", "off");
280 end
280 end
281
281
282 set(0, 'DefaultFigureCreateFcn', @fig_create);
282 set(0, 'DefaultFigureCreateFcn', @fig_create);
283
283
284 close all;
284 close all;
285 clear ans;
285 clear ans;
286
286
287 # ___<end_pre_call>___ #
287 # ___<end_pre_call>___ #
288 '''
288 '''
289
289
290 post_call = '''
290 post_call = '''
291 # ___<start_post_call>___ #
291 # ___<start_post_call>___ #
292
292
293 # Save output of the last execution
293 # Save output of the last execution
294 if exist("ans") == 1
294 if exist("ans") == 1
295 _ = ans;
295 _ = ans;
296 else
296 else
297 _ = nan;
297 _ = nan;
298 end
298 end
299
299
300 for f = __ipy_figures
300 for f = __ipy_figures
301 outfile = sprintf('%(plot_dir)s/__ipy_oct_fig_%%03d.png', f);
301 outfile = sprintf('%(plot_dir)s/__ipy_oct_fig_%%03d.png', f);
302 try
302 try
303 print(f, outfile, '-d%(plot_format)s', '-tight', '-S%(size)s');
303 print(f, outfile, '-d%(plot_format)s', '-tight', '-S%(size)s');
304 end
304 end
305 end
305 end
306
306
307 ''' % locals()
307 ''' % locals()
308
308
309 code = ' '.join((pre_call, code, post_call))
309 code = ' '.join((pre_call, code, post_call))
310 try:
310 try:
311 text_output = self._oct.run(code, verbose=False)
311 text_output = self._oct.run(code, verbose=False)
312 except (oct2py.Oct2PyError) as exception:
312 except (oct2py.Oct2PyError) as exception:
313 msg = exception.message
313 msg = exception.message
314 msg = msg.split('# ___<end_pre_call>___ #')[1]
314 msg = msg.split('# ___<end_pre_call>___ #')[1]
315 msg = msg.split('# ___<start_post_call>___ #')[0]
315 msg = msg.split('# ___<start_post_call>___ #')[0]
316 raise OctaveMagicError('Octave could not complete execution. '
316 raise OctaveMagicError('Octave could not complete execution. '
317 'Traceback (currently broken in oct2py): %s'
317 'Traceback (currently broken in oct2py): %s'
318 % msg)
318 % msg)
319
319
320 key = 'OctaveMagic.Octave'
320 key = 'OctaveMagic.Octave'
321 display_data = []
321 display_data = []
322
322
323 # Publish text output
323 # Publish text output
324 if text_output:
324 if text_output:
325 display_data.append((key, {'text/plain': text_output}))
325 display_data.append((key, {'text/plain': text_output}))
326
326
327 # Publish images
327 # Publish images
328 images = [open(imgfile, 'rb').read() for imgfile in \
328 images = [open(imgfile, 'rb').read() for imgfile in \
329 glob("%s/*" % plot_dir)]
329 glob("%s/*" % plot_dir)]
330 rmtree(plot_dir)
330 rmtree(plot_dir)
331
331
332 plot_mime_type = _mimetypes.get(plot_format, 'image/png')
332 plot_mime_type = _mimetypes.get(plot_format, 'image/png')
333 width, height = [int(s) for s in size.split(',')]
333 width, height = [int(s) for s in size.split(',')]
334 for image in images:
334 for image in images:
335 if plot_format == 'svg':
335 if plot_format == 'svg':
336 image = self._fix_gnuplot_svg_size(image, size=(width, height))
336 image = self._fix_gnuplot_svg_size(image, size=(width, height))
337 display_data.append((key, {plot_mime_type: image}))
337 display_data.append((key, {plot_mime_type: image}))
338
338
339 if args.output:
339 if args.output:
340 for output in ','.join(args.output).split(','):
340 for output in ','.join(args.output).split(','):
341 output = unicode_to_str(output)
341 output = unicode_to_str(output)
342 self.shell.push({output: self._oct.get(output)})
342 self.shell.push({output: self._oct.get(output)})
343
343
344 for source, data in display_data:
344 for source, data in display_data:
345 self._publish_display_data(source, data)
345 self._publish_display_data(source, data)
346
346
347 if return_output:
347 if return_output:
348 ans = self._oct.get('_')
348 ans = self._oct.get('_')
349
349
350 # Unfortunately, Octave doesn't have a "None" object,
350 # Unfortunately, Octave doesn't have a "None" object,
351 # so we can't return any NaN outputs
351 # so we can't return any NaN outputs
352 if np.isscalar(ans) and np.isnan(ans):
352 if np.isscalar(ans) and np.isnan(ans):
353 ans = None
353 ans = None
354
354
355 return ans
355 return ans
356
356
357
357
358 __doc__ = __doc__.format(
358 __doc__ = __doc__.format(
359 OCTAVE_DOC = ' '*8 + OctaveMagics.octave.__doc__,
359 OCTAVE_DOC = ' '*8 + OctaveMagics.octave.__doc__,
360 OCTAVE_PUSH_DOC = ' '*8 + OctaveMagics.octave_push.__doc__,
360 OCTAVE_PUSH_DOC = ' '*8 + OctaveMagics.octave_push.__doc__,
361 OCTAVE_PULL_DOC = ' '*8 + OctaveMagics.octave_pull.__doc__
361 OCTAVE_PULL_DOC = ' '*8 + OctaveMagics.octave_pull.__doc__
362 )
362 )
363
363
364
364
365 _loaded = False
366 def load_ipython_extension(ip):
365 def load_ipython_extension(ip):
367 """Load the extension in IPython."""
366 """Load the extension in IPython."""
368 global _loaded
367 ip.register_magics(OctaveMagics)
369 if not _loaded:
370 ip.register_magics(OctaveMagics)
371 _loaded = True
@@ -1,597 +1,593
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """
2 """
3 ======
3 ======
4 Rmagic
4 Rmagic
5 ======
5 ======
6
6
7 Magic command interface for interactive work with R via rpy2
7 Magic command interface for interactive work with R via rpy2
8
8
9 Usage
9 Usage
10 =====
10 =====
11
11
12 ``%R``
12 ``%R``
13
13
14 {R_DOC}
14 {R_DOC}
15
15
16 ``%Rpush``
16 ``%Rpush``
17
17
18 {RPUSH_DOC}
18 {RPUSH_DOC}
19
19
20 ``%Rpull``
20 ``%Rpull``
21
21
22 {RPULL_DOC}
22 {RPULL_DOC}
23
23
24 ``%Rget``
24 ``%Rget``
25
25
26 {RGET_DOC}
26 {RGET_DOC}
27
27
28 """
28 """
29
29
30 #-----------------------------------------------------------------------------
30 #-----------------------------------------------------------------------------
31 # Copyright (C) 2012 The IPython Development Team
31 # Copyright (C) 2012 The IPython Development Team
32 #
32 #
33 # Distributed under the terms of the BSD License. The full license is in
33 # Distributed under the terms of the BSD License. The full license is in
34 # the file COPYING, distributed as part of this software.
34 # the file COPYING, distributed as part of this software.
35 #-----------------------------------------------------------------------------
35 #-----------------------------------------------------------------------------
36
36
37 import sys
37 import sys
38 import tempfile
38 import tempfile
39 from glob import glob
39 from glob import glob
40 from shutil import rmtree
40 from shutil import rmtree
41 from getopt import getopt
41 from getopt import getopt
42
42
43 # numpy and rpy2 imports
43 # numpy and rpy2 imports
44
44
45 import numpy as np
45 import numpy as np
46
46
47 import rpy2.rinterface as ri
47 import rpy2.rinterface as ri
48 import rpy2.robjects as ro
48 import rpy2.robjects as ro
49 from rpy2.robjects.numpy2ri import numpy2ri
49 from rpy2.robjects.numpy2ri import numpy2ri
50 ro.conversion.py2ri = numpy2ri
50 ro.conversion.py2ri = numpy2ri
51
51
52 # IPython imports
52 # IPython imports
53
53
54 from IPython.core.displaypub import publish_display_data
54 from IPython.core.displaypub import publish_display_data
55 from IPython.core.magic import (Magics, magics_class, cell_magic, line_magic,
55 from IPython.core.magic import (Magics, magics_class, cell_magic, line_magic,
56 line_cell_magic, needs_local_scope)
56 line_cell_magic, needs_local_scope)
57 from IPython.testing.skipdoctest import skip_doctest
57 from IPython.testing.skipdoctest import skip_doctest
58 from IPython.core.magic_arguments import (
58 from IPython.core.magic_arguments import (
59 argument, magic_arguments, parse_argstring
59 argument, magic_arguments, parse_argstring
60 )
60 )
61 from IPython.utils.py3compat import str_to_unicode, unicode_to_str, PY3
61 from IPython.utils.py3compat import str_to_unicode, unicode_to_str, PY3
62
62
63 class RInterpreterError(ri.RRuntimeError):
63 class RInterpreterError(ri.RRuntimeError):
64 """An error when running R code in a %%R magic cell."""
64 """An error when running R code in a %%R magic cell."""
65 def __init__(self, line, err, stdout):
65 def __init__(self, line, err, stdout):
66 self.line = line
66 self.line = line
67 self.err = err.rstrip()
67 self.err = err.rstrip()
68 self.stdout = stdout.rstrip()
68 self.stdout = stdout.rstrip()
69
69
70 def __unicode__(self):
70 def __unicode__(self):
71 s = 'Failed to parse and evaluate line %r.\nR error message: %r' % \
71 s = 'Failed to parse and evaluate line %r.\nR error message: %r' % \
72 (self.line, self.err)
72 (self.line, self.err)
73 if self.stdout and (self.stdout != self.err):
73 if self.stdout and (self.stdout != self.err):
74 s += '\nR stdout:\n' + self.stdout
74 s += '\nR stdout:\n' + self.stdout
75 return s
75 return s
76
76
77 if PY3:
77 if PY3:
78 __str__ = __unicode__
78 __str__ = __unicode__
79 else:
79 else:
80 def __str__(self):
80 def __str__(self):
81 return unicode_to_str(unicode(self), 'utf-8')
81 return unicode_to_str(unicode(self), 'utf-8')
82
82
83 def Rconverter(Robj, dataframe=False):
83 def Rconverter(Robj, dataframe=False):
84 """
84 """
85 Convert an object in R's namespace to one suitable
85 Convert an object in R's namespace to one suitable
86 for ipython's namespace.
86 for ipython's namespace.
87
87
88 For a data.frame, it tries to return a structured array.
88 For a data.frame, it tries to return a structured array.
89 It first checks for colnames, then names.
89 It first checks for colnames, then names.
90 If all are NULL, it returns np.asarray(Robj), else
90 If all are NULL, it returns np.asarray(Robj), else
91 it tries to construct a recarray
91 it tries to construct a recarray
92
92
93 Parameters
93 Parameters
94 ----------
94 ----------
95
95
96 Robj: an R object returned from rpy2
96 Robj: an R object returned from rpy2
97 """
97 """
98 is_data_frame = ro.r('is.data.frame')
98 is_data_frame = ro.r('is.data.frame')
99 colnames = ro.r('colnames')
99 colnames = ro.r('colnames')
100 rownames = ro.r('rownames') # with pandas, these could be used for the index
100 rownames = ro.r('rownames') # with pandas, these could be used for the index
101 names = ro.r('names')
101 names = ro.r('names')
102
102
103 if dataframe:
103 if dataframe:
104 as_data_frame = ro.r('as.data.frame')
104 as_data_frame = ro.r('as.data.frame')
105 cols = colnames(Robj)
105 cols = colnames(Robj)
106 _names = names(Robj)
106 _names = names(Robj)
107 if cols != ri.NULL:
107 if cols != ri.NULL:
108 Robj = as_data_frame(Robj)
108 Robj = as_data_frame(Robj)
109 names = tuple(np.array(cols))
109 names = tuple(np.array(cols))
110 elif _names != ri.NULL:
110 elif _names != ri.NULL:
111 names = tuple(np.array(_names))
111 names = tuple(np.array(_names))
112 else: # failed to find names
112 else: # failed to find names
113 return np.asarray(Robj)
113 return np.asarray(Robj)
114 Robj = np.rec.fromarrays(Robj, names = names)
114 Robj = np.rec.fromarrays(Robj, names = names)
115 return np.asarray(Robj)
115 return np.asarray(Robj)
116
116
117 @magics_class
117 @magics_class
118 class RMagics(Magics):
118 class RMagics(Magics):
119 """A set of magics useful for interactive work with R via rpy2.
119 """A set of magics useful for interactive work with R via rpy2.
120 """
120 """
121
121
122 def __init__(self, shell, Rconverter=Rconverter,
122 def __init__(self, shell, Rconverter=Rconverter,
123 pyconverter=np.asarray,
123 pyconverter=np.asarray,
124 cache_display_data=False):
124 cache_display_data=False):
125 """
125 """
126 Parameters
126 Parameters
127 ----------
127 ----------
128
128
129 shell : IPython shell
129 shell : IPython shell
130
130
131 pyconverter : callable
131 pyconverter : callable
132 To be called on values in ipython namespace before
132 To be called on values in ipython namespace before
133 assigning to variables in rpy2.
133 assigning to variables in rpy2.
134
134
135 cache_display_data : bool
135 cache_display_data : bool
136 If True, the published results of the final call to R are
136 If True, the published results of the final call to R are
137 cached in the variable 'display_cache'.
137 cached in the variable 'display_cache'.
138
138
139 """
139 """
140 super(RMagics, self).__init__(shell)
140 super(RMagics, self).__init__(shell)
141 self.cache_display_data = cache_display_data
141 self.cache_display_data = cache_display_data
142
142
143 self.r = ro.R()
143 self.r = ro.R()
144
144
145 self.Rstdout_cache = []
145 self.Rstdout_cache = []
146 self.pyconverter = pyconverter
146 self.pyconverter = pyconverter
147 self.Rconverter = Rconverter
147 self.Rconverter = Rconverter
148
148
149 def eval(self, line):
149 def eval(self, line):
150 '''
150 '''
151 Parse and evaluate a line with rpy2.
151 Parse and evaluate a line with rpy2.
152 Returns the output to R's stdout() connection
152 Returns the output to R's stdout() connection
153 and the value of eval(parse(line)).
153 and the value of eval(parse(line)).
154 '''
154 '''
155 old_writeconsole = ri.get_writeconsole()
155 old_writeconsole = ri.get_writeconsole()
156 ri.set_writeconsole(self.write_console)
156 ri.set_writeconsole(self.write_console)
157 try:
157 try:
158 value = ri.baseenv['eval'](ri.parse(line))
158 value = ri.baseenv['eval'](ri.parse(line))
159 except (ri.RRuntimeError, ValueError) as exception:
159 except (ri.RRuntimeError, ValueError) as exception:
160 warning_or_other_msg = self.flush() # otherwise next return seems to have copy of error
160 warning_or_other_msg = self.flush() # otherwise next return seems to have copy of error
161 raise RInterpreterError(line, str_to_unicode(str(exception)), warning_or_other_msg)
161 raise RInterpreterError(line, str_to_unicode(str(exception)), warning_or_other_msg)
162 text_output = self.flush()
162 text_output = self.flush()
163 ri.set_writeconsole(old_writeconsole)
163 ri.set_writeconsole(old_writeconsole)
164 return text_output, value
164 return text_output, value
165
165
166 def write_console(self, output):
166 def write_console(self, output):
167 '''
167 '''
168 A hook to capture R's stdout in a cache.
168 A hook to capture R's stdout in a cache.
169 '''
169 '''
170 self.Rstdout_cache.append(output)
170 self.Rstdout_cache.append(output)
171
171
172 def flush(self):
172 def flush(self):
173 '''
173 '''
174 Flush R's stdout cache to a string, returning the string.
174 Flush R's stdout cache to a string, returning the string.
175 '''
175 '''
176 value = ''.join([str_to_unicode(s, 'utf-8') for s in self.Rstdout_cache])
176 value = ''.join([str_to_unicode(s, 'utf-8') for s in self.Rstdout_cache])
177 self.Rstdout_cache = []
177 self.Rstdout_cache = []
178 return value
178 return value
179
179
180 @skip_doctest
180 @skip_doctest
181 @line_magic
181 @line_magic
182 def Rpush(self, line):
182 def Rpush(self, line):
183 '''
183 '''
184 A line-level magic for R that pushes
184 A line-level magic for R that pushes
185 variables from python to rpy2. The line should be made up
185 variables from python to rpy2. The line should be made up
186 of whitespace separated variable names in the IPython
186 of whitespace separated variable names in the IPython
187 namespace::
187 namespace::
188
188
189 In [7]: import numpy as np
189 In [7]: import numpy as np
190
190
191 In [8]: X = np.array([4.5,6.3,7.9])
191 In [8]: X = np.array([4.5,6.3,7.9])
192
192
193 In [9]: X.mean()
193 In [9]: X.mean()
194 Out[9]: 6.2333333333333343
194 Out[9]: 6.2333333333333343
195
195
196 In [10]: %Rpush X
196 In [10]: %Rpush X
197
197
198 In [11]: %R mean(X)
198 In [11]: %R mean(X)
199 Out[11]: array([ 6.23333333])
199 Out[11]: array([ 6.23333333])
200
200
201 '''
201 '''
202
202
203 inputs = line.split(' ')
203 inputs = line.split(' ')
204 for input in inputs:
204 for input in inputs:
205 self.r.assign(input, self.pyconverter(self.shell.user_ns[input]))
205 self.r.assign(input, self.pyconverter(self.shell.user_ns[input]))
206
206
207 @skip_doctest
207 @skip_doctest
208 @magic_arguments()
208 @magic_arguments()
209 @argument(
209 @argument(
210 '-d', '--as_dataframe', action='store_true',
210 '-d', '--as_dataframe', action='store_true',
211 default=False,
211 default=False,
212 help='Convert objects to data.frames before returning to ipython.'
212 help='Convert objects to data.frames before returning to ipython.'
213 )
213 )
214 @argument(
214 @argument(
215 'outputs',
215 'outputs',
216 nargs='*',
216 nargs='*',
217 )
217 )
218 @line_magic
218 @line_magic
219 def Rpull(self, line):
219 def Rpull(self, line):
220 '''
220 '''
221 A line-level magic for R that pulls
221 A line-level magic for R that pulls
222 variables from python to rpy2::
222 variables from python to rpy2::
223
223
224 In [18]: _ = %R x = c(3,4,6.7); y = c(4,6,7); z = c('a',3,4)
224 In [18]: _ = %R x = c(3,4,6.7); y = c(4,6,7); z = c('a',3,4)
225
225
226 In [19]: %Rpull x y z
226 In [19]: %Rpull x y z
227
227
228 In [20]: x
228 In [20]: x
229 Out[20]: array([ 3. , 4. , 6.7])
229 Out[20]: array([ 3. , 4. , 6.7])
230
230
231 In [21]: y
231 In [21]: y
232 Out[21]: array([ 4., 6., 7.])
232 Out[21]: array([ 4., 6., 7.])
233
233
234 In [22]: z
234 In [22]: z
235 Out[22]:
235 Out[22]:
236 array(['a', '3', '4'],
236 array(['a', '3', '4'],
237 dtype='|S1')
237 dtype='|S1')
238
238
239
239
240 If --as_dataframe, then each object is returned as a structured array
240 If --as_dataframe, then each object is returned as a structured array
241 after first passed through "as.data.frame" in R before
241 after first passed through "as.data.frame" in R before
242 being calling self.Rconverter.
242 being calling self.Rconverter.
243 This is useful when a structured array is desired as output, or
243 This is useful when a structured array is desired as output, or
244 when the object in R has mixed data types.
244 when the object in R has mixed data types.
245 See the %%R docstring for more examples.
245 See the %%R docstring for more examples.
246
246
247 Notes
247 Notes
248 -----
248 -----
249
249
250 Beware that R names can have '.' so this is not fool proof.
250 Beware that R names can have '.' so this is not fool proof.
251 To avoid this, don't name your R objects with '.'s...
251 To avoid this, don't name your R objects with '.'s...
252
252
253 '''
253 '''
254 args = parse_argstring(self.Rpull, line)
254 args = parse_argstring(self.Rpull, line)
255 outputs = args.outputs
255 outputs = args.outputs
256 for output in outputs:
256 for output in outputs:
257 self.shell.push({output:self.Rconverter(self.r(output),dataframe=args.as_dataframe)})
257 self.shell.push({output:self.Rconverter(self.r(output),dataframe=args.as_dataframe)})
258
258
259 @skip_doctest
259 @skip_doctest
260 @magic_arguments()
260 @magic_arguments()
261 @argument(
261 @argument(
262 '-d', '--as_dataframe', action='store_true',
262 '-d', '--as_dataframe', action='store_true',
263 default=False,
263 default=False,
264 help='Convert objects to data.frames before returning to ipython.'
264 help='Convert objects to data.frames before returning to ipython.'
265 )
265 )
266 @argument(
266 @argument(
267 'output',
267 'output',
268 nargs=1,
268 nargs=1,
269 type=str,
269 type=str,
270 )
270 )
271 @line_magic
271 @line_magic
272 def Rget(self, line):
272 def Rget(self, line):
273 '''
273 '''
274 Return an object from rpy2, possibly as a structured array (if possible).
274 Return an object from rpy2, possibly as a structured array (if possible).
275 Similar to Rpull except only one argument is accepted and the value is
275 Similar to Rpull except only one argument is accepted and the value is
276 returned rather than pushed to self.shell.user_ns::
276 returned rather than pushed to self.shell.user_ns::
277
277
278 In [3]: dtype=[('x', '<i4'), ('y', '<f8'), ('z', '|S1')]
278 In [3]: dtype=[('x', '<i4'), ('y', '<f8'), ('z', '|S1')]
279
279
280 In [4]: datapy = np.array([(1, 2.9, 'a'), (2, 3.5, 'b'), (3, 2.1, 'c'), (4, 5, 'e')], dtype=dtype)
280 In [4]: datapy = np.array([(1, 2.9, 'a'), (2, 3.5, 'b'), (3, 2.1, 'c'), (4, 5, 'e')], dtype=dtype)
281
281
282 In [5]: %R -i datapy
282 In [5]: %R -i datapy
283
283
284 In [6]: %Rget datapy
284 In [6]: %Rget datapy
285 Out[6]:
285 Out[6]:
286 array([['1', '2', '3', '4'],
286 array([['1', '2', '3', '4'],
287 ['2', '3', '2', '5'],
287 ['2', '3', '2', '5'],
288 ['a', 'b', 'c', 'e']],
288 ['a', 'b', 'c', 'e']],
289 dtype='|S1')
289 dtype='|S1')
290
290
291 In [7]: %Rget -d datapy
291 In [7]: %Rget -d datapy
292 Out[7]:
292 Out[7]:
293 array([(1, 2.9, 'a'), (2, 3.5, 'b'), (3, 2.1, 'c'), (4, 5.0, 'e')],
293 array([(1, 2.9, 'a'), (2, 3.5, 'b'), (3, 2.1, 'c'), (4, 5.0, 'e')],
294 dtype=[('x', '<i4'), ('y', '<f8'), ('z', '|S1')])
294 dtype=[('x', '<i4'), ('y', '<f8'), ('z', '|S1')])
295
295
296 '''
296 '''
297 args = parse_argstring(self.Rget, line)
297 args = parse_argstring(self.Rget, line)
298 output = args.output
298 output = args.output
299 return self.Rconverter(self.r(output[0]),dataframe=args.as_dataframe)
299 return self.Rconverter(self.r(output[0]),dataframe=args.as_dataframe)
300
300
301
301
302 @skip_doctest
302 @skip_doctest
303 @magic_arguments()
303 @magic_arguments()
304 @argument(
304 @argument(
305 '-i', '--input', action='append',
305 '-i', '--input', action='append',
306 help='Names of input variable from shell.user_ns to be assigned to R variables of the same names after calling self.pyconverter. Multiple names can be passed separated only by commas with no whitespace.'
306 help='Names of input variable from shell.user_ns to be assigned to R variables of the same names after calling self.pyconverter. Multiple names can be passed separated only by commas with no whitespace.'
307 )
307 )
308 @argument(
308 @argument(
309 '-o', '--output', action='append',
309 '-o', '--output', action='append',
310 help='Names of variables to be pushed from rpy2 to shell.user_ns after executing cell body and applying self.Rconverter. Multiple names can be passed separated only by commas with no whitespace.'
310 help='Names of variables to be pushed from rpy2 to shell.user_ns after executing cell body and applying self.Rconverter. Multiple names can be passed separated only by commas with no whitespace.'
311 )
311 )
312 @argument(
312 @argument(
313 '-w', '--width', type=int,
313 '-w', '--width', type=int,
314 help='Width of png plotting device sent as an argument to *png* in R.'
314 help='Width of png plotting device sent as an argument to *png* in R.'
315 )
315 )
316 @argument(
316 @argument(
317 '-h', '--height', type=int,
317 '-h', '--height', type=int,
318 help='Height of png plotting device sent as an argument to *png* in R.'
318 help='Height of png plotting device sent as an argument to *png* in R.'
319 )
319 )
320
320
321 @argument(
321 @argument(
322 '-d', '--dataframe', action='append',
322 '-d', '--dataframe', action='append',
323 help='Convert these objects to data.frames and return as structured arrays.'
323 help='Convert these objects to data.frames and return as structured arrays.'
324 )
324 )
325 @argument(
325 @argument(
326 '-u', '--units', type=int,
326 '-u', '--units', type=int,
327 help='Units of png plotting device sent as an argument to *png* in R. One of ["px", "in", "cm", "mm"].'
327 help='Units of png plotting device sent as an argument to *png* in R. One of ["px", "in", "cm", "mm"].'
328 )
328 )
329 @argument(
329 @argument(
330 '-p', '--pointsize', type=int,
330 '-p', '--pointsize', type=int,
331 help='Pointsize of png plotting device sent as an argument to *png* in R.'
331 help='Pointsize of png plotting device sent as an argument to *png* in R.'
332 )
332 )
333 @argument(
333 @argument(
334 '-b', '--bg',
334 '-b', '--bg',
335 help='Background of png plotting device sent as an argument to *png* in R.'
335 help='Background of png plotting device sent as an argument to *png* in R.'
336 )
336 )
337 @argument(
337 @argument(
338 '-n', '--noreturn',
338 '-n', '--noreturn',
339 help='Force the magic to not return anything.',
339 help='Force the magic to not return anything.',
340 action='store_true',
340 action='store_true',
341 default=False
341 default=False
342 )
342 )
343 @argument(
343 @argument(
344 'code',
344 'code',
345 nargs='*',
345 nargs='*',
346 )
346 )
347 @needs_local_scope
347 @needs_local_scope
348 @line_cell_magic
348 @line_cell_magic
349 def R(self, line, cell=None, local_ns=None):
349 def R(self, line, cell=None, local_ns=None):
350 '''
350 '''
351 Execute code in R, and pull some of the results back into the Python namespace.
351 Execute code in R, and pull some of the results back into the Python namespace.
352
352
353 In line mode, this will evaluate an expression and convert the returned value to a Python object.
353 In line mode, this will evaluate an expression and convert the returned value to a Python object.
354 The return value is determined by rpy2's behaviour of returning the result of evaluating the
354 The return value is determined by rpy2's behaviour of returning the result of evaluating the
355 final line.
355 final line.
356
356
357 Multiple R lines can be executed by joining them with semicolons::
357 Multiple R lines can be executed by joining them with semicolons::
358
358
359 In [9]: %R X=c(1,4,5,7); sd(X); mean(X)
359 In [9]: %R X=c(1,4,5,7); sd(X); mean(X)
360 Out[9]: array([ 4.25])
360 Out[9]: array([ 4.25])
361
361
362 As a cell, this will run a block of R code, without bringing anything back by default::
362 As a cell, this will run a block of R code, without bringing anything back by default::
363
363
364 In [10]: %%R
364 In [10]: %%R
365 ....: Y = c(2,4,3,9)
365 ....: Y = c(2,4,3,9)
366 ....: print(summary(lm(Y~X)))
366 ....: print(summary(lm(Y~X)))
367 ....:
367 ....:
368
368
369 Call:
369 Call:
370 lm(formula = Y ~ X)
370 lm(formula = Y ~ X)
371
371
372 Residuals:
372 Residuals:
373 1 2 3 4
373 1 2 3 4
374 0.88 -0.24 -2.28 1.64
374 0.88 -0.24 -2.28 1.64
375
375
376 Coefficients:
376 Coefficients:
377 Estimate Std. Error t value Pr(>|t|)
377 Estimate Std. Error t value Pr(>|t|)
378 (Intercept) 0.0800 2.3000 0.035 0.975
378 (Intercept) 0.0800 2.3000 0.035 0.975
379 X 1.0400 0.4822 2.157 0.164
379 X 1.0400 0.4822 2.157 0.164
380
380
381 Residual standard error: 2.088 on 2 degrees of freedom
381 Residual standard error: 2.088 on 2 degrees of freedom
382 Multiple R-squared: 0.6993,Adjusted R-squared: 0.549
382 Multiple R-squared: 0.6993,Adjusted R-squared: 0.549
383 F-statistic: 4.651 on 1 and 2 DF, p-value: 0.1638
383 F-statistic: 4.651 on 1 and 2 DF, p-value: 0.1638
384
384
385 In the notebook, plots are published as the output of the cell.
385 In the notebook, plots are published as the output of the cell.
386
386
387 %R plot(X, Y)
387 %R plot(X, Y)
388
388
389 will create a scatter plot of X bs Y.
389 will create a scatter plot of X bs Y.
390
390
391 If cell is not None and line has some R code, it is prepended to
391 If cell is not None and line has some R code, it is prepended to
392 the R code in cell.
392 the R code in cell.
393
393
394 Objects can be passed back and forth between rpy2 and python via the -i -o flags in line::
394 Objects can be passed back and forth between rpy2 and python via the -i -o flags in line::
395
395
396 In [14]: Z = np.array([1,4,5,10])
396 In [14]: Z = np.array([1,4,5,10])
397
397
398 In [15]: %R -i Z mean(Z)
398 In [15]: %R -i Z mean(Z)
399 Out[15]: array([ 5.])
399 Out[15]: array([ 5.])
400
400
401
401
402 In [16]: %R -o W W=Z*mean(Z)
402 In [16]: %R -o W W=Z*mean(Z)
403 Out[16]: array([ 5., 20., 25., 50.])
403 Out[16]: array([ 5., 20., 25., 50.])
404
404
405 In [17]: W
405 In [17]: W
406 Out[17]: array([ 5., 20., 25., 50.])
406 Out[17]: array([ 5., 20., 25., 50.])
407
407
408 The return value is determined by these rules:
408 The return value is determined by these rules:
409
409
410 * If the cell is not None, the magic returns None.
410 * If the cell is not None, the magic returns None.
411
411
412 * If the cell evaluates as False, the resulting value is returned
412 * If the cell evaluates as False, the resulting value is returned
413 unless the final line prints something to the console, in
413 unless the final line prints something to the console, in
414 which case None is returned.
414 which case None is returned.
415
415
416 * If the final line results in a NULL value when evaluated
416 * If the final line results in a NULL value when evaluated
417 by rpy2, then None is returned.
417 by rpy2, then None is returned.
418
418
419 * No attempt is made to convert the final value to a structured array.
419 * No attempt is made to convert the final value to a structured array.
420 Use the --dataframe flag or %Rget to push / return a structured array.
420 Use the --dataframe flag or %Rget to push / return a structured array.
421
421
422 * If the -n flag is present, there is no return value.
422 * If the -n flag is present, there is no return value.
423
423
424 * A trailing ';' will also result in no return value as the last
424 * A trailing ';' will also result in no return value as the last
425 value in the line is an empty string.
425 value in the line is an empty string.
426
426
427 The --dataframe argument will attempt to return structured arrays.
427 The --dataframe argument will attempt to return structured arrays.
428 This is useful for dataframes with
428 This is useful for dataframes with
429 mixed data types. Note also that for a data.frame,
429 mixed data types. Note also that for a data.frame,
430 if it is returned as an ndarray, it is transposed::
430 if it is returned as an ndarray, it is transposed::
431
431
432 In [18]: dtype=[('x', '<i4'), ('y', '<f8'), ('z', '|S1')]
432 In [18]: dtype=[('x', '<i4'), ('y', '<f8'), ('z', '|S1')]
433
433
434 In [19]: datapy = np.array([(1, 2.9, 'a'), (2, 3.5, 'b'), (3, 2.1, 'c'), (4, 5, 'e')], dtype=dtype)
434 In [19]: datapy = np.array([(1, 2.9, 'a'), (2, 3.5, 'b'), (3, 2.1, 'c'), (4, 5, 'e')], dtype=dtype)
435
435
436 In [20]: %%R -o datar
436 In [20]: %%R -o datar
437 datar = datapy
437 datar = datapy
438 ....:
438 ....:
439
439
440 In [21]: datar
440 In [21]: datar
441 Out[21]:
441 Out[21]:
442 array([['1', '2', '3', '4'],
442 array([['1', '2', '3', '4'],
443 ['2', '3', '2', '5'],
443 ['2', '3', '2', '5'],
444 ['a', 'b', 'c', 'e']],
444 ['a', 'b', 'c', 'e']],
445 dtype='|S1')
445 dtype='|S1')
446
446
447 In [22]: %%R -d datar
447 In [22]: %%R -d datar
448 datar = datapy
448 datar = datapy
449 ....:
449 ....:
450
450
451 In [23]: datar
451 In [23]: datar
452 Out[23]:
452 Out[23]:
453 array([(1, 2.9, 'a'), (2, 3.5, 'b'), (3, 2.1, 'c'), (4, 5.0, 'e')],
453 array([(1, 2.9, 'a'), (2, 3.5, 'b'), (3, 2.1, 'c'), (4, 5.0, 'e')],
454 dtype=[('x', '<i4'), ('y', '<f8'), ('z', '|S1')])
454 dtype=[('x', '<i4'), ('y', '<f8'), ('z', '|S1')])
455
455
456 The --dataframe argument first tries colnames, then names.
456 The --dataframe argument first tries colnames, then names.
457 If both are NULL, it returns an ndarray (i.e. unstructured)::
457 If both are NULL, it returns an ndarray (i.e. unstructured)::
458
458
459 In [1]: %R mydata=c(4,6,8.3); NULL
459 In [1]: %R mydata=c(4,6,8.3); NULL
460
460
461 In [2]: %R -d mydata
461 In [2]: %R -d mydata
462
462
463 In [3]: mydata
463 In [3]: mydata
464 Out[3]: array([ 4. , 6. , 8.3])
464 Out[3]: array([ 4. , 6. , 8.3])
465
465
466 In [4]: %R names(mydata) = c('a','b','c'); NULL
466 In [4]: %R names(mydata) = c('a','b','c'); NULL
467
467
468 In [5]: %R -d mydata
468 In [5]: %R -d mydata
469
469
470 In [6]: mydata
470 In [6]: mydata
471 Out[6]:
471 Out[6]:
472 array((4.0, 6.0, 8.3),
472 array((4.0, 6.0, 8.3),
473 dtype=[('a', '<f8'), ('b', '<f8'), ('c', '<f8')])
473 dtype=[('a', '<f8'), ('b', '<f8'), ('c', '<f8')])
474
474
475 In [7]: %R -o mydata
475 In [7]: %R -o mydata
476
476
477 In [8]: mydata
477 In [8]: mydata
478 Out[8]: array([ 4. , 6. , 8.3])
478 Out[8]: array([ 4. , 6. , 8.3])
479
479
480 '''
480 '''
481
481
482 args = parse_argstring(self.R, line)
482 args = parse_argstring(self.R, line)
483
483
484 # arguments 'code' in line are prepended to
484 # arguments 'code' in line are prepended to
485 # the cell lines
485 # the cell lines
486
486
487 if cell is None:
487 if cell is None:
488 code = ''
488 code = ''
489 return_output = True
489 return_output = True
490 line_mode = True
490 line_mode = True
491 else:
491 else:
492 code = cell
492 code = cell
493 return_output = False
493 return_output = False
494 line_mode = False
494 line_mode = False
495
495
496 code = ' '.join(args.code) + code
496 code = ' '.join(args.code) + code
497
497
498 # if there is no local namespace then default to an empty dict
498 # if there is no local namespace then default to an empty dict
499 if local_ns is None:
499 if local_ns is None:
500 local_ns = {}
500 local_ns = {}
501
501
502 if args.input:
502 if args.input:
503 for input in ','.join(args.input).split(','):
503 for input in ','.join(args.input).split(','):
504 try:
504 try:
505 val = local_ns[input]
505 val = local_ns[input]
506 except KeyError:
506 except KeyError:
507 val = self.shell.user_ns[input]
507 val = self.shell.user_ns[input]
508 self.r.assign(input, self.pyconverter(val))
508 self.r.assign(input, self.pyconverter(val))
509
509
510 png_argdict = dict([(n, getattr(args, n)) for n in ['units', 'height', 'width', 'bg', 'pointsize']])
510 png_argdict = dict([(n, getattr(args, n)) for n in ['units', 'height', 'width', 'bg', 'pointsize']])
511 png_args = ','.join(['%s=%s' % (o,v) for o, v in png_argdict.items() if v is not None])
511 png_args = ','.join(['%s=%s' % (o,v) for o, v in png_argdict.items() if v is not None])
512 # execute the R code in a temporary directory
512 # execute the R code in a temporary directory
513
513
514 tmpd = tempfile.mkdtemp()
514 tmpd = tempfile.mkdtemp()
515 self.r('png("%s/Rplots%%03d.png",%s)' % (tmpd, png_args))
515 self.r('png("%s/Rplots%%03d.png",%s)' % (tmpd, png_args))
516
516
517 text_output = ''
517 text_output = ''
518 if line_mode:
518 if line_mode:
519 for line in code.split(';'):
519 for line in code.split(';'):
520 text_result, result = self.eval(line)
520 text_result, result = self.eval(line)
521 text_output += text_result
521 text_output += text_result
522 if text_result:
522 if text_result:
523 # the last line printed something to the console so we won't return it
523 # the last line printed something to the console so we won't return it
524 return_output = False
524 return_output = False
525 else:
525 else:
526 text_result, result = self.eval(code)
526 text_result, result = self.eval(code)
527 text_output += text_result
527 text_output += text_result
528
528
529 self.r('dev.off()')
529 self.r('dev.off()')
530
530
531 # read out all the saved .png files
531 # read out all the saved .png files
532
532
533 images = [open(imgfile, 'rb').read() for imgfile in glob("%s/Rplots*png" % tmpd)]
533 images = [open(imgfile, 'rb').read() for imgfile in glob("%s/Rplots*png" % tmpd)]
534
534
535 # now publish the images
535 # now publish the images
536 # mimicking IPython/zmq/pylab/backend_inline.py
536 # mimicking IPython/zmq/pylab/backend_inline.py
537 fmt = 'png'
537 fmt = 'png'
538 mimetypes = { 'png' : 'image/png', 'svg' : 'image/svg+xml' }
538 mimetypes = { 'png' : 'image/png', 'svg' : 'image/svg+xml' }
539 mime = mimetypes[fmt]
539 mime = mimetypes[fmt]
540
540
541 # publish the printed R objects, if any
541 # publish the printed R objects, if any
542
542
543 display_data = []
543 display_data = []
544 if text_output:
544 if text_output:
545 display_data.append(('RMagic.R', {'text/plain':text_output}))
545 display_data.append(('RMagic.R', {'text/plain':text_output}))
546
546
547 # flush text streams before sending figures, helps a little with output
547 # flush text streams before sending figures, helps a little with output
548 for image in images:
548 for image in images:
549 # synchronization in the console (though it's a bandaid, not a real sln)
549 # synchronization in the console (though it's a bandaid, not a real sln)
550 sys.stdout.flush(); sys.stderr.flush()
550 sys.stdout.flush(); sys.stderr.flush()
551 display_data.append(('RMagic.R', {mime: image}))
551 display_data.append(('RMagic.R', {mime: image}))
552
552
553 # kill the temporary directory
553 # kill the temporary directory
554 rmtree(tmpd)
554 rmtree(tmpd)
555
555
556 # try to turn every output into a numpy array
556 # try to turn every output into a numpy array
557 # this means that output are assumed to be castable
557 # this means that output are assumed to be castable
558 # as numpy arrays
558 # as numpy arrays
559
559
560 if args.output:
560 if args.output:
561 for output in ','.join(args.output).split(','):
561 for output in ','.join(args.output).split(','):
562 self.shell.push({output:self.Rconverter(self.r(output), dataframe=False)})
562 self.shell.push({output:self.Rconverter(self.r(output), dataframe=False)})
563
563
564 if args.dataframe:
564 if args.dataframe:
565 for output in ','.join(args.dataframe).split(','):
565 for output in ','.join(args.dataframe).split(','):
566 self.shell.push({output:self.Rconverter(self.r(output), dataframe=True)})
566 self.shell.push({output:self.Rconverter(self.r(output), dataframe=True)})
567
567
568 for tag, disp_d in display_data:
568 for tag, disp_d in display_data:
569 publish_display_data(tag, disp_d)
569 publish_display_data(tag, disp_d)
570
570
571 # this will keep a reference to the display_data
571 # this will keep a reference to the display_data
572 # which might be useful to other objects who happen to use
572 # which might be useful to other objects who happen to use
573 # this method
573 # this method
574
574
575 if self.cache_display_data:
575 if self.cache_display_data:
576 self.display_cache = display_data
576 self.display_cache = display_data
577
577
578 # if in line mode and return_output, return the result as an ndarray
578 # if in line mode and return_output, return the result as an ndarray
579 if return_output and not args.noreturn:
579 if return_output and not args.noreturn:
580 if result != ri.NULL:
580 if result != ri.NULL:
581 return self.Rconverter(result, dataframe=False)
581 return self.Rconverter(result, dataframe=False)
582
582
583 __doc__ = __doc__.format(
583 __doc__ = __doc__.format(
584 R_DOC = ' '*8 + RMagics.R.__doc__,
584 R_DOC = ' '*8 + RMagics.R.__doc__,
585 RPUSH_DOC = ' '*8 + RMagics.Rpush.__doc__,
585 RPUSH_DOC = ' '*8 + RMagics.Rpush.__doc__,
586 RPULL_DOC = ' '*8 + RMagics.Rpull.__doc__,
586 RPULL_DOC = ' '*8 + RMagics.Rpull.__doc__,
587 RGET_DOC = ' '*8 + RMagics.Rget.__doc__
587 RGET_DOC = ' '*8 + RMagics.Rget.__doc__
588 )
588 )
589
589
590
590
591 _loaded = False
592 def load_ipython_extension(ip):
591 def load_ipython_extension(ip):
593 """Load the extension in IPython."""
592 """Load the extension in IPython."""
594 global _loaded
593 ip.register_magics(RMagics)
595 if not _loaded:
596 ip.register_magics(RMagics)
597 _loaded = True
@@ -1,220 +1,214
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """
2 """
3 %store magic for lightweight persistence.
3 %store magic for lightweight persistence.
4
4
5 Stores variables, aliases and macros in IPython's database.
5 Stores variables, aliases and macros in IPython's database.
6
6
7 To automatically restore stored variables at startup, add this to your
7 To automatically restore stored variables at startup, add this to your
8 :file:`ipython_config.py` file::
8 :file:`ipython_config.py` file::
9
9
10 c.StoreMagic.autorestore = True
10 c.StoreMagic.autorestore = True
11 """
11 """
12 #-----------------------------------------------------------------------------
12 #-----------------------------------------------------------------------------
13 # Copyright (c) 2012, The IPython Development Team.
13 # Copyright (c) 2012, The IPython Development Team.
14 #
14 #
15 # Distributed under the terms of the Modified BSD License.
15 # Distributed under the terms of the Modified BSD License.
16 #
16 #
17 # The full license is in the file COPYING.txt, distributed with this software.
17 # The full license is in the file COPYING.txt, distributed with this software.
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 #-----------------------------------------------------------------------------
20 #-----------------------------------------------------------------------------
21 # Imports
21 # Imports
22 #-----------------------------------------------------------------------------
22 #-----------------------------------------------------------------------------
23
23
24 # Stdlib
24 # Stdlib
25 import inspect, os, sys, textwrap
25 import inspect, os, sys, textwrap
26
26
27 # Our own
27 # Our own
28 from IPython.core.error import UsageError
28 from IPython.core.error import UsageError
29 from IPython.core.fakemodule import FakeModule
29 from IPython.core.fakemodule import FakeModule
30 from IPython.core.magic import Magics, magics_class, line_magic
30 from IPython.core.magic import Magics, magics_class, line_magic
31 from IPython.testing.skipdoctest import skip_doctest
31 from IPython.testing.skipdoctest import skip_doctest
32
32
33 #-----------------------------------------------------------------------------
33 #-----------------------------------------------------------------------------
34 # Functions and classes
34 # Functions and classes
35 #-----------------------------------------------------------------------------
35 #-----------------------------------------------------------------------------
36
36
37 def restore_aliases(ip):
37 def restore_aliases(ip):
38 staliases = ip.db.get('stored_aliases', {})
38 staliases = ip.db.get('stored_aliases', {})
39 for k,v in staliases.items():
39 for k,v in staliases.items():
40 #print "restore alias",k,v # dbg
40 #print "restore alias",k,v # dbg
41 #self.alias_table[k] = v
41 #self.alias_table[k] = v
42 ip.alias_manager.define_alias(k,v)
42 ip.alias_manager.define_alias(k,v)
43
43
44
44
45 def refresh_variables(ip):
45 def refresh_variables(ip):
46 db = ip.db
46 db = ip.db
47 for key in db.keys('autorestore/*'):
47 for key in db.keys('autorestore/*'):
48 # strip autorestore
48 # strip autorestore
49 justkey = os.path.basename(key)
49 justkey = os.path.basename(key)
50 try:
50 try:
51 obj = db[key]
51 obj = db[key]
52 except KeyError:
52 except KeyError:
53 print "Unable to restore variable '%s', ignoring (use %%store -d to forget!)" % justkey
53 print "Unable to restore variable '%s', ignoring (use %%store -d to forget!)" % justkey
54 print "The error was:", sys.exc_info()[0]
54 print "The error was:", sys.exc_info()[0]
55 else:
55 else:
56 #print "restored",justkey,"=",obj #dbg
56 #print "restored",justkey,"=",obj #dbg
57 ip.user_ns[justkey] = obj
57 ip.user_ns[justkey] = obj
58
58
59
59
60 def restore_dhist(ip):
60 def restore_dhist(ip):
61 ip.user_ns['_dh'] = ip.db.get('dhist',[])
61 ip.user_ns['_dh'] = ip.db.get('dhist',[])
62
62
63
63
64 def restore_data(ip):
64 def restore_data(ip):
65 refresh_variables(ip)
65 refresh_variables(ip)
66 restore_aliases(ip)
66 restore_aliases(ip)
67 restore_dhist(ip)
67 restore_dhist(ip)
68
68
69
69
70 @magics_class
70 @magics_class
71 class StoreMagics(Magics):
71 class StoreMagics(Magics):
72 """Lightweight persistence for python variables.
72 """Lightweight persistence for python variables.
73
73
74 Provides the %store magic."""
74 Provides the %store magic."""
75
75
76 @skip_doctest
76 @skip_doctest
77 @line_magic
77 @line_magic
78 def store(self, parameter_s=''):
78 def store(self, parameter_s=''):
79 """Lightweight persistence for python variables.
79 """Lightweight persistence for python variables.
80
80
81 Example::
81 Example::
82
82
83 In [1]: l = ['hello',10,'world']
83 In [1]: l = ['hello',10,'world']
84 In [2]: %store l
84 In [2]: %store l
85 In [3]: exit
85 In [3]: exit
86
86
87 (IPython session is closed and started again...)
87 (IPython session is closed and started again...)
88
88
89 ville@badger:~$ ipython
89 ville@badger:~$ ipython
90 In [1]: l
90 In [1]: l
91 Out[1]: ['hello', 10, 'world']
91 Out[1]: ['hello', 10, 'world']
92
92
93 Usage:
93 Usage:
94
94
95 * ``%store`` - Show list of all variables and their current
95 * ``%store`` - Show list of all variables and their current
96 values
96 values
97 * ``%store spam`` - Store the *current* value of the variable spam
97 * ``%store spam`` - Store the *current* value of the variable spam
98 to disk
98 to disk
99 * ``%store -d spam`` - Remove the variable and its value from storage
99 * ``%store -d spam`` - Remove the variable and its value from storage
100 * ``%store -z`` - Remove all variables from storage
100 * ``%store -z`` - Remove all variables from storage
101 * ``%store -r`` - Refresh all variables from store (delete
101 * ``%store -r`` - Refresh all variables from store (delete
102 current vals)
102 current vals)
103 * ``%store foo >a.txt`` - Store value of foo to new file a.txt
103 * ``%store foo >a.txt`` - Store value of foo to new file a.txt
104 * ``%store foo >>a.txt`` - Append value of foo to file a.txt
104 * ``%store foo >>a.txt`` - Append value of foo to file a.txt
105
105
106 It should be noted that if you change the value of a variable, you
106 It should be noted that if you change the value of a variable, you
107 need to %store it again if you want to persist the new value.
107 need to %store it again if you want to persist the new value.
108
108
109 Note also that the variables will need to be pickleable; most basic
109 Note also that the variables will need to be pickleable; most basic
110 python types can be safely %store'd.
110 python types can be safely %store'd.
111
111
112 Also aliases can be %store'd across sessions.
112 Also aliases can be %store'd across sessions.
113 """
113 """
114
114
115 opts,argsl = self.parse_options(parameter_s,'drz',mode='string')
115 opts,argsl = self.parse_options(parameter_s,'drz',mode='string')
116 args = argsl.split(None,1)
116 args = argsl.split(None,1)
117 ip = self.shell
117 ip = self.shell
118 db = ip.db
118 db = ip.db
119 # delete
119 # delete
120 if 'd' in opts:
120 if 'd' in opts:
121 try:
121 try:
122 todel = args[0]
122 todel = args[0]
123 except IndexError:
123 except IndexError:
124 raise UsageError('You must provide the variable to forget')
124 raise UsageError('You must provide the variable to forget')
125 else:
125 else:
126 try:
126 try:
127 del db['autorestore/' + todel]
127 del db['autorestore/' + todel]
128 except:
128 except:
129 raise UsageError("Can't delete variable '%s'" % todel)
129 raise UsageError("Can't delete variable '%s'" % todel)
130 # reset
130 # reset
131 elif 'z' in opts:
131 elif 'z' in opts:
132 for k in db.keys('autorestore/*'):
132 for k in db.keys('autorestore/*'):
133 del db[k]
133 del db[k]
134
134
135 elif 'r' in opts:
135 elif 'r' in opts:
136 refresh_variables(ip)
136 refresh_variables(ip)
137
137
138
138
139 # run without arguments -> list variables & values
139 # run without arguments -> list variables & values
140 elif not args:
140 elif not args:
141 vars = db.keys('autorestore/*')
141 vars = db.keys('autorestore/*')
142 vars.sort()
142 vars.sort()
143 if vars:
143 if vars:
144 size = max(map(len, vars))
144 size = max(map(len, vars))
145 else:
145 else:
146 size = 0
146 size = 0
147
147
148 print 'Stored variables and their in-db values:'
148 print 'Stored variables and their in-db values:'
149 fmt = '%-'+str(size)+'s -> %s'
149 fmt = '%-'+str(size)+'s -> %s'
150 get = db.get
150 get = db.get
151 for var in vars:
151 for var in vars:
152 justkey = os.path.basename(var)
152 justkey = os.path.basename(var)
153 # print 30 first characters from every var
153 # print 30 first characters from every var
154 print fmt % (justkey, repr(get(var, '<unavailable>'))[:50])
154 print fmt % (justkey, repr(get(var, '<unavailable>'))[:50])
155
155
156 # default action - store the variable
156 # default action - store the variable
157 else:
157 else:
158 # %store foo >file.txt or >>file.txt
158 # %store foo >file.txt or >>file.txt
159 if len(args) > 1 and args[1].startswith('>'):
159 if len(args) > 1 and args[1].startswith('>'):
160 fnam = os.path.expanduser(args[1].lstrip('>').lstrip())
160 fnam = os.path.expanduser(args[1].lstrip('>').lstrip())
161 if args[1].startswith('>>'):
161 if args[1].startswith('>>'):
162 fil = open(fnam, 'a')
162 fil = open(fnam, 'a')
163 else:
163 else:
164 fil = open(fnam, 'w')
164 fil = open(fnam, 'w')
165 obj = ip.ev(args[0])
165 obj = ip.ev(args[0])
166 print "Writing '%s' (%s) to file '%s'." % (args[0],
166 print "Writing '%s' (%s) to file '%s'." % (args[0],
167 obj.__class__.__name__, fnam)
167 obj.__class__.__name__, fnam)
168
168
169
169
170 if not isinstance (obj, basestring):
170 if not isinstance (obj, basestring):
171 from pprint import pprint
171 from pprint import pprint
172 pprint(obj, fil)
172 pprint(obj, fil)
173 else:
173 else:
174 fil.write(obj)
174 fil.write(obj)
175 if not obj.endswith('\n'):
175 if not obj.endswith('\n'):
176 fil.write('\n')
176 fil.write('\n')
177
177
178 fil.close()
178 fil.close()
179 return
179 return
180
180
181 # %store foo
181 # %store foo
182 try:
182 try:
183 obj = ip.user_ns[args[0]]
183 obj = ip.user_ns[args[0]]
184 except KeyError:
184 except KeyError:
185 # it might be an alias
185 # it might be an alias
186 # This needs to be refactored to use the new AliasManager stuff.
186 # This needs to be refactored to use the new AliasManager stuff.
187 if args[0] in ip.alias_manager:
187 if args[0] in ip.alias_manager:
188 name = args[0]
188 name = args[0]
189 nargs, cmd = ip.alias_manager.alias_table[ name ]
189 nargs, cmd = ip.alias_manager.alias_table[ name ]
190 staliases = db.get('stored_aliases',{})
190 staliases = db.get('stored_aliases',{})
191 staliases[ name ] = cmd
191 staliases[ name ] = cmd
192 db['stored_aliases'] = staliases
192 db['stored_aliases'] = staliases
193 print "Alias stored: %s (%s)" % (name, cmd)
193 print "Alias stored: %s (%s)" % (name, cmd)
194 return
194 return
195 else:
195 else:
196 raise UsageError("Unknown variable '%s'" % args[0])
196 raise UsageError("Unknown variable '%s'" % args[0])
197
197
198 else:
198 else:
199 if isinstance(inspect.getmodule(obj), FakeModule):
199 if isinstance(inspect.getmodule(obj), FakeModule):
200 print textwrap.dedent("""\
200 print textwrap.dedent("""\
201 Warning:%s is %s
201 Warning:%s is %s
202 Proper storage of interactively declared classes (or instances
202 Proper storage of interactively declared classes (or instances
203 of those classes) is not possible! Only instances
203 of those classes) is not possible! Only instances
204 of classes in real modules on file system can be %%store'd.
204 of classes in real modules on file system can be %%store'd.
205 """ % (args[0], obj) )
205 """ % (args[0], obj) )
206 return
206 return
207 #pickled = pickle.dumps(obj)
207 #pickled = pickle.dumps(obj)
208 db[ 'autorestore/' + args[0] ] = obj
208 db[ 'autorestore/' + args[0] ] = obj
209 print "Stored '%s' (%s)" % (args[0], obj.__class__.__name__)
209 print "Stored '%s' (%s)" % (args[0], obj.__class__.__name__)
210
210
211
211
212 _loaded = False
213
214
215 def load_ipython_extension(ip):
212 def load_ipython_extension(ip):
216 """Load the extension in IPython."""
213 """Load the extension in IPython."""
217 global _loaded
214 ip.register_magics(StoreMagics)
218 if not _loaded:
219 ip.register_magics(StoreMagics)
220 _loaded = True
General Comments 0
You need to be logged in to leave comments. Login now