##// END OF EJS Templates
Remove uses of iterkeys
Thomas Kluyver -
Show More
@@ -1,338 +1,338 b''
1 1 """Implementations for various useful completers.
2 2
3 3 These are all loaded by default by IPython.
4 4 """
5 5 #-----------------------------------------------------------------------------
6 6 # Copyright (C) 2010-2011 The IPython Development Team.
7 7 #
8 8 # Distributed under the terms of the BSD License.
9 9 #
10 10 # The full license is in the file COPYING.txt, distributed with this software.
11 11 #-----------------------------------------------------------------------------
12 12
13 13 #-----------------------------------------------------------------------------
14 14 # Imports
15 15 #-----------------------------------------------------------------------------
16 16 from __future__ import print_function
17 17
18 18 # Stdlib imports
19 19 import glob
20 20 import imp
21 21 import inspect
22 22 import os
23 23 import re
24 24 import sys
25 25
26 26 # Third-party imports
27 27 from time import time
28 28 from zipimport import zipimporter
29 29
30 30 # Our own imports
31 31 from IPython.core.completer import expand_user, compress_user
32 32 from IPython.core.error import TryNext
33 33 from IPython.utils._process_common import arg_split
34 34 from IPython.utils.py3compat import string_types
35 35
36 36 # FIXME: this should be pulled in with the right call via the component system
37 37 from IPython import get_ipython
38 38
39 39 #-----------------------------------------------------------------------------
40 40 # Globals and constants
41 41 #-----------------------------------------------------------------------------
42 42
43 43 # Time in seconds after which the rootmodules will be stored permanently in the
44 44 # ipython ip.db database (kept in the user's .ipython dir).
45 45 TIMEOUT_STORAGE = 2
46 46
47 47 # Time in seconds after which we give up
48 48 TIMEOUT_GIVEUP = 20
49 49
50 50 # Regular expression for the python import statement
51 51 import_re = re.compile(r'(?P<name>[a-zA-Z_][a-zA-Z0-9_]*?)'
52 52 r'(?P<package>[/\\]__init__)?'
53 53 r'(?P<suffix>%s)$' %
54 54 r'|'.join(re.escape(s[0]) for s in imp.get_suffixes()))
55 55
56 56 # RE for the ipython %run command (python + ipython scripts)
57 57 magic_run_re = re.compile(r'.*(\.ipy|\.py[w]?)$')
58 58
59 59 #-----------------------------------------------------------------------------
60 60 # Local utilities
61 61 #-----------------------------------------------------------------------------
62 62
63 63 def module_list(path):
64 64 """
65 65 Return the list containing the names of the modules available in the given
66 66 folder.
67 67 """
68 68 # sys.path has the cwd as an empty string, but isdir/listdir need it as '.'
69 69 if path == '':
70 70 path = '.'
71 71
72 72 # A few local constants to be used in loops below
73 73 pjoin = os.path.join
74 74
75 75 if os.path.isdir(path):
76 76 # Build a list of all files in the directory and all files
77 77 # in its subdirectories. For performance reasons, do not
78 78 # recurse more than one level into subdirectories.
79 79 files = []
80 80 for root, dirs, nondirs in os.walk(path):
81 81 subdir = root[len(path)+1:]
82 82 if subdir:
83 83 files.extend(pjoin(subdir, f) for f in nondirs)
84 84 dirs[:] = [] # Do not recurse into additional subdirectories.
85 85 else:
86 86 files.extend(nondirs)
87 87
88 88 else:
89 89 try:
90 90 files = list(zipimporter(path)._files.keys())
91 91 except:
92 92 files = []
93 93
94 94 # Build a list of modules which match the import_re regex.
95 95 modules = []
96 96 for f in files:
97 97 m = import_re.match(f)
98 98 if m:
99 99 modules.append(m.group('name'))
100 100 return list(set(modules))
101 101
102 102
103 103 def get_root_modules():
104 104 """
105 105 Returns a list containing the names of all the modules available in the
106 106 folders of the pythonpath.
107 107
108 108 ip.db['rootmodules_cache'] maps sys.path entries to list of modules.
109 109 """
110 110 ip = get_ipython()
111 111 rootmodules_cache = ip.db.get('rootmodules_cache', {})
112 112 rootmodules = list(sys.builtin_module_names)
113 113 start_time = time()
114 114 store = False
115 115 for path in sys.path:
116 116 try:
117 117 modules = rootmodules_cache[path]
118 118 except KeyError:
119 119 modules = module_list(path)
120 120 try:
121 121 modules.remove('__init__')
122 122 except ValueError:
123 123 pass
124 124 if path not in ('', '.'): # cwd modules should not be cached
125 125 rootmodules_cache[path] = modules
126 126 if time() - start_time > TIMEOUT_STORAGE and not store:
127 127 store = True
128 128 print("\nCaching the list of root modules, please wait!")
129 129 print("(This will only be done once - type '%rehashx' to "
130 130 "reset cache!)\n")
131 131 sys.stdout.flush()
132 132 if time() - start_time > TIMEOUT_GIVEUP:
133 133 print("This is taking too long, we give up.\n")
134 134 return []
135 135 rootmodules.extend(modules)
136 136 if store:
137 137 ip.db['rootmodules_cache'] = rootmodules_cache
138 138 rootmodules = list(set(rootmodules))
139 139 return rootmodules
140 140
141 141
142 142 def is_importable(module, attr, only_modules):
143 143 if only_modules:
144 144 return inspect.ismodule(getattr(module, attr))
145 145 else:
146 146 return not(attr[:2] == '__' and attr[-2:] == '__')
147 147
148 148
149 149 def try_import(mod, only_modules=False):
150 150 try:
151 151 m = __import__(mod)
152 152 except:
153 153 return []
154 154 mods = mod.split('.')
155 155 for module in mods[1:]:
156 156 m = getattr(m, module)
157 157
158 158 m_is_init = hasattr(m, '__file__') and '__init__' in m.__file__
159 159
160 160 completions = []
161 161 if (not hasattr(m, '__file__')) or (not only_modules) or m_is_init:
162 162 completions.extend( [attr for attr in dir(m) if
163 163 is_importable(m, attr, only_modules)])
164 164
165 165 completions.extend(getattr(m, '__all__', []))
166 166 if m_is_init:
167 167 completions.extend(module_list(os.path.dirname(m.__file__)))
168 168 completions = set(completions)
169 169 if '__init__' in completions:
170 170 completions.remove('__init__')
171 171 return list(completions)
172 172
173 173
174 174 #-----------------------------------------------------------------------------
175 175 # Completion-related functions.
176 176 #-----------------------------------------------------------------------------
177 177
178 178 def quick_completer(cmd, completions):
179 179 """ Easily create a trivial completer for a command.
180 180
181 181 Takes either a list of completions, or all completions in string (that will
182 182 be split on whitespace).
183 183
184 184 Example::
185 185
186 186 [d:\ipython]|1> import ipy_completers
187 187 [d:\ipython]|2> ipy_completers.quick_completer('foo', ['bar','baz'])
188 188 [d:\ipython]|3> foo b<TAB>
189 189 bar baz
190 190 [d:\ipython]|3> foo ba
191 191 """
192 192
193 193 if isinstance(completions, string_types):
194 194 completions = completions.split()
195 195
196 196 def do_complete(self, event):
197 197 return completions
198 198
199 199 get_ipython().set_hook('complete_command',do_complete, str_key = cmd)
200 200
201 201 def module_completion(line):
202 202 """
203 203 Returns a list containing the completion possibilities for an import line.
204 204
205 205 The line looks like this :
206 206 'import xml.d'
207 207 'from xml.dom import'
208 208 """
209 209
210 210 words = line.split(' ')
211 211 nwords = len(words)
212 212
213 213 # from whatever <tab> -> 'import '
214 214 if nwords == 3 and words[0] == 'from':
215 215 return ['import ']
216 216
217 217 # 'from xy<tab>' or 'import xy<tab>'
218 218 if nwords < 3 and (words[0] in ['import','from']) :
219 219 if nwords == 1:
220 220 return get_root_modules()
221 221 mod = words[1].split('.')
222 222 if len(mod) < 2:
223 223 return get_root_modules()
224 224 completion_list = try_import('.'.join(mod[:-1]), True)
225 225 return ['.'.join(mod[:-1] + [el]) for el in completion_list]
226 226
227 227 # 'from xyz import abc<tab>'
228 228 if nwords >= 3 and words[0] == 'from':
229 229 mod = words[1]
230 230 return try_import(mod)
231 231
232 232 #-----------------------------------------------------------------------------
233 233 # Completers
234 234 #-----------------------------------------------------------------------------
235 235 # These all have the func(self, event) signature to be used as custom
236 236 # completers
237 237
238 238 def module_completer(self,event):
239 239 """Give completions after user has typed 'import ...' or 'from ...'"""
240 240
241 241 # This works in all versions of python. While 2.5 has
242 242 # pkgutil.walk_packages(), that particular routine is fairly dangerous,
243 243 # since it imports *EVERYTHING* on sys.path. That is: a) very slow b) full
244 244 # of possibly problematic side effects.
245 245 # This search the folders in the sys.path for available modules.
246 246
247 247 return module_completion(event.line)
248 248
249 249 # FIXME: there's a lot of logic common to the run, cd and builtin file
250 250 # completers, that is currently reimplemented in each.
251 251
252 252 def magic_run_completer(self, event):
253 253 """Complete files that end in .py or .ipy for the %run command.
254 254 """
255 255 comps = arg_split(event.line, strict=False)
256 256 relpath = (len(comps) > 1 and comps[-1] or '').strip("'\"")
257 257
258 258 #print("\nev=", event) # dbg
259 259 #print("rp=", relpath) # dbg
260 260 #print('comps=', comps) # dbg
261 261
262 262 lglob = glob.glob
263 263 isdir = os.path.isdir
264 264 relpath, tilde_expand, tilde_val = expand_user(relpath)
265 265
266 266 dirs = [f.replace('\\','/') + "/" for f in lglob(relpath+'*') if isdir(f)]
267 267
268 268 # Find if the user has already typed the first filename, after which we
269 269 # should complete on all files, since after the first one other files may
270 270 # be arguments to the input script.
271 271
272 272 if filter(magic_run_re.match, comps):
273 273 pys = [f.replace('\\','/') for f in lglob('*')]
274 274 else:
275 275 pys = [f.replace('\\','/')
276 276 for f in lglob(relpath+'*.py') + lglob(relpath+'*.ipy') +
277 277 lglob(relpath + '*.pyw')]
278 278 #print('run comp:', dirs+pys) # dbg
279 279 return [compress_user(p, tilde_expand, tilde_val) for p in dirs+pys]
280 280
281 281
282 282 def cd_completer(self, event):
283 283 """Completer function for cd, which only returns directories."""
284 284 ip = get_ipython()
285 285 relpath = event.symbol
286 286
287 287 #print(event) # dbg
288 288 if event.line.endswith('-b') or ' -b ' in event.line:
289 289 # return only bookmark completions
290 290 bkms = self.db.get('bookmarks', None)
291 291 if bkms:
292 292 return bkms.keys()
293 293 else:
294 294 return []
295 295
296 296 if event.symbol == '-':
297 297 width_dh = str(len(str(len(ip.user_ns['_dh']) + 1)))
298 298 # jump in directory history by number
299 299 fmt = '-%0' + width_dh +'d [%s]'
300 300 ents = [ fmt % (i,s) for i,s in enumerate(ip.user_ns['_dh'])]
301 301 if len(ents) > 1:
302 302 return ents
303 303 return []
304 304
305 305 if event.symbol.startswith('--'):
306 306 return ["--" + os.path.basename(d) for d in ip.user_ns['_dh']]
307 307
308 308 # Expand ~ in path and normalize directory separators.
309 309 relpath, tilde_expand, tilde_val = expand_user(relpath)
310 310 relpath = relpath.replace('\\','/')
311 311
312 312 found = []
313 313 for d in [f.replace('\\','/') + '/' for f in glob.glob(relpath+'*')
314 314 if os.path.isdir(f)]:
315 315 if ' ' in d:
316 316 # we don't want to deal with any of that, complex code
317 317 # for this is elsewhere
318 318 raise TryNext
319 319
320 320 found.append(d)
321 321
322 322 if not found:
323 323 if os.path.isdir(relpath):
324 324 return [compress_user(relpath, tilde_expand, tilde_val)]
325 325
326 326 # if no completions so far, try bookmarks
327 bks = self.db.get('bookmarks',{}).iterkeys()
327 bks = self.db.get('bookmarks',{})
328 328 bkmatches = [s for s in bks if s.startswith(event.symbol)]
329 329 if bkmatches:
330 330 return bkmatches
331 331
332 332 raise TryNext
333 333
334 334 return [compress_user(p, tilde_expand, tilde_val) for p in found]
335 335
336 336 def reset_completer(self, event):
337 337 "A completer for %reset magic"
338 338 return '-f -s in out array dhist'.split()
@@ -1,789 +1,789 b''
1 1 #!/usr/bin/env python
2 2 """An interactive kernel that talks to frontends over 0MQ."""
3 3
4 4 #-----------------------------------------------------------------------------
5 5 # Imports
6 6 #-----------------------------------------------------------------------------
7 7 from __future__ import print_function
8 8
9 9 # Standard library imports
10 10 import sys
11 11 import time
12 12 import traceback
13 13 import logging
14 14 import uuid
15 15
16 16 from datetime import datetime
17 17 from signal import (
18 18 signal, default_int_handler, SIGINT
19 19 )
20 20
21 21 # System library imports
22 22 import zmq
23 23 from zmq.eventloop import ioloop
24 24 from zmq.eventloop.zmqstream import ZMQStream
25 25
26 26 # Local imports
27 27 from IPython.config.configurable import Configurable
28 28 from IPython.core.error import StdinNotImplementedError
29 29 from IPython.core import release
30 30 from IPython.utils import py3compat
31 31 from IPython.utils.py3compat import builtin_mod, unicode_type, string_types
32 32 from IPython.utils.jsonutil import json_clean
33 33 from IPython.utils.traitlets import (
34 34 Any, Instance, Float, Dict, List, Set, Integer, Unicode,
35 35 Type
36 36 )
37 37
38 38 from .serialize import serialize_object, unpack_apply_message
39 39 from .session import Session
40 40 from .zmqshell import ZMQInteractiveShell
41 41
42 42
43 43 #-----------------------------------------------------------------------------
44 44 # Main kernel class
45 45 #-----------------------------------------------------------------------------
46 46
47 47 protocol_version = list(release.kernel_protocol_version_info)
48 48 ipython_version = list(release.version_info)
49 49 language_version = list(sys.version_info[:3])
50 50
51 51
52 52 class Kernel(Configurable):
53 53
54 54 #---------------------------------------------------------------------------
55 55 # Kernel interface
56 56 #---------------------------------------------------------------------------
57 57
58 58 # attribute to override with a GUI
59 59 eventloop = Any(None)
60 60 def _eventloop_changed(self, name, old, new):
61 61 """schedule call to eventloop from IOLoop"""
62 62 loop = ioloop.IOLoop.instance()
63 63 loop.add_timeout(time.time()+0.1, self.enter_eventloop)
64 64
65 65 shell = Instance('IPython.core.interactiveshell.InteractiveShellABC')
66 66 shell_class = Type(ZMQInteractiveShell)
67 67
68 68 session = Instance(Session)
69 69 profile_dir = Instance('IPython.core.profiledir.ProfileDir')
70 70 shell_streams = List()
71 71 control_stream = Instance(ZMQStream)
72 72 iopub_socket = Instance(zmq.Socket)
73 73 stdin_socket = Instance(zmq.Socket)
74 74 log = Instance(logging.Logger)
75 75
76 76 user_module = Any()
77 77 def _user_module_changed(self, name, old, new):
78 78 if self.shell is not None:
79 79 self.shell.user_module = new
80 80
81 81 user_ns = Instance(dict, args=None, allow_none=True)
82 82 def _user_ns_changed(self, name, old, new):
83 83 if self.shell is not None:
84 84 self.shell.user_ns = new
85 85 self.shell.init_user_ns()
86 86
87 87 # identities:
88 88 int_id = Integer(-1)
89 89 ident = Unicode()
90 90
91 91 def _ident_default(self):
92 92 return unicode_type(uuid.uuid4())
93 93
94 94
95 95 # Private interface
96 96
97 97 # Time to sleep after flushing the stdout/err buffers in each execute
98 98 # cycle. While this introduces a hard limit on the minimal latency of the
99 99 # execute cycle, it helps prevent output synchronization problems for
100 100 # clients.
101 101 # Units are in seconds. The minimum zmq latency on local host is probably
102 102 # ~150 microseconds, set this to 500us for now. We may need to increase it
103 103 # a little if it's not enough after more interactive testing.
104 104 _execute_sleep = Float(0.0005, config=True)
105 105
106 106 # Frequency of the kernel's event loop.
107 107 # Units are in seconds, kernel subclasses for GUI toolkits may need to
108 108 # adapt to milliseconds.
109 109 _poll_interval = Float(0.05, config=True)
110 110
111 111 # If the shutdown was requested over the network, we leave here the
112 112 # necessary reply message so it can be sent by our registered atexit
113 113 # handler. This ensures that the reply is only sent to clients truly at
114 114 # the end of our shutdown process (which happens after the underlying
115 115 # IPython shell's own shutdown).
116 116 _shutdown_message = None
117 117
118 118 # This is a dict of port number that the kernel is listening on. It is set
119 119 # by record_ports and used by connect_request.
120 120 _recorded_ports = Dict()
121 121
122 122 # A reference to the Python builtin 'raw_input' function.
123 123 # (i.e., __builtin__.raw_input for Python 2.7, builtins.input for Python 3)
124 124 _sys_raw_input = Any()
125 125 _sys_eval_input = Any()
126 126
127 127 # set of aborted msg_ids
128 128 aborted = Set()
129 129
130 130
131 131 def __init__(self, **kwargs):
132 132 super(Kernel, self).__init__(**kwargs)
133 133
134 134 # Initialize the InteractiveShell subclass
135 135 self.shell = self.shell_class.instance(parent=self,
136 136 profile_dir = self.profile_dir,
137 137 user_module = self.user_module,
138 138 user_ns = self.user_ns,
139 139 kernel = self,
140 140 )
141 141 self.shell.displayhook.session = self.session
142 142 self.shell.displayhook.pub_socket = self.iopub_socket
143 143 self.shell.displayhook.topic = self._topic('pyout')
144 144 self.shell.display_pub.session = self.session
145 145 self.shell.display_pub.pub_socket = self.iopub_socket
146 146 self.shell.data_pub.session = self.session
147 147 self.shell.data_pub.pub_socket = self.iopub_socket
148 148
149 149 # TMP - hack while developing
150 150 self.shell._reply_content = None
151 151
152 152 # Build dict of handlers for message types
153 153 msg_types = [ 'execute_request', 'complete_request',
154 154 'object_info_request', 'history_request',
155 155 'kernel_info_request',
156 156 'connect_request', 'shutdown_request',
157 157 'apply_request',
158 158 ]
159 159 self.shell_handlers = {}
160 160 for msg_type in msg_types:
161 161 self.shell_handlers[msg_type] = getattr(self, msg_type)
162 162
163 163 comm_msg_types = [ 'comm_open', 'comm_msg', 'comm_close' ]
164 164 comm_manager = self.shell.comm_manager
165 165 for msg_type in comm_msg_types:
166 166 self.shell_handlers[msg_type] = getattr(comm_manager, msg_type)
167 167
168 168 control_msg_types = msg_types + [ 'clear_request', 'abort_request' ]
169 169 self.control_handlers = {}
170 170 for msg_type in control_msg_types:
171 171 self.control_handlers[msg_type] = getattr(self, msg_type)
172 172
173 173
174 174 def dispatch_control(self, msg):
175 175 """dispatch control requests"""
176 176 idents,msg = self.session.feed_identities(msg, copy=False)
177 177 try:
178 178 msg = self.session.unserialize(msg, content=True, copy=False)
179 179 except:
180 180 self.log.error("Invalid Control Message", exc_info=True)
181 181 return
182 182
183 183 self.log.debug("Control received: %s", msg)
184 184
185 185 header = msg['header']
186 186 msg_id = header['msg_id']
187 187 msg_type = header['msg_type']
188 188
189 189 handler = self.control_handlers.get(msg_type, None)
190 190 if handler is None:
191 191 self.log.error("UNKNOWN CONTROL MESSAGE TYPE: %r", msg_type)
192 192 else:
193 193 try:
194 194 handler(self.control_stream, idents, msg)
195 195 except Exception:
196 196 self.log.error("Exception in control handler:", exc_info=True)
197 197
198 198 def dispatch_shell(self, stream, msg):
199 199 """dispatch shell requests"""
200 200 # flush control requests first
201 201 if self.control_stream:
202 202 self.control_stream.flush()
203 203
204 204 idents,msg = self.session.feed_identities(msg, copy=False)
205 205 try:
206 206 msg = self.session.unserialize(msg, content=True, copy=False)
207 207 except:
208 208 self.log.error("Invalid Message", exc_info=True)
209 209 return
210 210
211 211 header = msg['header']
212 212 msg_id = header['msg_id']
213 213 msg_type = msg['header']['msg_type']
214 214
215 215 # Print some info about this message and leave a '--->' marker, so it's
216 216 # easier to trace visually the message chain when debugging. Each
217 217 # handler prints its message at the end.
218 218 self.log.debug('\n*** MESSAGE TYPE:%s***', msg_type)
219 219 self.log.debug(' Content: %s\n --->\n ', msg['content'])
220 220
221 221 if msg_id in self.aborted:
222 222 self.aborted.remove(msg_id)
223 223 # is it safe to assume a msg_id will not be resubmitted?
224 224 reply_type = msg_type.split('_')[0] + '_reply'
225 225 status = {'status' : 'aborted'}
226 226 md = {'engine' : self.ident}
227 227 md.update(status)
228 228 reply_msg = self.session.send(stream, reply_type, metadata=md,
229 229 content=status, parent=msg, ident=idents)
230 230 return
231 231
232 232 handler = self.shell_handlers.get(msg_type, None)
233 233 if handler is None:
234 234 self.log.error("UNKNOWN MESSAGE TYPE: %r", msg_type)
235 235 else:
236 236 # ensure default_int_handler during handler call
237 237 sig = signal(SIGINT, default_int_handler)
238 238 try:
239 239 handler(stream, idents, msg)
240 240 except Exception:
241 241 self.log.error("Exception in message handler:", exc_info=True)
242 242 finally:
243 243 signal(SIGINT, sig)
244 244
245 245 def enter_eventloop(self):
246 246 """enter eventloop"""
247 247 self.log.info("entering eventloop")
248 248 # restore default_int_handler
249 249 signal(SIGINT, default_int_handler)
250 250 while self.eventloop is not None:
251 251 try:
252 252 self.eventloop(self)
253 253 except KeyboardInterrupt:
254 254 # Ctrl-C shouldn't crash the kernel
255 255 self.log.error("KeyboardInterrupt caught in kernel")
256 256 continue
257 257 else:
258 258 # eventloop exited cleanly, this means we should stop (right?)
259 259 self.eventloop = None
260 260 break
261 261 self.log.info("exiting eventloop")
262 262
263 263 def start(self):
264 264 """register dispatchers for streams"""
265 265 self.shell.exit_now = False
266 266 if self.control_stream:
267 267 self.control_stream.on_recv(self.dispatch_control, copy=False)
268 268
269 269 def make_dispatcher(stream):
270 270 def dispatcher(msg):
271 271 return self.dispatch_shell(stream, msg)
272 272 return dispatcher
273 273
274 274 for s in self.shell_streams:
275 275 s.on_recv(make_dispatcher(s), copy=False)
276 276
277 277 # publish idle status
278 278 self._publish_status('starting')
279 279
280 280 def do_one_iteration(self):
281 281 """step eventloop just once"""
282 282 if self.control_stream:
283 283 self.control_stream.flush()
284 284 for stream in self.shell_streams:
285 285 # handle at most one request per iteration
286 286 stream.flush(zmq.POLLIN, 1)
287 287 stream.flush(zmq.POLLOUT)
288 288
289 289
290 290 def record_ports(self, ports):
291 291 """Record the ports that this kernel is using.
292 292
293 293 The creator of the Kernel instance must call this methods if they
294 294 want the :meth:`connect_request` method to return the port numbers.
295 295 """
296 296 self._recorded_ports = ports
297 297
298 298 #---------------------------------------------------------------------------
299 299 # Kernel request handlers
300 300 #---------------------------------------------------------------------------
301 301
302 302 def _make_metadata(self, other=None):
303 303 """init metadata dict, for execute/apply_reply"""
304 304 new_md = {
305 305 'dependencies_met' : True,
306 306 'engine' : self.ident,
307 307 'started': datetime.now(),
308 308 }
309 309 if other:
310 310 new_md.update(other)
311 311 return new_md
312 312
313 313 def _publish_pyin(self, code, parent, execution_count):
314 314 """Publish the code request on the pyin stream."""
315 315
316 316 self.session.send(self.iopub_socket, u'pyin',
317 317 {u'code':code, u'execution_count': execution_count},
318 318 parent=parent, ident=self._topic('pyin')
319 319 )
320 320
321 321 def _publish_status(self, status, parent=None):
322 322 """send status (busy/idle) on IOPub"""
323 323 self.session.send(self.iopub_socket,
324 324 u'status',
325 325 {u'execution_state': status},
326 326 parent=parent,
327 327 ident=self._topic('status'),
328 328 )
329 329
330 330
331 331 def execute_request(self, stream, ident, parent):
332 332 """handle an execute_request"""
333 333
334 334 self._publish_status(u'busy', parent)
335 335
336 336 try:
337 337 content = parent[u'content']
338 338 code = content[u'code']
339 339 silent = content[u'silent']
340 340 store_history = content.get(u'store_history', not silent)
341 341 except:
342 342 self.log.error("Got bad msg: ")
343 343 self.log.error("%s", parent)
344 344 return
345 345
346 346 md = self._make_metadata(parent['metadata'])
347 347
348 348 shell = self.shell # we'll need this a lot here
349 349
350 350 # Replace raw_input. Note that is not sufficient to replace
351 351 # raw_input in the user namespace.
352 352 if content.get('allow_stdin', False):
353 353 raw_input = lambda prompt='': self._raw_input(prompt, ident, parent)
354 354 input = lambda prompt='': eval(raw_input(prompt))
355 355 else:
356 356 raw_input = input = lambda prompt='' : self._no_raw_input()
357 357
358 358 if py3compat.PY3:
359 359 self._sys_raw_input = builtin_mod.input
360 360 builtin_mod.input = raw_input
361 361 else:
362 362 self._sys_raw_input = builtin_mod.raw_input
363 363 self._sys_eval_input = builtin_mod.input
364 364 builtin_mod.raw_input = raw_input
365 365 builtin_mod.input = input
366 366
367 367 # Set the parent message of the display hook and out streams.
368 368 shell.set_parent(parent)
369 369
370 370 # Re-broadcast our input for the benefit of listening clients, and
371 371 # start computing output
372 372 if not silent:
373 373 self._publish_pyin(code, parent, shell.execution_count)
374 374
375 375 reply_content = {}
376 376 try:
377 377 # FIXME: the shell calls the exception handler itself.
378 378 shell.run_cell(code, store_history=store_history, silent=silent)
379 379 except:
380 380 status = u'error'
381 381 # FIXME: this code right now isn't being used yet by default,
382 382 # because the run_cell() call above directly fires off exception
383 383 # reporting. This code, therefore, is only active in the scenario
384 384 # where runlines itself has an unhandled exception. We need to
385 385 # uniformize this, for all exception construction to come from a
386 386 # single location in the codbase.
387 387 etype, evalue, tb = sys.exc_info()
388 388 tb_list = traceback.format_exception(etype, evalue, tb)
389 389 reply_content.update(shell._showtraceback(etype, evalue, tb_list))
390 390 else:
391 391 status = u'ok'
392 392 finally:
393 393 # Restore raw_input.
394 394 if py3compat.PY3:
395 395 builtin_mod.input = self._sys_raw_input
396 396 else:
397 397 builtin_mod.raw_input = self._sys_raw_input
398 398 builtin_mod.input = self._sys_eval_input
399 399
400 400 reply_content[u'status'] = status
401 401
402 402 # Return the execution counter so clients can display prompts
403 403 reply_content['execution_count'] = shell.execution_count - 1
404 404
405 405 # FIXME - fish exception info out of shell, possibly left there by
406 406 # runlines. We'll need to clean up this logic later.
407 407 if shell._reply_content is not None:
408 408 reply_content.update(shell._reply_content)
409 409 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method='execute')
410 410 reply_content['engine_info'] = e_info
411 411 # reset after use
412 412 shell._reply_content = None
413 413
414 414 if 'traceback' in reply_content:
415 415 self.log.info("Exception in execute request:\n%s", '\n'.join(reply_content['traceback']))
416 416
417 417
418 418 # At this point, we can tell whether the main code execution succeeded
419 419 # or not. If it did, we proceed to evaluate user_variables/expressions
420 420 if reply_content['status'] == 'ok':
421 421 reply_content[u'user_variables'] = \
422 422 shell.user_variables(content.get(u'user_variables', []))
423 423 reply_content[u'user_expressions'] = \
424 424 shell.user_expressions(content.get(u'user_expressions', {}))
425 425 else:
426 426 # If there was an error, don't even try to compute variables or
427 427 # expressions
428 428 reply_content[u'user_variables'] = {}
429 429 reply_content[u'user_expressions'] = {}
430 430
431 431 # Payloads should be retrieved regardless of outcome, so we can both
432 432 # recover partial output (that could have been generated early in a
433 433 # block, before an error) and clear the payload system always.
434 434 reply_content[u'payload'] = shell.payload_manager.read_payload()
435 435 # Be agressive about clearing the payload because we don't want
436 436 # it to sit in memory until the next execute_request comes in.
437 437 shell.payload_manager.clear_payload()
438 438
439 439 # Flush output before sending the reply.
440 440 sys.stdout.flush()
441 441 sys.stderr.flush()
442 442 # FIXME: on rare occasions, the flush doesn't seem to make it to the
443 443 # clients... This seems to mitigate the problem, but we definitely need
444 444 # to better understand what's going on.
445 445 if self._execute_sleep:
446 446 time.sleep(self._execute_sleep)
447 447
448 448 # Send the reply.
449 449 reply_content = json_clean(reply_content)
450 450
451 451 md['status'] = reply_content['status']
452 452 if reply_content['status'] == 'error' and \
453 453 reply_content['ename'] == 'UnmetDependency':
454 454 md['dependencies_met'] = False
455 455
456 456 reply_msg = self.session.send(stream, u'execute_reply',
457 457 reply_content, parent, metadata=md,
458 458 ident=ident)
459 459
460 460 self.log.debug("%s", reply_msg)
461 461
462 462 if not silent and reply_msg['content']['status'] == u'error':
463 463 self._abort_queues()
464 464
465 465 self._publish_status(u'idle', parent)
466 466
467 467 def complete_request(self, stream, ident, parent):
468 468 txt, matches = self._complete(parent)
469 469 matches = {'matches' : matches,
470 470 'matched_text' : txt,
471 471 'status' : 'ok'}
472 472 matches = json_clean(matches)
473 473 completion_msg = self.session.send(stream, 'complete_reply',
474 474 matches, parent, ident)
475 475 self.log.debug("%s", completion_msg)
476 476
477 477 def object_info_request(self, stream, ident, parent):
478 478 content = parent['content']
479 479 object_info = self.shell.object_inspect(content['oname'],
480 480 detail_level = content.get('detail_level', 0)
481 481 )
482 482 # Before we send this object over, we scrub it for JSON usage
483 483 oinfo = json_clean(object_info)
484 484 msg = self.session.send(stream, 'object_info_reply',
485 485 oinfo, parent, ident)
486 486 self.log.debug("%s", msg)
487 487
488 488 def history_request(self, stream, ident, parent):
489 489 # We need to pull these out, as passing **kwargs doesn't work with
490 490 # unicode keys before Python 2.6.5.
491 491 hist_access_type = parent['content']['hist_access_type']
492 492 raw = parent['content']['raw']
493 493 output = parent['content']['output']
494 494 if hist_access_type == 'tail':
495 495 n = parent['content']['n']
496 496 hist = self.shell.history_manager.get_tail(n, raw=raw, output=output,
497 497 include_latest=True)
498 498
499 499 elif hist_access_type == 'range':
500 500 session = parent['content']['session']
501 501 start = parent['content']['start']
502 502 stop = parent['content']['stop']
503 503 hist = self.shell.history_manager.get_range(session, start, stop,
504 504 raw=raw, output=output)
505 505
506 506 elif hist_access_type == 'search':
507 507 n = parent['content'].get('n')
508 508 unique = parent['content'].get('unique', False)
509 509 pattern = parent['content']['pattern']
510 510 hist = self.shell.history_manager.search(
511 511 pattern, raw=raw, output=output, n=n, unique=unique)
512 512
513 513 else:
514 514 hist = []
515 515 hist = list(hist)
516 516 content = {'history' : hist}
517 517 content = json_clean(content)
518 518 msg = self.session.send(stream, 'history_reply',
519 519 content, parent, ident)
520 520 self.log.debug("Sending history reply with %i entries", len(hist))
521 521
522 522 def connect_request(self, stream, ident, parent):
523 523 if self._recorded_ports is not None:
524 524 content = self._recorded_ports.copy()
525 525 else:
526 526 content = {}
527 527 msg = self.session.send(stream, 'connect_reply',
528 528 content, parent, ident)
529 529 self.log.debug("%s", msg)
530 530
531 531 def kernel_info_request(self, stream, ident, parent):
532 532 vinfo = {
533 533 'protocol_version': protocol_version,
534 534 'ipython_version': ipython_version,
535 535 'language_version': language_version,
536 536 'language': 'python',
537 537 }
538 538 msg = self.session.send(stream, 'kernel_info_reply',
539 539 vinfo, parent, ident)
540 540 self.log.debug("%s", msg)
541 541
542 542 def shutdown_request(self, stream, ident, parent):
543 543 self.shell.exit_now = True
544 544 content = dict(status='ok')
545 545 content.update(parent['content'])
546 546 self.session.send(stream, u'shutdown_reply', content, parent, ident=ident)
547 547 # same content, but different msg_id for broadcasting on IOPub
548 548 self._shutdown_message = self.session.msg(u'shutdown_reply',
549 549 content, parent
550 550 )
551 551
552 552 self._at_shutdown()
553 553 # call sys.exit after a short delay
554 554 loop = ioloop.IOLoop.instance()
555 555 loop.add_timeout(time.time()+0.1, loop.stop)
556 556
557 557 #---------------------------------------------------------------------------
558 558 # Engine methods
559 559 #---------------------------------------------------------------------------
560 560
561 561 def apply_request(self, stream, ident, parent):
562 562 try:
563 563 content = parent[u'content']
564 564 bufs = parent[u'buffers']
565 565 msg_id = parent['header']['msg_id']
566 566 except:
567 567 self.log.error("Got bad msg: %s", parent, exc_info=True)
568 568 return
569 569
570 570 self._publish_status(u'busy', parent)
571 571
572 572 # Set the parent message of the display hook and out streams.
573 573 shell = self.shell
574 574 shell.set_parent(parent)
575 575
576 576 # pyin_msg = self.session.msg(u'pyin',{u'code':code}, parent=parent)
577 577 # self.iopub_socket.send(pyin_msg)
578 578 # self.session.send(self.iopub_socket, u'pyin', {u'code':code},parent=parent)
579 579 md = self._make_metadata(parent['metadata'])
580 580 try:
581 581 working = shell.user_ns
582 582
583 583 prefix = "_"+str(msg_id).replace("-","")+"_"
584 584
585 585 f,args,kwargs = unpack_apply_message(bufs, working, copy=False)
586 586
587 587 fname = getattr(f, '__name__', 'f')
588 588
589 589 fname = prefix+"f"
590 590 argname = prefix+"args"
591 591 kwargname = prefix+"kwargs"
592 592 resultname = prefix+"result"
593 593
594 594 ns = { fname : f, argname : args, kwargname : kwargs , resultname : None }
595 595 # print ns
596 596 working.update(ns)
597 597 code = "%s = %s(*%s,**%s)" % (resultname, fname, argname, kwargname)
598 598 try:
599 599 exec(code, shell.user_global_ns, shell.user_ns)
600 600 result = working.get(resultname)
601 601 finally:
602 for key in ns.iterkeys():
602 for key in ns:
603 603 working.pop(key)
604 604
605 605 result_buf = serialize_object(result,
606 606 buffer_threshold=self.session.buffer_threshold,
607 607 item_threshold=self.session.item_threshold,
608 608 )
609 609
610 610 except:
611 611 # invoke IPython traceback formatting
612 612 shell.showtraceback()
613 613 # FIXME - fish exception info out of shell, possibly left there by
614 614 # run_code. We'll need to clean up this logic later.
615 615 reply_content = {}
616 616 if shell._reply_content is not None:
617 617 reply_content.update(shell._reply_content)
618 618 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method='apply')
619 619 reply_content['engine_info'] = e_info
620 620 # reset after use
621 621 shell._reply_content = None
622 622
623 623 self.session.send(self.iopub_socket, u'pyerr', reply_content, parent=parent,
624 624 ident=self._topic('pyerr'))
625 625 self.log.info("Exception in apply request:\n%s", '\n'.join(reply_content['traceback']))
626 626 result_buf = []
627 627
628 628 if reply_content['ename'] == 'UnmetDependency':
629 629 md['dependencies_met'] = False
630 630 else:
631 631 reply_content = {'status' : 'ok'}
632 632
633 633 # put 'ok'/'error' status in header, for scheduler introspection:
634 634 md['status'] = reply_content['status']
635 635
636 636 # flush i/o
637 637 sys.stdout.flush()
638 638 sys.stderr.flush()
639 639
640 640 reply_msg = self.session.send(stream, u'apply_reply', reply_content,
641 641 parent=parent, ident=ident,buffers=result_buf, metadata=md)
642 642
643 643 self._publish_status(u'idle', parent)
644 644
645 645 #---------------------------------------------------------------------------
646 646 # Control messages
647 647 #---------------------------------------------------------------------------
648 648
649 649 def abort_request(self, stream, ident, parent):
650 650 """abort a specifig msg by id"""
651 651 msg_ids = parent['content'].get('msg_ids', None)
652 652 if isinstance(msg_ids, string_types):
653 653 msg_ids = [msg_ids]
654 654 if not msg_ids:
655 655 self.abort_queues()
656 656 for mid in msg_ids:
657 657 self.aborted.add(str(mid))
658 658
659 659 content = dict(status='ok')
660 660 reply_msg = self.session.send(stream, 'abort_reply', content=content,
661 661 parent=parent, ident=ident)
662 662 self.log.debug("%s", reply_msg)
663 663
664 664 def clear_request(self, stream, idents, parent):
665 665 """Clear our namespace."""
666 666 self.shell.reset(False)
667 667 msg = self.session.send(stream, 'clear_reply', ident=idents, parent=parent,
668 668 content = dict(status='ok'))
669 669
670 670
671 671 #---------------------------------------------------------------------------
672 672 # Protected interface
673 673 #---------------------------------------------------------------------------
674 674
675 675 def _wrap_exception(self, method=None):
676 676 # import here, because _wrap_exception is only used in parallel,
677 677 # and parallel has higher min pyzmq version
678 678 from IPython.parallel.error import wrap_exception
679 679 e_info = dict(engine_uuid=self.ident, engine_id=self.int_id, method=method)
680 680 content = wrap_exception(e_info)
681 681 return content
682 682
683 683 def _topic(self, topic):
684 684 """prefixed topic for IOPub messages"""
685 685 if self.int_id >= 0:
686 686 base = "engine.%i" % self.int_id
687 687 else:
688 688 base = "kernel.%s" % self.ident
689 689
690 690 return py3compat.cast_bytes("%s.%s" % (base, topic))
691 691
692 692 def _abort_queues(self):
693 693 for stream in self.shell_streams:
694 694 if stream:
695 695 self._abort_queue(stream)
696 696
697 697 def _abort_queue(self, stream):
698 698 poller = zmq.Poller()
699 699 poller.register(stream.socket, zmq.POLLIN)
700 700 while True:
701 701 idents,msg = self.session.recv(stream, zmq.NOBLOCK, content=True)
702 702 if msg is None:
703 703 return
704 704
705 705 self.log.info("Aborting:")
706 706 self.log.info("%s", msg)
707 707 msg_type = msg['header']['msg_type']
708 708 reply_type = msg_type.split('_')[0] + '_reply'
709 709
710 710 status = {'status' : 'aborted'}
711 711 md = {'engine' : self.ident}
712 712 md.update(status)
713 713 reply_msg = self.session.send(stream, reply_type, metadata=md,
714 714 content=status, parent=msg, ident=idents)
715 715 self.log.debug("%s", reply_msg)
716 716 # We need to wait a bit for requests to come in. This can probably
717 717 # be set shorter for true asynchronous clients.
718 718 poller.poll(50)
719 719
720 720
721 721 def _no_raw_input(self):
722 722 """Raise StdinNotImplentedError if active frontend doesn't support
723 723 stdin."""
724 724 raise StdinNotImplementedError("raw_input was called, but this "
725 725 "frontend does not support stdin.")
726 726
727 727 def _raw_input(self, prompt, ident, parent):
728 728 # Flush output before making the request.
729 729 sys.stderr.flush()
730 730 sys.stdout.flush()
731 731 # flush the stdin socket, to purge stale replies
732 732 while True:
733 733 try:
734 734 self.stdin_socket.recv_multipart(zmq.NOBLOCK)
735 735 except zmq.ZMQError as e:
736 736 if e.errno == zmq.EAGAIN:
737 737 break
738 738 else:
739 739 raise
740 740
741 741 # Send the input request.
742 742 content = json_clean(dict(prompt=prompt))
743 743 self.session.send(self.stdin_socket, u'input_request', content, parent,
744 744 ident=ident)
745 745
746 746 # Await a response.
747 747 while True:
748 748 try:
749 749 ident, reply = self.session.recv(self.stdin_socket, 0)
750 750 except Exception:
751 751 self.log.warn("Invalid Message:", exc_info=True)
752 752 except KeyboardInterrupt:
753 753 # re-raise KeyboardInterrupt, to truncate traceback
754 754 raise KeyboardInterrupt
755 755 else:
756 756 break
757 757 try:
758 758 value = py3compat.unicode_to_str(reply['content']['value'])
759 759 except:
760 760 self.log.error("Got bad raw_input reply: ")
761 761 self.log.error("%s", parent)
762 762 value = ''
763 763 if value == '\x04':
764 764 # EOF
765 765 raise EOFError
766 766 return value
767 767
768 768 def _complete(self, msg):
769 769 c = msg['content']
770 770 try:
771 771 cpos = int(c['cursor_pos'])
772 772 except:
773 773 # If we don't get something that we can convert to an integer, at
774 774 # least attempt the completion guessing the cursor is at the end of
775 775 # the text, if there's any, and otherwise of the line
776 776 cpos = len(c['text'])
777 777 if cpos==0:
778 778 cpos = len(c['line'])
779 779 return self.shell.complete(c['text'], c['line'], cpos)
780 780
781 781 def _at_shutdown(self):
782 782 """Actions taken at shutdown by the kernel, called by python's atexit.
783 783 """
784 784 # io.rprint("Kernel at_shutdown") # dbg
785 785 if self._shutdown_message is not None:
786 786 self.session.send(self.iopub_socket, self._shutdown_message, ident=self._topic('shutdown'))
787 787 self.log.debug("%s", self._shutdown_message)
788 788 [ s.flush(zmq.POLLOUT) for s in self.shell_streams ]
789 789
@@ -1,198 +1,198 b''
1 1 """serialization utilities for apply messages
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 try:
19 19 import cPickle
20 20 pickle = cPickle
21 21 except:
22 22 cPickle = None
23 23 import pickle
24 24
25 25
26 26 # IPython imports
27 27 from IPython.utils import py3compat
28 28 from IPython.utils.data import flatten
29 29 from IPython.utils.pickleutil import (
30 30 can, uncan, can_sequence, uncan_sequence, CannedObject,
31 31 istype, sequence_types,
32 32 )
33 33
34 34 if py3compat.PY3:
35 35 buffer = memoryview
36 36
37 37 #-----------------------------------------------------------------------------
38 38 # Serialization Functions
39 39 #-----------------------------------------------------------------------------
40 40
41 41 # default values for the thresholds:
42 42 MAX_ITEMS = 64
43 43 MAX_BYTES = 1024
44 44
45 45 def _extract_buffers(obj, threshold=MAX_BYTES):
46 46 """extract buffers larger than a certain threshold"""
47 47 buffers = []
48 48 if isinstance(obj, CannedObject) and obj.buffers:
49 49 for i,buf in enumerate(obj.buffers):
50 50 if len(buf) > threshold:
51 51 # buffer larger than threshold, prevent pickling
52 52 obj.buffers[i] = None
53 53 buffers.append(buf)
54 54 elif isinstance(buf, buffer):
55 55 # buffer too small for separate send, coerce to bytes
56 56 # because pickling buffer objects just results in broken pointers
57 57 obj.buffers[i] = bytes(buf)
58 58 return buffers
59 59
60 60 def _restore_buffers(obj, buffers):
61 61 """restore buffers extracted by """
62 62 if isinstance(obj, CannedObject) and obj.buffers:
63 63 for i,buf in enumerate(obj.buffers):
64 64 if buf is None:
65 65 obj.buffers[i] = buffers.pop(0)
66 66
67 67 def serialize_object(obj, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
68 68 """Serialize an object into a list of sendable buffers.
69 69
70 70 Parameters
71 71 ----------
72 72
73 73 obj : object
74 74 The object to be serialized
75 75 buffer_threshold : int
76 76 The threshold (in bytes) for pulling out data buffers
77 77 to avoid pickling them.
78 78 item_threshold : int
79 79 The maximum number of items over which canning will iterate.
80 80 Containers (lists, dicts) larger than this will be pickled without
81 81 introspection.
82 82
83 83 Returns
84 84 -------
85 85 [bufs] : list of buffers representing the serialized object.
86 86 """
87 87 buffers = []
88 88 if istype(obj, sequence_types) and len(obj) < item_threshold:
89 89 cobj = can_sequence(obj)
90 90 for c in cobj:
91 91 buffers.extend(_extract_buffers(c, buffer_threshold))
92 92 elif istype(obj, dict) and len(obj) < item_threshold:
93 93 cobj = {}
94 for k in sorted(obj.iterkeys()):
94 for k in sorted(obj):
95 95 c = can(obj[k])
96 96 buffers.extend(_extract_buffers(c, buffer_threshold))
97 97 cobj[k] = c
98 98 else:
99 99 cobj = can(obj)
100 100 buffers.extend(_extract_buffers(cobj, buffer_threshold))
101 101
102 102 buffers.insert(0, pickle.dumps(cobj,-1))
103 103 return buffers
104 104
105 105 def unserialize_object(buffers, g=None):
106 106 """reconstruct an object serialized by serialize_object from data buffers.
107 107
108 108 Parameters
109 109 ----------
110 110
111 111 bufs : list of buffers/bytes
112 112
113 113 g : globals to be used when uncanning
114 114
115 115 Returns
116 116 -------
117 117
118 118 (newobj, bufs) : unpacked object, and the list of remaining unused buffers.
119 119 """
120 120 bufs = list(buffers)
121 121 pobj = bufs.pop(0)
122 122 if not isinstance(pobj, bytes):
123 123 # a zmq message
124 124 pobj = bytes(pobj)
125 125 canned = pickle.loads(pobj)
126 126 if istype(canned, sequence_types) and len(canned) < MAX_ITEMS:
127 127 for c in canned:
128 128 _restore_buffers(c, bufs)
129 129 newobj = uncan_sequence(canned, g)
130 130 elif istype(canned, dict) and len(canned) < MAX_ITEMS:
131 131 newobj = {}
132 for k in sorted(canned.iterkeys()):
132 for k in sorted(canned):
133 133 c = canned[k]
134 134 _restore_buffers(c, bufs)
135 135 newobj[k] = uncan(c, g)
136 136 else:
137 137 _restore_buffers(canned, bufs)
138 138 newobj = uncan(canned, g)
139 139
140 140 return newobj, bufs
141 141
142 142 def pack_apply_message(f, args, kwargs, buffer_threshold=MAX_BYTES, item_threshold=MAX_ITEMS):
143 143 """pack up a function, args, and kwargs to be sent over the wire
144 144
145 145 Each element of args/kwargs will be canned for special treatment,
146 146 but inspection will not go any deeper than that.
147 147
148 148 Any object whose data is larger than `threshold` will not have their data copied
149 149 (only numpy arrays and bytes/buffers support zero-copy)
150 150
151 151 Message will be a list of bytes/buffers of the format:
152 152
153 153 [ cf, pinfo, <arg_bufs>, <kwarg_bufs> ]
154 154
155 155 With length at least two + len(args) + len(kwargs)
156 156 """
157 157
158 158 arg_bufs = flatten(serialize_object(arg, buffer_threshold, item_threshold) for arg in args)
159 159
160 160 kw_keys = sorted(kwargs.keys())
161 161 kwarg_bufs = flatten(serialize_object(kwargs[key], buffer_threshold, item_threshold) for key in kw_keys)
162 162
163 163 info = dict(nargs=len(args), narg_bufs=len(arg_bufs), kw_keys=kw_keys)
164 164
165 165 msg = [pickle.dumps(can(f),-1)]
166 166 msg.append(pickle.dumps(info, -1))
167 167 msg.extend(arg_bufs)
168 168 msg.extend(kwarg_bufs)
169 169
170 170 return msg
171 171
172 172 def unpack_apply_message(bufs, g=None, copy=True):
173 173 """unpack f,args,kwargs from buffers packed by pack_apply_message()
174 174 Returns: original f,args,kwargs"""
175 175 bufs = list(bufs) # allow us to pop
176 176 assert len(bufs) >= 2, "not enough buffers!"
177 177 if not copy:
178 178 for i in range(2):
179 179 bufs[i] = bufs[i].bytes
180 180 f = uncan(pickle.loads(bufs.pop(0)), g)
181 181 info = pickle.loads(bufs.pop(0))
182 182 arg_bufs, kwarg_bufs = bufs[:info['narg_bufs']], bufs[info['narg_bufs']:]
183 183
184 184 args = []
185 185 for i in range(info['nargs']):
186 186 arg, arg_bufs = unserialize_object(arg_bufs, g)
187 187 args.append(arg)
188 188 args = tuple(args)
189 189 assert not arg_bufs, "Shouldn't be any arg bufs left over"
190 190
191 191 kwargs = {}
192 192 for key in info['kw_keys']:
193 193 kwarg, kwarg_bufs = unserialize_object(kwarg_bufs, g)
194 194 kwargs[key] = kwarg
195 195 assert not kwarg_bufs, "Shouldn't be any kwarg bufs left over"
196 196
197 197 return f,args,kwargs
198 198
@@ -1,1854 +1,1854 b''
1 1 """A semi-synchronous Client for the ZMQ cluster
2 2
3 3 Authors:
4 4
5 5 * MinRK
6 6 """
7 7 from __future__ import print_function
8 8 #-----------------------------------------------------------------------------
9 9 # Copyright (C) 2010-2011 The IPython Development Team
10 10 #
11 11 # Distributed under the terms of the BSD License. The full license is in
12 12 # the file COPYING, distributed as part of this software.
13 13 #-----------------------------------------------------------------------------
14 14
15 15 #-----------------------------------------------------------------------------
16 16 # Imports
17 17 #-----------------------------------------------------------------------------
18 18
19 19 import os
20 20 import json
21 21 import sys
22 22 from threading import Thread, Event
23 23 import time
24 24 import warnings
25 25 from datetime import datetime
26 26 from getpass import getpass
27 27 from pprint import pprint
28 28
29 29 pjoin = os.path.join
30 30
31 31 import zmq
32 32 # from zmq.eventloop import ioloop, zmqstream
33 33
34 34 from IPython.config.configurable import MultipleInstanceError
35 35 from IPython.core.application import BaseIPythonApplication
36 36 from IPython.core.profiledir import ProfileDir, ProfileDirError
37 37
38 38 from IPython.utils.capture import RichOutput
39 39 from IPython.utils.coloransi import TermColors
40 40 from IPython.utils.jsonutil import rekey
41 41 from IPython.utils.localinterfaces import localhost, is_local_ip
42 42 from IPython.utils.path import get_ipython_dir
43 43 from IPython.utils.py3compat import cast_bytes, string_types, xrange
44 44 from IPython.utils.traitlets import (HasTraits, Integer, Instance, Unicode,
45 45 Dict, List, Bool, Set, Any)
46 46 from IPython.external.decorator import decorator
47 47 from IPython.external.ssh import tunnel
48 48
49 49 from IPython.parallel import Reference
50 50 from IPython.parallel import error
51 51 from IPython.parallel import util
52 52
53 53 from IPython.kernel.zmq.session import Session, Message
54 54 from IPython.kernel.zmq import serialize
55 55
56 56 from .asyncresult import AsyncResult, AsyncHubResult
57 57 from .view import DirectView, LoadBalancedView
58 58
59 59 #--------------------------------------------------------------------------
60 60 # Decorators for Client methods
61 61 #--------------------------------------------------------------------------
62 62
63 63 @decorator
64 64 def spin_first(f, self, *args, **kwargs):
65 65 """Call spin() to sync state prior to calling the method."""
66 66 self.spin()
67 67 return f(self, *args, **kwargs)
68 68
69 69
70 70 #--------------------------------------------------------------------------
71 71 # Classes
72 72 #--------------------------------------------------------------------------
73 73
74 74
75 75 class ExecuteReply(RichOutput):
76 76 """wrapper for finished Execute results"""
77 77 def __init__(self, msg_id, content, metadata):
78 78 self.msg_id = msg_id
79 79 self._content = content
80 80 self.execution_count = content['execution_count']
81 81 self.metadata = metadata
82 82
83 83 # RichOutput overrides
84 84
85 85 @property
86 86 def source(self):
87 87 pyout = self.metadata['pyout']
88 88 if pyout:
89 89 return pyout.get('source', '')
90 90
91 91 @property
92 92 def data(self):
93 93 pyout = self.metadata['pyout']
94 94 if pyout:
95 95 return pyout.get('data', {})
96 96
97 97 @property
98 98 def _metadata(self):
99 99 pyout = self.metadata['pyout']
100 100 if pyout:
101 101 return pyout.get('metadata', {})
102 102
103 103 def display(self):
104 104 from IPython.display import publish_display_data
105 105 publish_display_data(self.source, self.data, self.metadata)
106 106
107 107 def _repr_mime_(self, mime):
108 108 if mime not in self.data:
109 109 return
110 110 data = self.data[mime]
111 111 if mime in self._metadata:
112 112 return data, self._metadata[mime]
113 113 else:
114 114 return data
115 115
116 116 def __getitem__(self, key):
117 117 return self.metadata[key]
118 118
119 119 def __getattr__(self, key):
120 120 if key not in self.metadata:
121 121 raise AttributeError(key)
122 122 return self.metadata[key]
123 123
124 124 def __repr__(self):
125 125 pyout = self.metadata['pyout'] or {'data':{}}
126 126 text_out = pyout['data'].get('text/plain', '')
127 127 if len(text_out) > 32:
128 128 text_out = text_out[:29] + '...'
129 129
130 130 return "<ExecuteReply[%i]: %s>" % (self.execution_count, text_out)
131 131
132 132 def _repr_pretty_(self, p, cycle):
133 133 pyout = self.metadata['pyout'] or {'data':{}}
134 134 text_out = pyout['data'].get('text/plain', '')
135 135
136 136 if not text_out:
137 137 return
138 138
139 139 try:
140 140 ip = get_ipython()
141 141 except NameError:
142 142 colors = "NoColor"
143 143 else:
144 144 colors = ip.colors
145 145
146 146 if colors == "NoColor":
147 147 out = normal = ""
148 148 else:
149 149 out = TermColors.Red
150 150 normal = TermColors.Normal
151 151
152 152 if '\n' in text_out and not text_out.startswith('\n'):
153 153 # add newline for multiline reprs
154 154 text_out = '\n' + text_out
155 155
156 156 p.text(
157 157 out + u'Out[%i:%i]: ' % (
158 158 self.metadata['engine_id'], self.execution_count
159 159 ) + normal + text_out
160 160 )
161 161
162 162
163 163 class Metadata(dict):
164 164 """Subclass of dict for initializing metadata values.
165 165
166 166 Attribute access works on keys.
167 167
168 168 These objects have a strict set of keys - errors will raise if you try
169 169 to add new keys.
170 170 """
171 171 def __init__(self, *args, **kwargs):
172 172 dict.__init__(self)
173 173 md = {'msg_id' : None,
174 174 'submitted' : None,
175 175 'started' : None,
176 176 'completed' : None,
177 177 'received' : None,
178 178 'engine_uuid' : None,
179 179 'engine_id' : None,
180 180 'follow' : None,
181 181 'after' : None,
182 182 'status' : None,
183 183
184 184 'pyin' : None,
185 185 'pyout' : None,
186 186 'pyerr' : None,
187 187 'stdout' : '',
188 188 'stderr' : '',
189 189 'outputs' : [],
190 190 'data': {},
191 191 'outputs_ready' : False,
192 192 }
193 193 self.update(md)
194 194 self.update(dict(*args, **kwargs))
195 195
196 196 def __getattr__(self, key):
197 197 """getattr aliased to getitem"""
198 if key in self.iterkeys():
198 if key in self:
199 199 return self[key]
200 200 else:
201 201 raise AttributeError(key)
202 202
203 203 def __setattr__(self, key, value):
204 204 """setattr aliased to setitem, with strict"""
205 if key in self.iterkeys():
205 if key in self:
206 206 self[key] = value
207 207 else:
208 208 raise AttributeError(key)
209 209
210 210 def __setitem__(self, key, value):
211 211 """strict static key enforcement"""
212 if key in self.iterkeys():
212 if key in self:
213 213 dict.__setitem__(self, key, value)
214 214 else:
215 215 raise KeyError(key)
216 216
217 217
218 218 class Client(HasTraits):
219 219 """A semi-synchronous client to the IPython ZMQ cluster
220 220
221 221 Parameters
222 222 ----------
223 223
224 224 url_file : str/unicode; path to ipcontroller-client.json
225 225 This JSON file should contain all the information needed to connect to a cluster,
226 226 and is likely the only argument needed.
227 227 Connection information for the Hub's registration. If a json connector
228 228 file is given, then likely no further configuration is necessary.
229 229 [Default: use profile]
230 230 profile : bytes
231 231 The name of the Cluster profile to be used to find connector information.
232 232 If run from an IPython application, the default profile will be the same
233 233 as the running application, otherwise it will be 'default'.
234 234 cluster_id : str
235 235 String id to added to runtime files, to prevent name collisions when using
236 236 multiple clusters with a single profile simultaneously.
237 237 When set, will look for files named like: 'ipcontroller-<cluster_id>-client.json'
238 238 Since this is text inserted into filenames, typical recommendations apply:
239 239 Simple character strings are ideal, and spaces are not recommended (but
240 240 should generally work)
241 241 context : zmq.Context
242 242 Pass an existing zmq.Context instance, otherwise the client will create its own.
243 243 debug : bool
244 244 flag for lots of message printing for debug purposes
245 245 timeout : int/float
246 246 time (in seconds) to wait for connection replies from the Hub
247 247 [Default: 10]
248 248
249 249 #-------------- session related args ----------------
250 250
251 251 config : Config object
252 252 If specified, this will be relayed to the Session for configuration
253 253 username : str
254 254 set username for the session object
255 255
256 256 #-------------- ssh related args ----------------
257 257 # These are args for configuring the ssh tunnel to be used
258 258 # credentials are used to forward connections over ssh to the Controller
259 259 # Note that the ip given in `addr` needs to be relative to sshserver
260 260 # The most basic case is to leave addr as pointing to localhost (127.0.0.1),
261 261 # and set sshserver as the same machine the Controller is on. However,
262 262 # the only requirement is that sshserver is able to see the Controller
263 263 # (i.e. is within the same trusted network).
264 264
265 265 sshserver : str
266 266 A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port'
267 267 If keyfile or password is specified, and this is not, it will default to
268 268 the ip given in addr.
269 269 sshkey : str; path to ssh private key file
270 270 This specifies a key to be used in ssh login, default None.
271 271 Regular default ssh keys will be used without specifying this argument.
272 272 password : str
273 273 Your ssh password to sshserver. Note that if this is left None,
274 274 you will be prompted for it if passwordless key based login is unavailable.
275 275 paramiko : bool
276 276 flag for whether to use paramiko instead of shell ssh for tunneling.
277 277 [default: True on win32, False else]
278 278
279 279
280 280 Attributes
281 281 ----------
282 282
283 283 ids : list of int engine IDs
284 284 requesting the ids attribute always synchronizes
285 285 the registration state. To request ids without synchronization,
286 286 use semi-private _ids attributes.
287 287
288 288 history : list of msg_ids
289 289 a list of msg_ids, keeping track of all the execution
290 290 messages you have submitted in order.
291 291
292 292 outstanding : set of msg_ids
293 293 a set of msg_ids that have been submitted, but whose
294 294 results have not yet been received.
295 295
296 296 results : dict
297 297 a dict of all our results, keyed by msg_id
298 298
299 299 block : bool
300 300 determines default behavior when block not specified
301 301 in execution methods
302 302
303 303 Methods
304 304 -------
305 305
306 306 spin
307 307 flushes incoming results and registration state changes
308 308 control methods spin, and requesting `ids` also ensures up to date
309 309
310 310 wait
311 311 wait on one or more msg_ids
312 312
313 313 execution methods
314 314 apply
315 315 legacy: execute, run
316 316
317 317 data movement
318 318 push, pull, scatter, gather
319 319
320 320 query methods
321 321 queue_status, get_result, purge, result_status
322 322
323 323 control methods
324 324 abort, shutdown
325 325
326 326 """
327 327
328 328
329 329 block = Bool(False)
330 330 outstanding = Set()
331 331 results = Instance('collections.defaultdict', (dict,))
332 332 metadata = Instance('collections.defaultdict', (Metadata,))
333 333 history = List()
334 334 debug = Bool(False)
335 335 _spin_thread = Any()
336 336 _stop_spinning = Any()
337 337
338 338 profile=Unicode()
339 339 def _profile_default(self):
340 340 if BaseIPythonApplication.initialized():
341 341 # an IPython app *might* be running, try to get its profile
342 342 try:
343 343 return BaseIPythonApplication.instance().profile
344 344 except (AttributeError, MultipleInstanceError):
345 345 # could be a *different* subclass of config.Application,
346 346 # which would raise one of these two errors.
347 347 return u'default'
348 348 else:
349 349 return u'default'
350 350
351 351
352 352 _outstanding_dict = Instance('collections.defaultdict', (set,))
353 353 _ids = List()
354 354 _connected=Bool(False)
355 355 _ssh=Bool(False)
356 356 _context = Instance('zmq.Context')
357 357 _config = Dict()
358 358 _engines=Instance(util.ReverseDict, (), {})
359 359 # _hub_socket=Instance('zmq.Socket')
360 360 _query_socket=Instance('zmq.Socket')
361 361 _control_socket=Instance('zmq.Socket')
362 362 _iopub_socket=Instance('zmq.Socket')
363 363 _notification_socket=Instance('zmq.Socket')
364 364 _mux_socket=Instance('zmq.Socket')
365 365 _task_socket=Instance('zmq.Socket')
366 366 _task_scheme=Unicode()
367 367 _closed = False
368 368 _ignored_control_replies=Integer(0)
369 369 _ignored_hub_replies=Integer(0)
370 370
371 371 def __new__(self, *args, **kw):
372 372 # don't raise on positional args
373 373 return HasTraits.__new__(self, **kw)
374 374
375 375 def __init__(self, url_file=None, profile=None, profile_dir=None, ipython_dir=None,
376 376 context=None, debug=False,
377 377 sshserver=None, sshkey=None, password=None, paramiko=None,
378 378 timeout=10, cluster_id=None, **extra_args
379 379 ):
380 380 if profile:
381 381 super(Client, self).__init__(debug=debug, profile=profile)
382 382 else:
383 383 super(Client, self).__init__(debug=debug)
384 384 if context is None:
385 385 context = zmq.Context.instance()
386 386 self._context = context
387 387 self._stop_spinning = Event()
388 388
389 389 if 'url_or_file' in extra_args:
390 390 url_file = extra_args['url_or_file']
391 391 warnings.warn("url_or_file arg no longer supported, use url_file", DeprecationWarning)
392 392
393 393 if url_file and util.is_url(url_file):
394 394 raise ValueError("single urls cannot be specified, url-files must be used.")
395 395
396 396 self._setup_profile_dir(self.profile, profile_dir, ipython_dir)
397 397
398 398 if self._cd is not None:
399 399 if url_file is None:
400 400 if not cluster_id:
401 401 client_json = 'ipcontroller-client.json'
402 402 else:
403 403 client_json = 'ipcontroller-%s-client.json' % cluster_id
404 404 url_file = pjoin(self._cd.security_dir, client_json)
405 405 if url_file is None:
406 406 raise ValueError(
407 407 "I can't find enough information to connect to a hub!"
408 408 " Please specify at least one of url_file or profile."
409 409 )
410 410
411 411 with open(url_file) as f:
412 412 cfg = json.load(f)
413 413
414 414 self._task_scheme = cfg['task_scheme']
415 415
416 416 # sync defaults from args, json:
417 417 if sshserver:
418 418 cfg['ssh'] = sshserver
419 419
420 420 location = cfg.setdefault('location', None)
421 421
422 422 proto,addr = cfg['interface'].split('://')
423 423 addr = util.disambiguate_ip_address(addr, location)
424 424 cfg['interface'] = "%s://%s" % (proto, addr)
425 425
426 426 # turn interface,port into full urls:
427 427 for key in ('control', 'task', 'mux', 'iopub', 'notification', 'registration'):
428 428 cfg[key] = cfg['interface'] + ':%i' % cfg[key]
429 429
430 430 url = cfg['registration']
431 431
432 432 if location is not None and addr == localhost():
433 433 # location specified, and connection is expected to be local
434 434 if not is_local_ip(location) and not sshserver:
435 435 # load ssh from JSON *only* if the controller is not on
436 436 # this machine
437 437 sshserver=cfg['ssh']
438 438 if not is_local_ip(location) and not sshserver:
439 439 # warn if no ssh specified, but SSH is probably needed
440 440 # This is only a warning, because the most likely cause
441 441 # is a local Controller on a laptop whose IP is dynamic
442 442 warnings.warn("""
443 443 Controller appears to be listening on localhost, but not on this machine.
444 444 If this is true, you should specify Client(...,sshserver='you@%s')
445 445 or instruct your controller to listen on an external IP."""%location,
446 446 RuntimeWarning)
447 447 elif not sshserver:
448 448 # otherwise sync with cfg
449 449 sshserver = cfg['ssh']
450 450
451 451 self._config = cfg
452 452
453 453 self._ssh = bool(sshserver or sshkey or password)
454 454 if self._ssh and sshserver is None:
455 455 # default to ssh via localhost
456 456 sshserver = addr
457 457 if self._ssh and password is None:
458 458 if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko):
459 459 password=False
460 460 else:
461 461 password = getpass("SSH Password for %s: "%sshserver)
462 462 ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko)
463 463
464 464 # configure and construct the session
465 465 try:
466 466 extra_args['packer'] = cfg['pack']
467 467 extra_args['unpacker'] = cfg['unpack']
468 468 extra_args['key'] = cast_bytes(cfg['key'])
469 469 extra_args['signature_scheme'] = cfg['signature_scheme']
470 470 except KeyError as exc:
471 471 msg = '\n'.join([
472 472 "Connection file is invalid (missing '{}'), possibly from an old version of IPython.",
473 473 "If you are reusing connection files, remove them and start ipcontroller again."
474 474 ])
475 475 raise ValueError(msg.format(exc.message))
476 476
477 477 self.session = Session(**extra_args)
478 478
479 479 self._query_socket = self._context.socket(zmq.DEALER)
480 480
481 481 if self._ssh:
482 482 tunnel.tunnel_connection(self._query_socket, cfg['registration'], sshserver, **ssh_kwargs)
483 483 else:
484 484 self._query_socket.connect(cfg['registration'])
485 485
486 486 self.session.debug = self.debug
487 487
488 488 self._notification_handlers = {'registration_notification' : self._register_engine,
489 489 'unregistration_notification' : self._unregister_engine,
490 490 'shutdown_notification' : lambda msg: self.close(),
491 491 }
492 492 self._queue_handlers = {'execute_reply' : self._handle_execute_reply,
493 493 'apply_reply' : self._handle_apply_reply}
494 494
495 495 try:
496 496 self._connect(sshserver, ssh_kwargs, timeout)
497 497 except:
498 498 self.close(linger=0)
499 499 raise
500 500
501 501 # last step: setup magics, if we are in IPython:
502 502
503 503 try:
504 504 ip = get_ipython()
505 505 except NameError:
506 506 return
507 507 else:
508 508 if 'px' not in ip.magics_manager.magics:
509 509 # in IPython but we are the first Client.
510 510 # activate a default view for parallel magics.
511 511 self.activate()
512 512
513 513 def __del__(self):
514 514 """cleanup sockets, but _not_ context."""
515 515 self.close()
516 516
517 517 def _setup_profile_dir(self, profile, profile_dir, ipython_dir):
518 518 if ipython_dir is None:
519 519 ipython_dir = get_ipython_dir()
520 520 if profile_dir is not None:
521 521 try:
522 522 self._cd = ProfileDir.find_profile_dir(profile_dir)
523 523 return
524 524 except ProfileDirError:
525 525 pass
526 526 elif profile is not None:
527 527 try:
528 528 self._cd = ProfileDir.find_profile_dir_by_name(
529 529 ipython_dir, profile)
530 530 return
531 531 except ProfileDirError:
532 532 pass
533 533 self._cd = None
534 534
535 535 def _update_engines(self, engines):
536 536 """Update our engines dict and _ids from a dict of the form: {id:uuid}."""
537 537 for k,v in engines.iteritems():
538 538 eid = int(k)
539 539 if eid not in self._engines:
540 540 self._ids.append(eid)
541 541 self._engines[eid] = v
542 542 self._ids = sorted(self._ids)
543 543 if sorted(self._engines.keys()) != range(len(self._engines)) and \
544 544 self._task_scheme == 'pure' and self._task_socket:
545 545 self._stop_scheduling_tasks()
546 546
547 547 def _stop_scheduling_tasks(self):
548 548 """Stop scheduling tasks because an engine has been unregistered
549 549 from a pure ZMQ scheduler.
550 550 """
551 551 self._task_socket.close()
552 552 self._task_socket = None
553 553 msg = "An engine has been unregistered, and we are using pure " +\
554 554 "ZMQ task scheduling. Task farming will be disabled."
555 555 if self.outstanding:
556 556 msg += " If you were running tasks when this happened, " +\
557 557 "some `outstanding` msg_ids may never resolve."
558 558 warnings.warn(msg, RuntimeWarning)
559 559
560 560 def _build_targets(self, targets):
561 561 """Turn valid target IDs or 'all' into two lists:
562 562 (int_ids, uuids).
563 563 """
564 564 if not self._ids:
565 565 # flush notification socket if no engines yet, just in case
566 566 if not self.ids:
567 567 raise error.NoEnginesRegistered("Can't build targets without any engines")
568 568
569 569 if targets is None:
570 570 targets = self._ids
571 571 elif isinstance(targets, string_types):
572 572 if targets.lower() == 'all':
573 573 targets = self._ids
574 574 else:
575 575 raise TypeError("%r not valid str target, must be 'all'"%(targets))
576 576 elif isinstance(targets, int):
577 577 if targets < 0:
578 578 targets = self.ids[targets]
579 579 if targets not in self._ids:
580 580 raise IndexError("No such engine: %i"%targets)
581 581 targets = [targets]
582 582
583 583 if isinstance(targets, slice):
584 584 indices = range(len(self._ids))[targets]
585 585 ids = self.ids
586 586 targets = [ ids[i] for i in indices ]
587 587
588 588 if not isinstance(targets, (tuple, list, xrange)):
589 589 raise TypeError("targets by int/slice/collection of ints only, not %s"%(type(targets)))
590 590
591 591 return [cast_bytes(self._engines[t]) for t in targets], list(targets)
592 592
593 593 def _connect(self, sshserver, ssh_kwargs, timeout):
594 594 """setup all our socket connections to the cluster. This is called from
595 595 __init__."""
596 596
597 597 # Maybe allow reconnecting?
598 598 if self._connected:
599 599 return
600 600 self._connected=True
601 601
602 602 def connect_socket(s, url):
603 603 if self._ssh:
604 604 return tunnel.tunnel_connection(s, url, sshserver, **ssh_kwargs)
605 605 else:
606 606 return s.connect(url)
607 607
608 608 self.session.send(self._query_socket, 'connection_request')
609 609 # use Poller because zmq.select has wrong units in pyzmq 2.1.7
610 610 poller = zmq.Poller()
611 611 poller.register(self._query_socket, zmq.POLLIN)
612 612 # poll expects milliseconds, timeout is seconds
613 613 evts = poller.poll(timeout*1000)
614 614 if not evts:
615 615 raise error.TimeoutError("Hub connection request timed out")
616 616 idents,msg = self.session.recv(self._query_socket,mode=0)
617 617 if self.debug:
618 618 pprint(msg)
619 619 content = msg['content']
620 620 # self._config['registration'] = dict(content)
621 621 cfg = self._config
622 622 if content['status'] == 'ok':
623 623 self._mux_socket = self._context.socket(zmq.DEALER)
624 624 connect_socket(self._mux_socket, cfg['mux'])
625 625
626 626 self._task_socket = self._context.socket(zmq.DEALER)
627 627 connect_socket(self._task_socket, cfg['task'])
628 628
629 629 self._notification_socket = self._context.socket(zmq.SUB)
630 630 self._notification_socket.setsockopt(zmq.SUBSCRIBE, b'')
631 631 connect_socket(self._notification_socket, cfg['notification'])
632 632
633 633 self._control_socket = self._context.socket(zmq.DEALER)
634 634 connect_socket(self._control_socket, cfg['control'])
635 635
636 636 self._iopub_socket = self._context.socket(zmq.SUB)
637 637 self._iopub_socket.setsockopt(zmq.SUBSCRIBE, b'')
638 638 connect_socket(self._iopub_socket, cfg['iopub'])
639 639
640 640 self._update_engines(dict(content['engines']))
641 641 else:
642 642 self._connected = False
643 643 raise Exception("Failed to connect!")
644 644
645 645 #--------------------------------------------------------------------------
646 646 # handlers and callbacks for incoming messages
647 647 #--------------------------------------------------------------------------
648 648
649 649 def _unwrap_exception(self, content):
650 650 """unwrap exception, and remap engine_id to int."""
651 651 e = error.unwrap_exception(content)
652 652 # print e.traceback
653 653 if e.engine_info:
654 654 e_uuid = e.engine_info['engine_uuid']
655 655 eid = self._engines[e_uuid]
656 656 e.engine_info['engine_id'] = eid
657 657 return e
658 658
659 659 def _extract_metadata(self, msg):
660 660 header = msg['header']
661 661 parent = msg['parent_header']
662 662 msg_meta = msg['metadata']
663 663 content = msg['content']
664 664 md = {'msg_id' : parent['msg_id'],
665 665 'received' : datetime.now(),
666 666 'engine_uuid' : msg_meta.get('engine', None),
667 667 'follow' : msg_meta.get('follow', []),
668 668 'after' : msg_meta.get('after', []),
669 669 'status' : content['status'],
670 670 }
671 671
672 672 if md['engine_uuid'] is not None:
673 673 md['engine_id'] = self._engines.get(md['engine_uuid'], None)
674 674
675 675 if 'date' in parent:
676 676 md['submitted'] = parent['date']
677 677 if 'started' in msg_meta:
678 678 md['started'] = msg_meta['started']
679 679 if 'date' in header:
680 680 md['completed'] = header['date']
681 681 return md
682 682
683 683 def _register_engine(self, msg):
684 684 """Register a new engine, and update our connection info."""
685 685 content = msg['content']
686 686 eid = content['id']
687 687 d = {eid : content['uuid']}
688 688 self._update_engines(d)
689 689
690 690 def _unregister_engine(self, msg):
691 691 """Unregister an engine that has died."""
692 692 content = msg['content']
693 693 eid = int(content['id'])
694 694 if eid in self._ids:
695 695 self._ids.remove(eid)
696 696 uuid = self._engines.pop(eid)
697 697
698 698 self._handle_stranded_msgs(eid, uuid)
699 699
700 700 if self._task_socket and self._task_scheme == 'pure':
701 701 self._stop_scheduling_tasks()
702 702
703 703 def _handle_stranded_msgs(self, eid, uuid):
704 704 """Handle messages known to be on an engine when the engine unregisters.
705 705
706 706 It is possible that this will fire prematurely - that is, an engine will
707 707 go down after completing a result, and the client will be notified
708 708 of the unregistration and later receive the successful result.
709 709 """
710 710
711 711 outstanding = self._outstanding_dict[uuid]
712 712
713 713 for msg_id in list(outstanding):
714 714 if msg_id in self.results:
715 715 # we already
716 716 continue
717 717 try:
718 718 raise error.EngineError("Engine %r died while running task %r"%(eid, msg_id))
719 719 except:
720 720 content = error.wrap_exception()
721 721 # build a fake message:
722 722 msg = self.session.msg('apply_reply', content=content)
723 723 msg['parent_header']['msg_id'] = msg_id
724 724 msg['metadata']['engine'] = uuid
725 725 self._handle_apply_reply(msg)
726 726
727 727 def _handle_execute_reply(self, msg):
728 728 """Save the reply to an execute_request into our results.
729 729
730 730 execute messages are never actually used. apply is used instead.
731 731 """
732 732
733 733 parent = msg['parent_header']
734 734 msg_id = parent['msg_id']
735 735 if msg_id not in self.outstanding:
736 736 if msg_id in self.history:
737 737 print(("got stale result: %s"%msg_id))
738 738 else:
739 739 print(("got unknown result: %s"%msg_id))
740 740 else:
741 741 self.outstanding.remove(msg_id)
742 742
743 743 content = msg['content']
744 744 header = msg['header']
745 745
746 746 # construct metadata:
747 747 md = self.metadata[msg_id]
748 748 md.update(self._extract_metadata(msg))
749 749 # is this redundant?
750 750 self.metadata[msg_id] = md
751 751
752 752 e_outstanding = self._outstanding_dict[md['engine_uuid']]
753 753 if msg_id in e_outstanding:
754 754 e_outstanding.remove(msg_id)
755 755
756 756 # construct result:
757 757 if content['status'] == 'ok':
758 758 self.results[msg_id] = ExecuteReply(msg_id, content, md)
759 759 elif content['status'] == 'aborted':
760 760 self.results[msg_id] = error.TaskAborted(msg_id)
761 761 elif content['status'] == 'resubmitted':
762 762 # TODO: handle resubmission
763 763 pass
764 764 else:
765 765 self.results[msg_id] = self._unwrap_exception(content)
766 766
767 767 def _handle_apply_reply(self, msg):
768 768 """Save the reply to an apply_request into our results."""
769 769 parent = msg['parent_header']
770 770 msg_id = parent['msg_id']
771 771 if msg_id not in self.outstanding:
772 772 if msg_id in self.history:
773 773 print(("got stale result: %s"%msg_id))
774 774 print(self.results[msg_id])
775 775 print(msg)
776 776 else:
777 777 print(("got unknown result: %s"%msg_id))
778 778 else:
779 779 self.outstanding.remove(msg_id)
780 780 content = msg['content']
781 781 header = msg['header']
782 782
783 783 # construct metadata:
784 784 md = self.metadata[msg_id]
785 785 md.update(self._extract_metadata(msg))
786 786 # is this redundant?
787 787 self.metadata[msg_id] = md
788 788
789 789 e_outstanding = self._outstanding_dict[md['engine_uuid']]
790 790 if msg_id in e_outstanding:
791 791 e_outstanding.remove(msg_id)
792 792
793 793 # construct result:
794 794 if content['status'] == 'ok':
795 795 self.results[msg_id] = serialize.unserialize_object(msg['buffers'])[0]
796 796 elif content['status'] == 'aborted':
797 797 self.results[msg_id] = error.TaskAborted(msg_id)
798 798 elif content['status'] == 'resubmitted':
799 799 # TODO: handle resubmission
800 800 pass
801 801 else:
802 802 self.results[msg_id] = self._unwrap_exception(content)
803 803
804 804 def _flush_notifications(self):
805 805 """Flush notifications of engine registrations waiting
806 806 in ZMQ queue."""
807 807 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
808 808 while msg is not None:
809 809 if self.debug:
810 810 pprint(msg)
811 811 msg_type = msg['header']['msg_type']
812 812 handler = self._notification_handlers.get(msg_type, None)
813 813 if handler is None:
814 814 raise Exception("Unhandled message type: %s" % msg_type)
815 815 else:
816 816 handler(msg)
817 817 idents,msg = self.session.recv(self._notification_socket, mode=zmq.NOBLOCK)
818 818
819 819 def _flush_results(self, sock):
820 820 """Flush task or queue results waiting in ZMQ queue."""
821 821 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
822 822 while msg is not None:
823 823 if self.debug:
824 824 pprint(msg)
825 825 msg_type = msg['header']['msg_type']
826 826 handler = self._queue_handlers.get(msg_type, None)
827 827 if handler is None:
828 828 raise Exception("Unhandled message type: %s" % msg_type)
829 829 else:
830 830 handler(msg)
831 831 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
832 832
833 833 def _flush_control(self, sock):
834 834 """Flush replies from the control channel waiting
835 835 in the ZMQ queue.
836 836
837 837 Currently: ignore them."""
838 838 if self._ignored_control_replies <= 0:
839 839 return
840 840 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
841 841 while msg is not None:
842 842 self._ignored_control_replies -= 1
843 843 if self.debug:
844 844 pprint(msg)
845 845 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
846 846
847 847 def _flush_ignored_control(self):
848 848 """flush ignored control replies"""
849 849 while self._ignored_control_replies > 0:
850 850 self.session.recv(self._control_socket)
851 851 self._ignored_control_replies -= 1
852 852
853 853 def _flush_ignored_hub_replies(self):
854 854 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
855 855 while msg is not None:
856 856 ident,msg = self.session.recv(self._query_socket, mode=zmq.NOBLOCK)
857 857
858 858 def _flush_iopub(self, sock):
859 859 """Flush replies from the iopub channel waiting
860 860 in the ZMQ queue.
861 861 """
862 862 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
863 863 while msg is not None:
864 864 if self.debug:
865 865 pprint(msg)
866 866 parent = msg['parent_header']
867 867 # ignore IOPub messages with no parent.
868 868 # Caused by print statements or warnings from before the first execution.
869 869 if not parent:
870 870 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
871 871 continue
872 872 msg_id = parent['msg_id']
873 873 content = msg['content']
874 874 header = msg['header']
875 875 msg_type = msg['header']['msg_type']
876 876
877 877 # init metadata:
878 878 md = self.metadata[msg_id]
879 879
880 880 if msg_type == 'stream':
881 881 name = content['name']
882 882 s = md[name] or ''
883 883 md[name] = s + content['data']
884 884 elif msg_type == 'pyerr':
885 885 md.update({'pyerr' : self._unwrap_exception(content)})
886 886 elif msg_type == 'pyin':
887 887 md.update({'pyin' : content['code']})
888 888 elif msg_type == 'display_data':
889 889 md['outputs'].append(content)
890 890 elif msg_type == 'pyout':
891 891 md['pyout'] = content
892 892 elif msg_type == 'data_message':
893 893 data, remainder = serialize.unserialize_object(msg['buffers'])
894 894 md['data'].update(data)
895 895 elif msg_type == 'status':
896 896 # idle message comes after all outputs
897 897 if content['execution_state'] == 'idle':
898 898 md['outputs_ready'] = True
899 899 else:
900 900 # unhandled msg_type (status, etc.)
901 901 pass
902 902
903 903 # reduntant?
904 904 self.metadata[msg_id] = md
905 905
906 906 idents,msg = self.session.recv(sock, mode=zmq.NOBLOCK)
907 907
908 908 #--------------------------------------------------------------------------
909 909 # len, getitem
910 910 #--------------------------------------------------------------------------
911 911
912 912 def __len__(self):
913 913 """len(client) returns # of engines."""
914 914 return len(self.ids)
915 915
916 916 def __getitem__(self, key):
917 917 """index access returns DirectView multiplexer objects
918 918
919 919 Must be int, slice, or list/tuple/xrange of ints"""
920 920 if not isinstance(key, (int, slice, tuple, list, xrange)):
921 921 raise TypeError("key by int/slice/iterable of ints only, not %s"%(type(key)))
922 922 else:
923 923 return self.direct_view(key)
924 924
925 925 #--------------------------------------------------------------------------
926 926 # Begin public methods
927 927 #--------------------------------------------------------------------------
928 928
929 929 @property
930 930 def ids(self):
931 931 """Always up-to-date ids property."""
932 932 self._flush_notifications()
933 933 # always copy:
934 934 return list(self._ids)
935 935
936 936 def activate(self, targets='all', suffix=''):
937 937 """Create a DirectView and register it with IPython magics
938 938
939 939 Defines the magics `%px, %autopx, %pxresult, %%px`
940 940
941 941 Parameters
942 942 ----------
943 943
944 944 targets: int, list of ints, or 'all'
945 945 The engines on which the view's magics will run
946 946 suffix: str [default: '']
947 947 The suffix, if any, for the magics. This allows you to have
948 948 multiple views associated with parallel magics at the same time.
949 949
950 950 e.g. ``rc.activate(targets=0, suffix='0')`` will give you
951 951 the magics ``%px0``, ``%pxresult0``, etc. for running magics just
952 952 on engine 0.
953 953 """
954 954 view = self.direct_view(targets)
955 955 view.block = True
956 956 view.activate(suffix)
957 957 return view
958 958
959 959 def close(self, linger=None):
960 960 """Close my zmq Sockets
961 961
962 962 If `linger`, set the zmq LINGER socket option,
963 963 which allows discarding of messages.
964 964 """
965 965 if self._closed:
966 966 return
967 967 self.stop_spin_thread()
968 968 snames = [ trait for trait in self.trait_names() if trait.endswith("socket") ]
969 969 for name in snames:
970 970 socket = getattr(self, name)
971 971 if socket is not None and not socket.closed:
972 972 if linger is not None:
973 973 socket.close(linger=linger)
974 974 else:
975 975 socket.close()
976 976 self._closed = True
977 977
978 978 def _spin_every(self, interval=1):
979 979 """target func for use in spin_thread"""
980 980 while True:
981 981 if self._stop_spinning.is_set():
982 982 return
983 983 time.sleep(interval)
984 984 self.spin()
985 985
986 986 def spin_thread(self, interval=1):
987 987 """call Client.spin() in a background thread on some regular interval
988 988
989 989 This helps ensure that messages don't pile up too much in the zmq queue
990 990 while you are working on other things, or just leaving an idle terminal.
991 991
992 992 It also helps limit potential padding of the `received` timestamp
993 993 on AsyncResult objects, used for timings.
994 994
995 995 Parameters
996 996 ----------
997 997
998 998 interval : float, optional
999 999 The interval on which to spin the client in the background thread
1000 1000 (simply passed to time.sleep).
1001 1001
1002 1002 Notes
1003 1003 -----
1004 1004
1005 1005 For precision timing, you may want to use this method to put a bound
1006 1006 on the jitter (in seconds) in `received` timestamps used
1007 1007 in AsyncResult.wall_time.
1008 1008
1009 1009 """
1010 1010 if self._spin_thread is not None:
1011 1011 self.stop_spin_thread()
1012 1012 self._stop_spinning.clear()
1013 1013 self._spin_thread = Thread(target=self._spin_every, args=(interval,))
1014 1014 self._spin_thread.daemon = True
1015 1015 self._spin_thread.start()
1016 1016
1017 1017 def stop_spin_thread(self):
1018 1018 """stop background spin_thread, if any"""
1019 1019 if self._spin_thread is not None:
1020 1020 self._stop_spinning.set()
1021 1021 self._spin_thread.join()
1022 1022 self._spin_thread = None
1023 1023
1024 1024 def spin(self):
1025 1025 """Flush any registration notifications and execution results
1026 1026 waiting in the ZMQ queue.
1027 1027 """
1028 1028 if self._notification_socket:
1029 1029 self._flush_notifications()
1030 1030 if self._iopub_socket:
1031 1031 self._flush_iopub(self._iopub_socket)
1032 1032 if self._mux_socket:
1033 1033 self._flush_results(self._mux_socket)
1034 1034 if self._task_socket:
1035 1035 self._flush_results(self._task_socket)
1036 1036 if self._control_socket:
1037 1037 self._flush_control(self._control_socket)
1038 1038 if self._query_socket:
1039 1039 self._flush_ignored_hub_replies()
1040 1040
1041 1041 def wait(self, jobs=None, timeout=-1):
1042 1042 """waits on one or more `jobs`, for up to `timeout` seconds.
1043 1043
1044 1044 Parameters
1045 1045 ----------
1046 1046
1047 1047 jobs : int, str, or list of ints and/or strs, or one or more AsyncResult objects
1048 1048 ints are indices to self.history
1049 1049 strs are msg_ids
1050 1050 default: wait on all outstanding messages
1051 1051 timeout : float
1052 1052 a time in seconds, after which to give up.
1053 1053 default is -1, which means no timeout
1054 1054
1055 1055 Returns
1056 1056 -------
1057 1057
1058 1058 True : when all msg_ids are done
1059 1059 False : timeout reached, some msg_ids still outstanding
1060 1060 """
1061 1061 tic = time.time()
1062 1062 if jobs is None:
1063 1063 theids = self.outstanding
1064 1064 else:
1065 1065 if isinstance(jobs, string_types + (int, AsyncResult)):
1066 1066 jobs = [jobs]
1067 1067 theids = set()
1068 1068 for job in jobs:
1069 1069 if isinstance(job, int):
1070 1070 # index access
1071 1071 job = self.history[job]
1072 1072 elif isinstance(job, AsyncResult):
1073 1073 map(theids.add, job.msg_ids)
1074 1074 continue
1075 1075 theids.add(job)
1076 1076 if not theids.intersection(self.outstanding):
1077 1077 return True
1078 1078 self.spin()
1079 1079 while theids.intersection(self.outstanding):
1080 1080 if timeout >= 0 and ( time.time()-tic ) > timeout:
1081 1081 break
1082 1082 time.sleep(1e-3)
1083 1083 self.spin()
1084 1084 return len(theids.intersection(self.outstanding)) == 0
1085 1085
1086 1086 #--------------------------------------------------------------------------
1087 1087 # Control methods
1088 1088 #--------------------------------------------------------------------------
1089 1089
1090 1090 @spin_first
1091 1091 def clear(self, targets=None, block=None):
1092 1092 """Clear the namespace in target(s)."""
1093 1093 block = self.block if block is None else block
1094 1094 targets = self._build_targets(targets)[0]
1095 1095 for t in targets:
1096 1096 self.session.send(self._control_socket, 'clear_request', content={}, ident=t)
1097 1097 error = False
1098 1098 if block:
1099 1099 self._flush_ignored_control()
1100 1100 for i in range(len(targets)):
1101 1101 idents,msg = self.session.recv(self._control_socket,0)
1102 1102 if self.debug:
1103 1103 pprint(msg)
1104 1104 if msg['content']['status'] != 'ok':
1105 1105 error = self._unwrap_exception(msg['content'])
1106 1106 else:
1107 1107 self._ignored_control_replies += len(targets)
1108 1108 if error:
1109 1109 raise error
1110 1110
1111 1111
1112 1112 @spin_first
1113 1113 def abort(self, jobs=None, targets=None, block=None):
1114 1114 """Abort specific jobs from the execution queues of target(s).
1115 1115
1116 1116 This is a mechanism to prevent jobs that have already been submitted
1117 1117 from executing.
1118 1118
1119 1119 Parameters
1120 1120 ----------
1121 1121
1122 1122 jobs : msg_id, list of msg_ids, or AsyncResult
1123 1123 The jobs to be aborted
1124 1124
1125 1125 If unspecified/None: abort all outstanding jobs.
1126 1126
1127 1127 """
1128 1128 block = self.block if block is None else block
1129 1129 jobs = jobs if jobs is not None else list(self.outstanding)
1130 1130 targets = self._build_targets(targets)[0]
1131 1131
1132 1132 msg_ids = []
1133 1133 if isinstance(jobs, string_types + (AsyncResult,)):
1134 1134 jobs = [jobs]
1135 1135 bad_ids = filter(lambda obj: not isinstance(obj, string_types + (AsyncResult,)), jobs)
1136 1136 if bad_ids:
1137 1137 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1138 1138 for j in jobs:
1139 1139 if isinstance(j, AsyncResult):
1140 1140 msg_ids.extend(j.msg_ids)
1141 1141 else:
1142 1142 msg_ids.append(j)
1143 1143 content = dict(msg_ids=msg_ids)
1144 1144 for t in targets:
1145 1145 self.session.send(self._control_socket, 'abort_request',
1146 1146 content=content, ident=t)
1147 1147 error = False
1148 1148 if block:
1149 1149 self._flush_ignored_control()
1150 1150 for i in range(len(targets)):
1151 1151 idents,msg = self.session.recv(self._control_socket,0)
1152 1152 if self.debug:
1153 1153 pprint(msg)
1154 1154 if msg['content']['status'] != 'ok':
1155 1155 error = self._unwrap_exception(msg['content'])
1156 1156 else:
1157 1157 self._ignored_control_replies += len(targets)
1158 1158 if error:
1159 1159 raise error
1160 1160
1161 1161 @spin_first
1162 1162 def shutdown(self, targets='all', restart=False, hub=False, block=None):
1163 1163 """Terminates one or more engine processes, optionally including the hub.
1164 1164
1165 1165 Parameters
1166 1166 ----------
1167 1167
1168 1168 targets: list of ints or 'all' [default: all]
1169 1169 Which engines to shutdown.
1170 1170 hub: bool [default: False]
1171 1171 Whether to include the Hub. hub=True implies targets='all'.
1172 1172 block: bool [default: self.block]
1173 1173 Whether to wait for clean shutdown replies or not.
1174 1174 restart: bool [default: False]
1175 1175 NOT IMPLEMENTED
1176 1176 whether to restart engines after shutting them down.
1177 1177 """
1178 1178 from IPython.parallel.error import NoEnginesRegistered
1179 1179 if restart:
1180 1180 raise NotImplementedError("Engine restart is not yet implemented")
1181 1181
1182 1182 block = self.block if block is None else block
1183 1183 if hub:
1184 1184 targets = 'all'
1185 1185 try:
1186 1186 targets = self._build_targets(targets)[0]
1187 1187 except NoEnginesRegistered:
1188 1188 targets = []
1189 1189 for t in targets:
1190 1190 self.session.send(self._control_socket, 'shutdown_request',
1191 1191 content={'restart':restart},ident=t)
1192 1192 error = False
1193 1193 if block or hub:
1194 1194 self._flush_ignored_control()
1195 1195 for i in range(len(targets)):
1196 1196 idents,msg = self.session.recv(self._control_socket, 0)
1197 1197 if self.debug:
1198 1198 pprint(msg)
1199 1199 if msg['content']['status'] != 'ok':
1200 1200 error = self._unwrap_exception(msg['content'])
1201 1201 else:
1202 1202 self._ignored_control_replies += len(targets)
1203 1203
1204 1204 if hub:
1205 1205 time.sleep(0.25)
1206 1206 self.session.send(self._query_socket, 'shutdown_request')
1207 1207 idents,msg = self.session.recv(self._query_socket, 0)
1208 1208 if self.debug:
1209 1209 pprint(msg)
1210 1210 if msg['content']['status'] != 'ok':
1211 1211 error = self._unwrap_exception(msg['content'])
1212 1212
1213 1213 if error:
1214 1214 raise error
1215 1215
1216 1216 #--------------------------------------------------------------------------
1217 1217 # Execution related methods
1218 1218 #--------------------------------------------------------------------------
1219 1219
1220 1220 def _maybe_raise(self, result):
1221 1221 """wrapper for maybe raising an exception if apply failed."""
1222 1222 if isinstance(result, error.RemoteError):
1223 1223 raise result
1224 1224
1225 1225 return result
1226 1226
1227 1227 def send_apply_request(self, socket, f, args=None, kwargs=None, metadata=None, track=False,
1228 1228 ident=None):
1229 1229 """construct and send an apply message via a socket.
1230 1230
1231 1231 This is the principal method with which all engine execution is performed by views.
1232 1232 """
1233 1233
1234 1234 if self._closed:
1235 1235 raise RuntimeError("Client cannot be used after its sockets have been closed")
1236 1236
1237 1237 # defaults:
1238 1238 args = args if args is not None else []
1239 1239 kwargs = kwargs if kwargs is not None else {}
1240 1240 metadata = metadata if metadata is not None else {}
1241 1241
1242 1242 # validate arguments
1243 1243 if not callable(f) and not isinstance(f, Reference):
1244 1244 raise TypeError("f must be callable, not %s"%type(f))
1245 1245 if not isinstance(args, (tuple, list)):
1246 1246 raise TypeError("args must be tuple or list, not %s"%type(args))
1247 1247 if not isinstance(kwargs, dict):
1248 1248 raise TypeError("kwargs must be dict, not %s"%type(kwargs))
1249 1249 if not isinstance(metadata, dict):
1250 1250 raise TypeError("metadata must be dict, not %s"%type(metadata))
1251 1251
1252 1252 bufs = serialize.pack_apply_message(f, args, kwargs,
1253 1253 buffer_threshold=self.session.buffer_threshold,
1254 1254 item_threshold=self.session.item_threshold,
1255 1255 )
1256 1256
1257 1257 msg = self.session.send(socket, "apply_request", buffers=bufs, ident=ident,
1258 1258 metadata=metadata, track=track)
1259 1259
1260 1260 msg_id = msg['header']['msg_id']
1261 1261 self.outstanding.add(msg_id)
1262 1262 if ident:
1263 1263 # possibly routed to a specific engine
1264 1264 if isinstance(ident, list):
1265 1265 ident = ident[-1]
1266 1266 if ident in self._engines.values():
1267 1267 # save for later, in case of engine death
1268 1268 self._outstanding_dict[ident].add(msg_id)
1269 1269 self.history.append(msg_id)
1270 1270 self.metadata[msg_id]['submitted'] = datetime.now()
1271 1271
1272 1272 return msg
1273 1273
1274 1274 def send_execute_request(self, socket, code, silent=True, metadata=None, ident=None):
1275 1275 """construct and send an execute request via a socket.
1276 1276
1277 1277 """
1278 1278
1279 1279 if self._closed:
1280 1280 raise RuntimeError("Client cannot be used after its sockets have been closed")
1281 1281
1282 1282 # defaults:
1283 1283 metadata = metadata if metadata is not None else {}
1284 1284
1285 1285 # validate arguments
1286 1286 if not isinstance(code, string_types):
1287 1287 raise TypeError("code must be text, not %s" % type(code))
1288 1288 if not isinstance(metadata, dict):
1289 1289 raise TypeError("metadata must be dict, not %s" % type(metadata))
1290 1290
1291 1291 content = dict(code=code, silent=bool(silent), user_variables=[], user_expressions={})
1292 1292
1293 1293
1294 1294 msg = self.session.send(socket, "execute_request", content=content, ident=ident,
1295 1295 metadata=metadata)
1296 1296
1297 1297 msg_id = msg['header']['msg_id']
1298 1298 self.outstanding.add(msg_id)
1299 1299 if ident:
1300 1300 # possibly routed to a specific engine
1301 1301 if isinstance(ident, list):
1302 1302 ident = ident[-1]
1303 1303 if ident in self._engines.values():
1304 1304 # save for later, in case of engine death
1305 1305 self._outstanding_dict[ident].add(msg_id)
1306 1306 self.history.append(msg_id)
1307 1307 self.metadata[msg_id]['submitted'] = datetime.now()
1308 1308
1309 1309 return msg
1310 1310
1311 1311 #--------------------------------------------------------------------------
1312 1312 # construct a View object
1313 1313 #--------------------------------------------------------------------------
1314 1314
1315 1315 def load_balanced_view(self, targets=None):
1316 1316 """construct a DirectView object.
1317 1317
1318 1318 If no arguments are specified, create a LoadBalancedView
1319 1319 using all engines.
1320 1320
1321 1321 Parameters
1322 1322 ----------
1323 1323
1324 1324 targets: list,slice,int,etc. [default: use all engines]
1325 1325 The subset of engines across which to load-balance
1326 1326 """
1327 1327 if targets == 'all':
1328 1328 targets = None
1329 1329 if targets is not None:
1330 1330 targets = self._build_targets(targets)[1]
1331 1331 return LoadBalancedView(client=self, socket=self._task_socket, targets=targets)
1332 1332
1333 1333 def direct_view(self, targets='all'):
1334 1334 """construct a DirectView object.
1335 1335
1336 1336 If no targets are specified, create a DirectView using all engines.
1337 1337
1338 1338 rc.direct_view('all') is distinguished from rc[:] in that 'all' will
1339 1339 evaluate the target engines at each execution, whereas rc[:] will connect to
1340 1340 all *current* engines, and that list will not change.
1341 1341
1342 1342 That is, 'all' will always use all engines, whereas rc[:] will not use
1343 1343 engines added after the DirectView is constructed.
1344 1344
1345 1345 Parameters
1346 1346 ----------
1347 1347
1348 1348 targets: list,slice,int,etc. [default: use all engines]
1349 1349 The engines to use for the View
1350 1350 """
1351 1351 single = isinstance(targets, int)
1352 1352 # allow 'all' to be lazily evaluated at each execution
1353 1353 if targets != 'all':
1354 1354 targets = self._build_targets(targets)[1]
1355 1355 if single:
1356 1356 targets = targets[0]
1357 1357 return DirectView(client=self, socket=self._mux_socket, targets=targets)
1358 1358
1359 1359 #--------------------------------------------------------------------------
1360 1360 # Query methods
1361 1361 #--------------------------------------------------------------------------
1362 1362
1363 1363 @spin_first
1364 1364 def get_result(self, indices_or_msg_ids=None, block=None):
1365 1365 """Retrieve a result by msg_id or history index, wrapped in an AsyncResult object.
1366 1366
1367 1367 If the client already has the results, no request to the Hub will be made.
1368 1368
1369 1369 This is a convenient way to construct AsyncResult objects, which are wrappers
1370 1370 that include metadata about execution, and allow for awaiting results that
1371 1371 were not submitted by this Client.
1372 1372
1373 1373 It can also be a convenient way to retrieve the metadata associated with
1374 1374 blocking execution, since it always retrieves
1375 1375
1376 1376 Examples
1377 1377 --------
1378 1378 ::
1379 1379
1380 1380 In [10]: r = client.apply()
1381 1381
1382 1382 Parameters
1383 1383 ----------
1384 1384
1385 1385 indices_or_msg_ids : integer history index, str msg_id, or list of either
1386 1386 The indices or msg_ids of indices to be retrieved
1387 1387
1388 1388 block : bool
1389 1389 Whether to wait for the result to be done
1390 1390
1391 1391 Returns
1392 1392 -------
1393 1393
1394 1394 AsyncResult
1395 1395 A single AsyncResult object will always be returned.
1396 1396
1397 1397 AsyncHubResult
1398 1398 A subclass of AsyncResult that retrieves results from the Hub
1399 1399
1400 1400 """
1401 1401 block = self.block if block is None else block
1402 1402 if indices_or_msg_ids is None:
1403 1403 indices_or_msg_ids = -1
1404 1404
1405 1405 single_result = False
1406 1406 if not isinstance(indices_or_msg_ids, (list,tuple)):
1407 1407 indices_or_msg_ids = [indices_or_msg_ids]
1408 1408 single_result = True
1409 1409
1410 1410 theids = []
1411 1411 for id in indices_or_msg_ids:
1412 1412 if isinstance(id, int):
1413 1413 id = self.history[id]
1414 1414 if not isinstance(id, string_types):
1415 1415 raise TypeError("indices must be str or int, not %r"%id)
1416 1416 theids.append(id)
1417 1417
1418 1418 local_ids = filter(lambda msg_id: msg_id in self.outstanding or msg_id in self.results, theids)
1419 1419 remote_ids = filter(lambda msg_id: msg_id not in local_ids, theids)
1420 1420
1421 1421 # given single msg_id initially, get_result shot get the result itself,
1422 1422 # not a length-one list
1423 1423 if single_result:
1424 1424 theids = theids[0]
1425 1425
1426 1426 if remote_ids:
1427 1427 ar = AsyncHubResult(self, msg_ids=theids)
1428 1428 else:
1429 1429 ar = AsyncResult(self, msg_ids=theids)
1430 1430
1431 1431 if block:
1432 1432 ar.wait()
1433 1433
1434 1434 return ar
1435 1435
1436 1436 @spin_first
1437 1437 def resubmit(self, indices_or_msg_ids=None, metadata=None, block=None):
1438 1438 """Resubmit one or more tasks.
1439 1439
1440 1440 in-flight tasks may not be resubmitted.
1441 1441
1442 1442 Parameters
1443 1443 ----------
1444 1444
1445 1445 indices_or_msg_ids : integer history index, str msg_id, or list of either
1446 1446 The indices or msg_ids of indices to be retrieved
1447 1447
1448 1448 block : bool
1449 1449 Whether to wait for the result to be done
1450 1450
1451 1451 Returns
1452 1452 -------
1453 1453
1454 1454 AsyncHubResult
1455 1455 A subclass of AsyncResult that retrieves results from the Hub
1456 1456
1457 1457 """
1458 1458 block = self.block if block is None else block
1459 1459 if indices_or_msg_ids is None:
1460 1460 indices_or_msg_ids = -1
1461 1461
1462 1462 if not isinstance(indices_or_msg_ids, (list,tuple)):
1463 1463 indices_or_msg_ids = [indices_or_msg_ids]
1464 1464
1465 1465 theids = []
1466 1466 for id in indices_or_msg_ids:
1467 1467 if isinstance(id, int):
1468 1468 id = self.history[id]
1469 1469 if not isinstance(id, string_types):
1470 1470 raise TypeError("indices must be str or int, not %r"%id)
1471 1471 theids.append(id)
1472 1472
1473 1473 content = dict(msg_ids = theids)
1474 1474
1475 1475 self.session.send(self._query_socket, 'resubmit_request', content)
1476 1476
1477 1477 zmq.select([self._query_socket], [], [])
1478 1478 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1479 1479 if self.debug:
1480 1480 pprint(msg)
1481 1481 content = msg['content']
1482 1482 if content['status'] != 'ok':
1483 1483 raise self._unwrap_exception(content)
1484 1484 mapping = content['resubmitted']
1485 1485 new_ids = [ mapping[msg_id] for msg_id in theids ]
1486 1486
1487 1487 ar = AsyncHubResult(self, msg_ids=new_ids)
1488 1488
1489 1489 if block:
1490 1490 ar.wait()
1491 1491
1492 1492 return ar
1493 1493
1494 1494 @spin_first
1495 1495 def result_status(self, msg_ids, status_only=True):
1496 1496 """Check on the status of the result(s) of the apply request with `msg_ids`.
1497 1497
1498 1498 If status_only is False, then the actual results will be retrieved, else
1499 1499 only the status of the results will be checked.
1500 1500
1501 1501 Parameters
1502 1502 ----------
1503 1503
1504 1504 msg_ids : list of msg_ids
1505 1505 if int:
1506 1506 Passed as index to self.history for convenience.
1507 1507 status_only : bool (default: True)
1508 1508 if False:
1509 1509 Retrieve the actual results of completed tasks.
1510 1510
1511 1511 Returns
1512 1512 -------
1513 1513
1514 1514 results : dict
1515 1515 There will always be the keys 'pending' and 'completed', which will
1516 1516 be lists of msg_ids that are incomplete or complete. If `status_only`
1517 1517 is False, then completed results will be keyed by their `msg_id`.
1518 1518 """
1519 1519 if not isinstance(msg_ids, (list,tuple)):
1520 1520 msg_ids = [msg_ids]
1521 1521
1522 1522 theids = []
1523 1523 for msg_id in msg_ids:
1524 1524 if isinstance(msg_id, int):
1525 1525 msg_id = self.history[msg_id]
1526 1526 if not isinstance(msg_id, string_types):
1527 1527 raise TypeError("msg_ids must be str, not %r"%msg_id)
1528 1528 theids.append(msg_id)
1529 1529
1530 1530 completed = []
1531 1531 local_results = {}
1532 1532
1533 1533 # comment this block out to temporarily disable local shortcut:
1534 1534 for msg_id in theids:
1535 1535 if msg_id in self.results:
1536 1536 completed.append(msg_id)
1537 1537 local_results[msg_id] = self.results[msg_id]
1538 1538 theids.remove(msg_id)
1539 1539
1540 1540 if theids: # some not locally cached
1541 1541 content = dict(msg_ids=theids, status_only=status_only)
1542 1542 msg = self.session.send(self._query_socket, "result_request", content=content)
1543 1543 zmq.select([self._query_socket], [], [])
1544 1544 idents,msg = self.session.recv(self._query_socket, zmq.NOBLOCK)
1545 1545 if self.debug:
1546 1546 pprint(msg)
1547 1547 content = msg['content']
1548 1548 if content['status'] != 'ok':
1549 1549 raise self._unwrap_exception(content)
1550 1550 buffers = msg['buffers']
1551 1551 else:
1552 1552 content = dict(completed=[],pending=[])
1553 1553
1554 1554 content['completed'].extend(completed)
1555 1555
1556 1556 if status_only:
1557 1557 return content
1558 1558
1559 1559 failures = []
1560 1560 # load cached results into result:
1561 1561 content.update(local_results)
1562 1562
1563 1563 # update cache with results:
1564 1564 for msg_id in sorted(theids):
1565 1565 if msg_id in content['completed']:
1566 1566 rec = content[msg_id]
1567 1567 parent = rec['header']
1568 1568 header = rec['result_header']
1569 1569 rcontent = rec['result_content']
1570 1570 iodict = rec['io']
1571 1571 if isinstance(rcontent, str):
1572 1572 rcontent = self.session.unpack(rcontent)
1573 1573
1574 1574 md = self.metadata[msg_id]
1575 1575 md_msg = dict(
1576 1576 content=rcontent,
1577 1577 parent_header=parent,
1578 1578 header=header,
1579 1579 metadata=rec['result_metadata'],
1580 1580 )
1581 1581 md.update(self._extract_metadata(md_msg))
1582 1582 if rec.get('received'):
1583 1583 md['received'] = rec['received']
1584 1584 md.update(iodict)
1585 1585
1586 1586 if rcontent['status'] == 'ok':
1587 1587 if header['msg_type'] == 'apply_reply':
1588 1588 res,buffers = serialize.unserialize_object(buffers)
1589 1589 elif header['msg_type'] == 'execute_reply':
1590 1590 res = ExecuteReply(msg_id, rcontent, md)
1591 1591 else:
1592 1592 raise KeyError("unhandled msg type: %r" % header['msg_type'])
1593 1593 else:
1594 1594 res = self._unwrap_exception(rcontent)
1595 1595 failures.append(res)
1596 1596
1597 1597 self.results[msg_id] = res
1598 1598 content[msg_id] = res
1599 1599
1600 1600 if len(theids) == 1 and failures:
1601 1601 raise failures[0]
1602 1602
1603 1603 error.collect_exceptions(failures, "result_status")
1604 1604 return content
1605 1605
1606 1606 @spin_first
1607 1607 def queue_status(self, targets='all', verbose=False):
1608 1608 """Fetch the status of engine queues.
1609 1609
1610 1610 Parameters
1611 1611 ----------
1612 1612
1613 1613 targets : int/str/list of ints/strs
1614 1614 the engines whose states are to be queried.
1615 1615 default : all
1616 1616 verbose : bool
1617 1617 Whether to return lengths only, or lists of ids for each element
1618 1618 """
1619 1619 if targets == 'all':
1620 1620 # allow 'all' to be evaluated on the engine
1621 1621 engine_ids = None
1622 1622 else:
1623 1623 engine_ids = self._build_targets(targets)[1]
1624 1624 content = dict(targets=engine_ids, verbose=verbose)
1625 1625 self.session.send(self._query_socket, "queue_request", content=content)
1626 1626 idents,msg = self.session.recv(self._query_socket, 0)
1627 1627 if self.debug:
1628 1628 pprint(msg)
1629 1629 content = msg['content']
1630 1630 status = content.pop('status')
1631 1631 if status != 'ok':
1632 1632 raise self._unwrap_exception(content)
1633 1633 content = rekey(content)
1634 1634 if isinstance(targets, int):
1635 1635 return content[targets]
1636 1636 else:
1637 1637 return content
1638 1638
1639 1639 def _build_msgids_from_target(self, targets=None):
1640 1640 """Build a list of msg_ids from the list of engine targets"""
1641 1641 if not targets: # needed as _build_targets otherwise uses all engines
1642 1642 return []
1643 1643 target_ids = self._build_targets(targets)[0]
1644 1644 return filter(lambda md_id: self.metadata[md_id]["engine_uuid"] in target_ids, self.metadata)
1645 1645
1646 1646 def _build_msgids_from_jobs(self, jobs=None):
1647 1647 """Build a list of msg_ids from "jobs" """
1648 1648 if not jobs:
1649 1649 return []
1650 1650 msg_ids = []
1651 1651 if isinstance(jobs, string_types + (AsyncResult,)):
1652 1652 jobs = [jobs]
1653 1653 bad_ids = filter(lambda obj: not isinstance(obj, string_types + (AsyncResult)), jobs)
1654 1654 if bad_ids:
1655 1655 raise TypeError("Invalid msg_id type %r, expected str or AsyncResult"%bad_ids[0])
1656 1656 for j in jobs:
1657 1657 if isinstance(j, AsyncResult):
1658 1658 msg_ids.extend(j.msg_ids)
1659 1659 else:
1660 1660 msg_ids.append(j)
1661 1661 return msg_ids
1662 1662
1663 1663 def purge_local_results(self, jobs=[], targets=[]):
1664 1664 """Clears the client caches of results and frees such memory.
1665 1665
1666 1666 Individual results can be purged by msg_id, or the entire
1667 1667 history of specific targets can be purged.
1668 1668
1669 1669 Use `purge_local_results('all')` to scrub everything from the Clients's db.
1670 1670
1671 1671 The client must have no outstanding tasks before purging the caches.
1672 1672 Raises `AssertionError` if there are still outstanding tasks.
1673 1673
1674 1674 After this call all `AsyncResults` are invalid and should be discarded.
1675 1675
1676 1676 If you must "reget" the results, you can still do so by using
1677 1677 `client.get_result(msg_id)` or `client.get_result(asyncresult)`. This will
1678 1678 redownload the results from the hub if they are still available
1679 1679 (i.e `client.purge_hub_results(...)` has not been called.
1680 1680
1681 1681 Parameters
1682 1682 ----------
1683 1683
1684 1684 jobs : str or list of str or AsyncResult objects
1685 1685 the msg_ids whose results should be purged.
1686 1686 targets : int/str/list of ints/strs
1687 1687 The targets, by int_id, whose entire results are to be purged.
1688 1688
1689 1689 default : None
1690 1690 """
1691 1691 assert not self.outstanding, "Can't purge a client with outstanding tasks!"
1692 1692
1693 1693 if not targets and not jobs:
1694 1694 raise ValueError("Must specify at least one of `targets` and `jobs`")
1695 1695
1696 1696 if jobs == 'all':
1697 1697 self.results.clear()
1698 1698 self.metadata.clear()
1699 1699 return
1700 1700 else:
1701 1701 msg_ids = []
1702 1702 msg_ids.extend(self._build_msgids_from_target(targets))
1703 1703 msg_ids.extend(self._build_msgids_from_jobs(jobs))
1704 1704 map(self.results.pop, msg_ids)
1705 1705 map(self.metadata.pop, msg_ids)
1706 1706
1707 1707
1708 1708 @spin_first
1709 1709 def purge_hub_results(self, jobs=[], targets=[]):
1710 1710 """Tell the Hub to forget results.
1711 1711
1712 1712 Individual results can be purged by msg_id, or the entire
1713 1713 history of specific targets can be purged.
1714 1714
1715 1715 Use `purge_results('all')` to scrub everything from the Hub's db.
1716 1716
1717 1717 Parameters
1718 1718 ----------
1719 1719
1720 1720 jobs : str or list of str or AsyncResult objects
1721 1721 the msg_ids whose results should be forgotten.
1722 1722 targets : int/str/list of ints/strs
1723 1723 The targets, by int_id, whose entire history is to be purged.
1724 1724
1725 1725 default : None
1726 1726 """
1727 1727 if not targets and not jobs:
1728 1728 raise ValueError("Must specify at least one of `targets` and `jobs`")
1729 1729 if targets:
1730 1730 targets = self._build_targets(targets)[1]
1731 1731
1732 1732 # construct msg_ids from jobs
1733 1733 if jobs == 'all':
1734 1734 msg_ids = jobs
1735 1735 else:
1736 1736 msg_ids = self._build_msgids_from_jobs(jobs)
1737 1737
1738 1738 content = dict(engine_ids=targets, msg_ids=msg_ids)
1739 1739 self.session.send(self._query_socket, "purge_request", content=content)
1740 1740 idents, msg = self.session.recv(self._query_socket, 0)
1741 1741 if self.debug:
1742 1742 pprint(msg)
1743 1743 content = msg['content']
1744 1744 if content['status'] != 'ok':
1745 1745 raise self._unwrap_exception(content)
1746 1746
1747 1747 def purge_results(self, jobs=[], targets=[]):
1748 1748 """Clears the cached results from both the hub and the local client
1749 1749
1750 1750 Individual results can be purged by msg_id, or the entire
1751 1751 history of specific targets can be purged.
1752 1752
1753 1753 Use `purge_results('all')` to scrub every cached result from both the Hub's and
1754 1754 the Client's db.
1755 1755
1756 1756 Equivalent to calling both `purge_hub_results()` and `purge_client_results()` with
1757 1757 the same arguments.
1758 1758
1759 1759 Parameters
1760 1760 ----------
1761 1761
1762 1762 jobs : str or list of str or AsyncResult objects
1763 1763 the msg_ids whose results should be forgotten.
1764 1764 targets : int/str/list of ints/strs
1765 1765 The targets, by int_id, whose entire history is to be purged.
1766 1766
1767 1767 default : None
1768 1768 """
1769 1769 self.purge_local_results(jobs=jobs, targets=targets)
1770 1770 self.purge_hub_results(jobs=jobs, targets=targets)
1771 1771
1772 1772 def purge_everything(self):
1773 1773 """Clears all content from previous Tasks from both the hub and the local client
1774 1774
1775 1775 In addition to calling `purge_results("all")` it also deletes the history and
1776 1776 other bookkeeping lists.
1777 1777 """
1778 1778 self.purge_results("all")
1779 1779 self.history = []
1780 1780 self.session.digest_history.clear()
1781 1781
1782 1782 @spin_first
1783 1783 def hub_history(self):
1784 1784 """Get the Hub's history
1785 1785
1786 1786 Just like the Client, the Hub has a history, which is a list of msg_ids.
1787 1787 This will contain the history of all clients, and, depending on configuration,
1788 1788 may contain history across multiple cluster sessions.
1789 1789
1790 1790 Any msg_id returned here is a valid argument to `get_result`.
1791 1791
1792 1792 Returns
1793 1793 -------
1794 1794
1795 1795 msg_ids : list of strs
1796 1796 list of all msg_ids, ordered by task submission time.
1797 1797 """
1798 1798
1799 1799 self.session.send(self._query_socket, "history_request", content={})
1800 1800 idents, msg = self.session.recv(self._query_socket, 0)
1801 1801
1802 1802 if self.debug:
1803 1803 pprint(msg)
1804 1804 content = msg['content']
1805 1805 if content['status'] != 'ok':
1806 1806 raise self._unwrap_exception(content)
1807 1807 else:
1808 1808 return content['history']
1809 1809
1810 1810 @spin_first
1811 1811 def db_query(self, query, keys=None):
1812 1812 """Query the Hub's TaskRecord database
1813 1813
1814 1814 This will return a list of task record dicts that match `query`
1815 1815
1816 1816 Parameters
1817 1817 ----------
1818 1818
1819 1819 query : mongodb query dict
1820 1820 The search dict. See mongodb query docs for details.
1821 1821 keys : list of strs [optional]
1822 1822 The subset of keys to be returned. The default is to fetch everything but buffers.
1823 1823 'msg_id' will *always* be included.
1824 1824 """
1825 1825 if isinstance(keys, string_types):
1826 1826 keys = [keys]
1827 1827 content = dict(query=query, keys=keys)
1828 1828 self.session.send(self._query_socket, "db_request", content=content)
1829 1829 idents, msg = self.session.recv(self._query_socket, 0)
1830 1830 if self.debug:
1831 1831 pprint(msg)
1832 1832 content = msg['content']
1833 1833 if content['status'] != 'ok':
1834 1834 raise self._unwrap_exception(content)
1835 1835
1836 1836 records = content['records']
1837 1837
1838 1838 buffer_lens = content['buffer_lens']
1839 1839 result_buffer_lens = content['result_buffer_lens']
1840 1840 buffers = msg['buffers']
1841 1841 has_bufs = buffer_lens is not None
1842 1842 has_rbufs = result_buffer_lens is not None
1843 1843 for i,rec in enumerate(records):
1844 1844 # relink buffers
1845 1845 if has_bufs:
1846 1846 blen = buffer_lens[i]
1847 1847 rec['buffers'], buffers = buffers[:blen],buffers[blen:]
1848 1848 if has_rbufs:
1849 1849 blen = result_buffer_lens[i]
1850 1850 rec['result_buffers'], buffers = buffers[:blen],buffers[blen:]
1851 1851
1852 1852 return records
1853 1853
1854 1854 __all__ = [ 'Client' ]
@@ -1,369 +1,369 b''
1 1 """some generic utilities for dealing with classes, urls, and serialization
2 2
3 3 Authors:
4 4
5 5 * Min RK
6 6 """
7 7 #-----------------------------------------------------------------------------
8 8 # Copyright (C) 2010-2011 The IPython Development Team
9 9 #
10 10 # Distributed under the terms of the BSD License. The full license is in
11 11 # the file COPYING, distributed as part of this software.
12 12 #-----------------------------------------------------------------------------
13 13
14 14 #-----------------------------------------------------------------------------
15 15 # Imports
16 16 #-----------------------------------------------------------------------------
17 17
18 18 # Standard library imports.
19 19 import logging
20 20 import os
21 21 import re
22 22 import stat
23 23 import socket
24 24 import sys
25 25 from signal import signal, SIGINT, SIGABRT, SIGTERM
26 26 try:
27 27 from signal import SIGKILL
28 28 except ImportError:
29 29 SIGKILL=None
30 30
31 31 try:
32 32 import cPickle
33 33 pickle = cPickle
34 34 except:
35 35 cPickle = None
36 36 import pickle
37 37
38 38 # System library imports
39 39 import zmq
40 40 from zmq.log import handlers
41 41
42 42 from IPython.external.decorator import decorator
43 43
44 44 # IPython imports
45 45 from IPython.config.application import Application
46 46 from IPython.utils.localinterfaces import localhost, is_public_ip, public_ips
47 47 from IPython.utils.py3compat import string_types
48 48 from IPython.kernel.zmq.log import EnginePUBHandler
49 49 from IPython.kernel.zmq.serialize import (
50 50 unserialize_object, serialize_object, pack_apply_message, unpack_apply_message
51 51 )
52 52
53 53 #-----------------------------------------------------------------------------
54 54 # Classes
55 55 #-----------------------------------------------------------------------------
56 56
57 57 class Namespace(dict):
58 58 """Subclass of dict for attribute access to keys."""
59 59
60 60 def __getattr__(self, key):
61 61 """getattr aliased to getitem"""
62 if key in self.iterkeys():
62 if key in self:
63 63 return self[key]
64 64 else:
65 65 raise NameError(key)
66 66
67 67 def __setattr__(self, key, value):
68 68 """setattr aliased to setitem, with strict"""
69 69 if hasattr(dict, key):
70 70 raise KeyError("Cannot override dict keys %r"%key)
71 71 self[key] = value
72 72
73 73
74 74 class ReverseDict(dict):
75 75 """simple double-keyed subset of dict methods."""
76 76
77 77 def __init__(self, *args, **kwargs):
78 78 dict.__init__(self, *args, **kwargs)
79 79 self._reverse = dict()
80 80 for key, value in self.iteritems():
81 81 self._reverse[value] = key
82 82
83 83 def __getitem__(self, key):
84 84 try:
85 85 return dict.__getitem__(self, key)
86 86 except KeyError:
87 87 return self._reverse[key]
88 88
89 89 def __setitem__(self, key, value):
90 90 if key in self._reverse:
91 91 raise KeyError("Can't have key %r on both sides!"%key)
92 92 dict.__setitem__(self, key, value)
93 93 self._reverse[value] = key
94 94
95 95 def pop(self, key):
96 96 value = dict.pop(self, key)
97 97 self._reverse.pop(value)
98 98 return value
99 99
100 100 def get(self, key, default=None):
101 101 try:
102 102 return self[key]
103 103 except KeyError:
104 104 return default
105 105
106 106 #-----------------------------------------------------------------------------
107 107 # Functions
108 108 #-----------------------------------------------------------------------------
109 109
110 110 @decorator
111 111 def log_errors(f, self, *args, **kwargs):
112 112 """decorator to log unhandled exceptions raised in a method.
113 113
114 114 For use wrapping on_recv callbacks, so that exceptions
115 115 do not cause the stream to be closed.
116 116 """
117 117 try:
118 118 return f(self, *args, **kwargs)
119 119 except Exception:
120 120 self.log.error("Uncaught exception in %r" % f, exc_info=True)
121 121
122 122
123 123 def is_url(url):
124 124 """boolean check for whether a string is a zmq url"""
125 125 if '://' not in url:
126 126 return False
127 127 proto, addr = url.split('://', 1)
128 128 if proto.lower() not in ['tcp','pgm','epgm','ipc','inproc']:
129 129 return False
130 130 return True
131 131
132 132 def validate_url(url):
133 133 """validate a url for zeromq"""
134 134 if not isinstance(url, string_types):
135 135 raise TypeError("url must be a string, not %r"%type(url))
136 136 url = url.lower()
137 137
138 138 proto_addr = url.split('://')
139 139 assert len(proto_addr) == 2, 'Invalid url: %r'%url
140 140 proto, addr = proto_addr
141 141 assert proto in ['tcp','pgm','epgm','ipc','inproc'], "Invalid protocol: %r"%proto
142 142
143 143 # domain pattern adapted from http://www.regexlib.com/REDetails.aspx?regexp_id=391
144 144 # author: Remi Sabourin
145 145 pat = re.compile(r'^([\w\d]([\w\d\-]{0,61}[\w\d])?\.)*[\w\d]([\w\d\-]{0,61}[\w\d])?$')
146 146
147 147 if proto == 'tcp':
148 148 lis = addr.split(':')
149 149 assert len(lis) == 2, 'Invalid url: %r'%url
150 150 addr,s_port = lis
151 151 try:
152 152 port = int(s_port)
153 153 except ValueError:
154 154 raise AssertionError("Invalid port %r in url: %r"%(port, url))
155 155
156 156 assert addr == '*' or pat.match(addr) is not None, 'Invalid url: %r'%url
157 157
158 158 else:
159 159 # only validate tcp urls currently
160 160 pass
161 161
162 162 return True
163 163
164 164
165 165 def validate_url_container(container):
166 166 """validate a potentially nested collection of urls."""
167 167 if isinstance(container, string_types):
168 168 url = container
169 169 return validate_url(url)
170 170 elif isinstance(container, dict):
171 171 container = container.itervalues()
172 172
173 173 for element in container:
174 174 validate_url_container(element)
175 175
176 176
177 177 def split_url(url):
178 178 """split a zmq url (tcp://ip:port) into ('tcp','ip','port')."""
179 179 proto_addr = url.split('://')
180 180 assert len(proto_addr) == 2, 'Invalid url: %r'%url
181 181 proto, addr = proto_addr
182 182 lis = addr.split(':')
183 183 assert len(lis) == 2, 'Invalid url: %r'%url
184 184 addr,s_port = lis
185 185 return proto,addr,s_port
186 186
187 187 def disambiguate_ip_address(ip, location=None):
188 188 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
189 189 ones, based on the location (default interpretation of location is localhost)."""
190 190 if ip in ('0.0.0.0', '*'):
191 191 if location is None or is_public_ip(location) or not public_ips():
192 192 # If location is unspecified or cannot be determined, assume local
193 193 ip = localhost()
194 194 elif location:
195 195 return location
196 196 return ip
197 197
198 198 def disambiguate_url(url, location=None):
199 199 """turn multi-ip interfaces '0.0.0.0' and '*' into connectable
200 200 ones, based on the location (default interpretation is localhost).
201 201
202 202 This is for zeromq urls, such as tcp://*:10101."""
203 203 try:
204 204 proto,ip,port = split_url(url)
205 205 except AssertionError:
206 206 # probably not tcp url; could be ipc, etc.
207 207 return url
208 208
209 209 ip = disambiguate_ip_address(ip,location)
210 210
211 211 return "%s://%s:%s"%(proto,ip,port)
212 212
213 213
214 214 #--------------------------------------------------------------------------
215 215 # helpers for implementing old MEC API via view.apply
216 216 #--------------------------------------------------------------------------
217 217
218 218 def interactive(f):
219 219 """decorator for making functions appear as interactively defined.
220 220 This results in the function being linked to the user_ns as globals()
221 221 instead of the module globals().
222 222 """
223 223 f.__module__ = '__main__'
224 224 return f
225 225
226 226 @interactive
227 227 def _push(**ns):
228 228 """helper method for implementing `client.push` via `client.apply`"""
229 229 user_ns = globals()
230 230 tmp = '_IP_PUSH_TMP_'
231 231 while tmp in user_ns:
232 232 tmp = tmp + '_'
233 233 try:
234 234 for name, value in ns.iteritems():
235 235 user_ns[tmp] = value
236 236 exec("%s = %s" % (name, tmp), user_ns)
237 237 finally:
238 238 user_ns.pop(tmp, None)
239 239
240 240 @interactive
241 241 def _pull(keys):
242 242 """helper method for implementing `client.pull` via `client.apply`"""
243 243 if isinstance(keys, (list,tuple, set)):
244 244 return map(lambda key: eval(key, globals()), keys)
245 245 else:
246 246 return eval(keys, globals())
247 247
248 248 @interactive
249 249 def _execute(code):
250 250 """helper method for implementing `client.execute` via `client.apply`"""
251 251 exec(code, globals())
252 252
253 253 #--------------------------------------------------------------------------
254 254 # extra process management utilities
255 255 #--------------------------------------------------------------------------
256 256
257 257 _random_ports = set()
258 258
259 259 def select_random_ports(n):
260 260 """Selects and return n random ports that are available."""
261 261 ports = []
262 262 for i in range(n):
263 263 sock = socket.socket()
264 264 sock.bind(('', 0))
265 265 while sock.getsockname()[1] in _random_ports:
266 266 sock.close()
267 267 sock = socket.socket()
268 268 sock.bind(('', 0))
269 269 ports.append(sock)
270 270 for i, sock in enumerate(ports):
271 271 port = sock.getsockname()[1]
272 272 sock.close()
273 273 ports[i] = port
274 274 _random_ports.add(port)
275 275 return ports
276 276
277 277 def signal_children(children):
278 278 """Relay interupt/term signals to children, for more solid process cleanup."""
279 279 def terminate_children(sig, frame):
280 280 log = Application.instance().log
281 281 log.critical("Got signal %i, terminating children..."%sig)
282 282 for child in children:
283 283 child.terminate()
284 284
285 285 sys.exit(sig != SIGINT)
286 286 # sys.exit(sig)
287 287 for sig in (SIGINT, SIGABRT, SIGTERM):
288 288 signal(sig, terminate_children)
289 289
290 290 def generate_exec_key(keyfile):
291 291 import uuid
292 292 newkey = str(uuid.uuid4())
293 293 with open(keyfile, 'w') as f:
294 294 # f.write('ipython-key ')
295 295 f.write(newkey+'\n')
296 296 # set user-only RW permissions (0600)
297 297 # this will have no effect on Windows
298 298 os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR)
299 299
300 300
301 301 def integer_loglevel(loglevel):
302 302 try:
303 303 loglevel = int(loglevel)
304 304 except ValueError:
305 305 if isinstance(loglevel, str):
306 306 loglevel = getattr(logging, loglevel)
307 307 return loglevel
308 308
309 309 def connect_logger(logname, context, iface, root="ip", loglevel=logging.DEBUG):
310 310 logger = logging.getLogger(logname)
311 311 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
312 312 # don't add a second PUBHandler
313 313 return
314 314 loglevel = integer_loglevel(loglevel)
315 315 lsock = context.socket(zmq.PUB)
316 316 lsock.connect(iface)
317 317 handler = handlers.PUBHandler(lsock)
318 318 handler.setLevel(loglevel)
319 319 handler.root_topic = root
320 320 logger.addHandler(handler)
321 321 logger.setLevel(loglevel)
322 322
323 323 def connect_engine_logger(context, iface, engine, loglevel=logging.DEBUG):
324 324 logger = logging.getLogger()
325 325 if any([isinstance(h, handlers.PUBHandler) for h in logger.handlers]):
326 326 # don't add a second PUBHandler
327 327 return
328 328 loglevel = integer_loglevel(loglevel)
329 329 lsock = context.socket(zmq.PUB)
330 330 lsock.connect(iface)
331 331 handler = EnginePUBHandler(engine, lsock)
332 332 handler.setLevel(loglevel)
333 333 logger.addHandler(handler)
334 334 logger.setLevel(loglevel)
335 335 return logger
336 336
337 337 def local_logger(logname, loglevel=logging.DEBUG):
338 338 loglevel = integer_loglevel(loglevel)
339 339 logger = logging.getLogger(logname)
340 340 if any([isinstance(h, logging.StreamHandler) for h in logger.handlers]):
341 341 # don't add a second StreamHandler
342 342 return
343 343 handler = logging.StreamHandler()
344 344 handler.setLevel(loglevel)
345 345 formatter = logging.Formatter("%(asctime)s.%(msecs).03d [%(name)s] %(message)s",
346 346 datefmt="%Y-%m-%d %H:%M:%S")
347 347 handler.setFormatter(formatter)
348 348
349 349 logger.addHandler(handler)
350 350 logger.setLevel(loglevel)
351 351 return logger
352 352
353 353 def set_hwm(sock, hwm=0):
354 354 """set zmq High Water Mark on a socket
355 355
356 356 in a way that always works for various pyzmq / libzmq versions.
357 357 """
358 358 import zmq
359 359
360 360 for key in ('HWM', 'SNDHWM', 'RCVHWM'):
361 361 opt = getattr(zmq, key, None)
362 362 if opt is None:
363 363 continue
364 364 try:
365 365 sock.setsockopt(opt, hwm)
366 366 except zmq.ZMQError:
367 367 pass
368 368
369 369 No newline at end of file
@@ -1,229 +1,229 b''
1 1 """Utilities to manipulate JSON objects.
2 2 """
3 3 #-----------------------------------------------------------------------------
4 4 # Copyright (C) 2010-2011 The IPython Development Team
5 5 #
6 6 # Distributed under the terms of the BSD License. The full license is in
7 7 # the file COPYING.txt, distributed as part of this software.
8 8 #-----------------------------------------------------------------------------
9 9
10 10 #-----------------------------------------------------------------------------
11 11 # Imports
12 12 #-----------------------------------------------------------------------------
13 13 # stdlib
14 14 import math
15 15 import re
16 16 import types
17 17 from datetime import datetime
18 18
19 19 try:
20 20 # base64.encodestring is deprecated in Python 3.x
21 21 from base64 import encodebytes
22 22 except ImportError:
23 23 # Python 2.x
24 24 from base64 import encodestring as encodebytes
25 25
26 26 from IPython.utils import py3compat
27 27 from IPython.utils.py3compat import string_types, unicode_type
28 28 from IPython.utils.encoding import DEFAULT_ENCODING
29 29 next_attr_name = '__next__' if py3compat.PY3 else 'next'
30 30
31 31 #-----------------------------------------------------------------------------
32 32 # Globals and constants
33 33 #-----------------------------------------------------------------------------
34 34
35 35 # timestamp formats
36 36 ISO8601="%Y-%m-%dT%H:%M:%S.%f"
37 37 ISO8601_PAT=re.compile(r"^(\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+)Z?([\+\-]\d{2}:?\d{2})?$")
38 38
39 39 #-----------------------------------------------------------------------------
40 40 # Classes and functions
41 41 #-----------------------------------------------------------------------------
42 42
43 43 def rekey(dikt):
44 44 """Rekey a dict that has been forced to use str keys where there should be
45 45 ints by json."""
46 for k in dikt.iterkeys():
46 for k in dikt:
47 47 if isinstance(k, string_types):
48 48 ik=fk=None
49 49 try:
50 50 ik = int(k)
51 51 except ValueError:
52 52 try:
53 53 fk = float(k)
54 54 except ValueError:
55 55 continue
56 56 if ik is not None:
57 57 nk = ik
58 58 else:
59 59 nk = fk
60 60 if nk in dikt:
61 61 raise KeyError("already have key %r"%nk)
62 62 dikt[nk] = dikt.pop(k)
63 63 return dikt
64 64
65 65
66 66 def extract_dates(obj):
67 67 """extract ISO8601 dates from unpacked JSON"""
68 68 if isinstance(obj, dict):
69 69 obj = dict(obj) # don't clobber
70 70 for k,v in obj.iteritems():
71 71 obj[k] = extract_dates(v)
72 72 elif isinstance(obj, (list, tuple)):
73 73 obj = [ extract_dates(o) for o in obj ]
74 74 elif isinstance(obj, string_types):
75 75 m = ISO8601_PAT.match(obj)
76 76 if m:
77 77 # FIXME: add actual timezone support
78 78 # this just drops the timezone info
79 79 notz = m.groups()[0]
80 80 obj = datetime.strptime(notz, ISO8601)
81 81 return obj
82 82
83 83 def squash_dates(obj):
84 84 """squash datetime objects into ISO8601 strings"""
85 85 if isinstance(obj, dict):
86 86 obj = dict(obj) # don't clobber
87 87 for k,v in obj.iteritems():
88 88 obj[k] = squash_dates(v)
89 89 elif isinstance(obj, (list, tuple)):
90 90 obj = [ squash_dates(o) for o in obj ]
91 91 elif isinstance(obj, datetime):
92 92 obj = obj.isoformat()
93 93 return obj
94 94
95 95 def date_default(obj):
96 96 """default function for packing datetime objects in JSON."""
97 97 if isinstance(obj, datetime):
98 98 return obj.isoformat()
99 99 else:
100 100 raise TypeError("%r is not JSON serializable"%obj)
101 101
102 102
103 103 # constants for identifying png/jpeg data
104 104 PNG = b'\x89PNG\r\n\x1a\n'
105 105 # front of PNG base64-encoded
106 106 PNG64 = b'iVBORw0KG'
107 107 JPEG = b'\xff\xd8'
108 108 # front of JPEG base64-encoded
109 109 JPEG64 = b'/9'
110 110
111 111 def encode_images(format_dict):
112 112 """b64-encodes images in a displaypub format dict
113 113
114 114 Perhaps this should be handled in json_clean itself?
115 115
116 116 Parameters
117 117 ----------
118 118
119 119 format_dict : dict
120 120 A dictionary of display data keyed by mime-type
121 121
122 122 Returns
123 123 -------
124 124
125 125 format_dict : dict
126 126 A copy of the same dictionary,
127 127 but binary image data ('image/png' or 'image/jpeg')
128 128 is base64-encoded.
129 129
130 130 """
131 131 encoded = format_dict.copy()
132 132
133 133 pngdata = format_dict.get('image/png')
134 134 if isinstance(pngdata, bytes):
135 135 # make sure we don't double-encode
136 136 if not pngdata.startswith(PNG64):
137 137 pngdata = encodebytes(pngdata)
138 138 encoded['image/png'] = pngdata.decode('ascii')
139 139
140 140 jpegdata = format_dict.get('image/jpeg')
141 141 if isinstance(jpegdata, bytes):
142 142 # make sure we don't double-encode
143 143 if not jpegdata.startswith(JPEG64):
144 144 jpegdata = encodebytes(jpegdata)
145 145 encoded['image/jpeg'] = jpegdata.decode('ascii')
146 146
147 147 return encoded
148 148
149 149
150 150 def json_clean(obj):
151 151 """Clean an object to ensure it's safe to encode in JSON.
152 152
153 153 Atomic, immutable objects are returned unmodified. Sets and tuples are
154 154 converted to lists, lists are copied and dicts are also copied.
155 155
156 156 Note: dicts whose keys could cause collisions upon encoding (such as a dict
157 157 with both the number 1 and the string '1' as keys) will cause a ValueError
158 158 to be raised.
159 159
160 160 Parameters
161 161 ----------
162 162 obj : any python object
163 163
164 164 Returns
165 165 -------
166 166 out : object
167 167
168 168 A version of the input which will not cause an encoding error when
169 169 encoded as JSON. Note that this function does not *encode* its inputs,
170 170 it simply sanitizes it so that there will be no encoding errors later.
171 171
172 172 Examples
173 173 --------
174 174 >>> json_clean(4)
175 175 4
176 176 >>> json_clean(range(10))
177 177 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
178 178 >>> sorted(json_clean(dict(x=1, y=2)).items())
179 179 [('x', 1), ('y', 2)]
180 180 >>> sorted(json_clean(dict(x=1, y=2, z=[1,2,3])).items())
181 181 [('x', 1), ('y', 2), ('z', [1, 2, 3])]
182 182 >>> json_clean(True)
183 183 True
184 184 """
185 185 # types that are 'atomic' and ok in json as-is. bool doesn't need to be
186 186 # listed explicitly because bools pass as int instances
187 187 atomic_ok = (unicode_type, int, types.NoneType)
188 188
189 189 # containers that we need to convert into lists
190 190 container_to_list = (tuple, set, types.GeneratorType)
191 191
192 192 if isinstance(obj, float):
193 193 # cast out-of-range floats to their reprs
194 194 if math.isnan(obj) or math.isinf(obj):
195 195 return repr(obj)
196 196 return obj
197 197
198 198 if isinstance(obj, atomic_ok):
199 199 return obj
200 200
201 201 if isinstance(obj, bytes):
202 202 return obj.decode(DEFAULT_ENCODING, 'replace')
203 203
204 204 if isinstance(obj, container_to_list) or (
205 205 hasattr(obj, '__iter__') and hasattr(obj, next_attr_name)):
206 206 obj = list(obj)
207 207
208 208 if isinstance(obj, list):
209 209 return [json_clean(x) for x in obj]
210 210
211 211 if isinstance(obj, dict):
212 212 # First, validate that the dict won't lose data in conversion due to
213 213 # key collisions after stringification. This can happen with keys like
214 214 # True and 'true' or 1 and '1', which collide in JSON.
215 215 nkeys = len(obj)
216 216 nkeys_collapsed = len(set(map(str, obj)))
217 217 if nkeys != nkeys_collapsed:
218 218 raise ValueError('dict can not be safely converted to JSON: '
219 219 'key collision would lead to dropped values')
220 220 # If all OK, proceed by making the new dict that will be json-safe
221 221 out = {}
222 222 for k,v in obj.iteritems():
223 223 out[str(k)] = json_clean(v)
224 224 return out
225 225
226 226 # If we get here, we don't know how to handle the object, so we just get
227 227 # its repr and return that. This will catch lambdas, open sockets, class
228 228 # objects, and any other complicated contraption that json can't encode
229 229 return repr(obj)
General Comments 0
You need to be logged in to leave comments. Login now