##// END OF EJS Templates
add signature_scheme to Session and connection files...
MinRK -
Show More
@@ -1,390 +1,392 b''
1 """ A minimal application base mixin for all ZMQ based IPython frontends.
1 """ A minimal application base mixin for all ZMQ based IPython frontends.
2
2
3 This is not a complete console app, as subprocess will not be able to receive
3 This is not a complete console app, as subprocess will not be able to receive
4 input, there is no real readline support, among other limitations. This is a
4 input, there is no real readline support, among other limitations. This is a
5 refactoring of what used to be the IPython/qt/console/qtconsoleapp.py
5 refactoring of what used to be the IPython/qt/console/qtconsoleapp.py
6
6
7 Authors:
7 Authors:
8
8
9 * Evan Patterson
9 * Evan Patterson
10 * Min RK
10 * Min RK
11 * Erik Tollerud
11 * Erik Tollerud
12 * Fernando Perez
12 * Fernando Perez
13 * Bussonnier Matthias
13 * Bussonnier Matthias
14 * Thomas Kluyver
14 * Thomas Kluyver
15 * Paul Ivanov
15 * Paul Ivanov
16
16
17 """
17 """
18
18
19 #-----------------------------------------------------------------------------
19 #-----------------------------------------------------------------------------
20 # Imports
20 # Imports
21 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
22
22
23 # stdlib imports
23 # stdlib imports
24 import atexit
24 import atexit
25 import json
25 import json
26 import os
26 import os
27 import shutil
27 import shutil
28 import signal
28 import signal
29 import sys
29 import sys
30 import uuid
30 import uuid
31
31
32
32
33 # Local imports
33 # Local imports
34 from IPython.config.application import boolean_flag
34 from IPython.config.application import boolean_flag
35 from IPython.config.configurable import Configurable
35 from IPython.config.configurable import Configurable
36 from IPython.core.profiledir import ProfileDir
36 from IPython.core.profiledir import ProfileDir
37 from IPython.kernel.blocking import BlockingKernelClient
37 from IPython.kernel.blocking import BlockingKernelClient
38 from IPython.kernel import KernelManager
38 from IPython.kernel import KernelManager
39 from IPython.kernel import tunnel_to_kernel, find_connection_file, swallow_argv
39 from IPython.kernel import tunnel_to_kernel, find_connection_file, swallow_argv
40 from IPython.utils.path import filefind
40 from IPython.utils.path import filefind
41 from IPython.utils.py3compat import str_to_bytes
41 from IPython.utils.py3compat import str_to_bytes
42 from IPython.utils.traitlets import (
42 from IPython.utils.traitlets import (
43 Dict, List, Unicode, CUnicode, Int, CBool, Any, CaselessStrEnum
43 Dict, List, Unicode, CUnicode, Int, CBool, Any, CaselessStrEnum
44 )
44 )
45 from IPython.kernel.zmq.kernelapp import (
45 from IPython.kernel.zmq.kernelapp import (
46 kernel_flags,
46 kernel_flags,
47 kernel_aliases,
47 kernel_aliases,
48 IPKernelApp
48 IPKernelApp
49 )
49 )
50 from IPython.kernel.zmq.session import Session, default_secure
50 from IPython.kernel.zmq.session import Session, default_secure
51 from IPython.kernel.zmq.zmqshell import ZMQInteractiveShell
51 from IPython.kernel.zmq.zmqshell import ZMQInteractiveShell
52
52
53 #-----------------------------------------------------------------------------
53 #-----------------------------------------------------------------------------
54 # Network Constants
54 # Network Constants
55 #-----------------------------------------------------------------------------
55 #-----------------------------------------------------------------------------
56
56
57 from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS
57 from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS
58
58
59 #-----------------------------------------------------------------------------
59 #-----------------------------------------------------------------------------
60 # Globals
60 # Globals
61 #-----------------------------------------------------------------------------
61 #-----------------------------------------------------------------------------
62
62
63
63
64 #-----------------------------------------------------------------------------
64 #-----------------------------------------------------------------------------
65 # Aliases and Flags
65 # Aliases and Flags
66 #-----------------------------------------------------------------------------
66 #-----------------------------------------------------------------------------
67
67
68 flags = dict(kernel_flags)
68 flags = dict(kernel_flags)
69
69
70 # the flags that are specific to the frontend
70 # the flags that are specific to the frontend
71 # these must be scrubbed before being passed to the kernel,
71 # these must be scrubbed before being passed to the kernel,
72 # or it will raise an error on unrecognized flags
72 # or it will raise an error on unrecognized flags
73 app_flags = {
73 app_flags = {
74 'existing' : ({'IPythonConsoleApp' : {'existing' : 'kernel*.json'}},
74 'existing' : ({'IPythonConsoleApp' : {'existing' : 'kernel*.json'}},
75 "Connect to an existing kernel. If no argument specified, guess most recent"),
75 "Connect to an existing kernel. If no argument specified, guess most recent"),
76 }
76 }
77 app_flags.update(boolean_flag(
77 app_flags.update(boolean_flag(
78 'confirm-exit', 'IPythonConsoleApp.confirm_exit',
78 'confirm-exit', 'IPythonConsoleApp.confirm_exit',
79 """Set to display confirmation dialog on exit. You can always use 'exit' or 'quit',
79 """Set to display confirmation dialog on exit. You can always use 'exit' or 'quit',
80 to force a direct exit without any confirmation.
80 to force a direct exit without any confirmation.
81 """,
81 """,
82 """Don't prompt the user when exiting. This will terminate the kernel
82 """Don't prompt the user when exiting. This will terminate the kernel
83 if it is owned by the frontend, and leave it alive if it is external.
83 if it is owned by the frontend, and leave it alive if it is external.
84 """
84 """
85 ))
85 ))
86 flags.update(app_flags)
86 flags.update(app_flags)
87
87
88 aliases = dict(kernel_aliases)
88 aliases = dict(kernel_aliases)
89
89
90 # also scrub aliases from the frontend
90 # also scrub aliases from the frontend
91 app_aliases = dict(
91 app_aliases = dict(
92 ip = 'KernelManager.ip',
92 ip = 'KernelManager.ip',
93 transport = 'KernelManager.transport',
93 transport = 'KernelManager.transport',
94 hb = 'IPythonConsoleApp.hb_port',
94 hb = 'IPythonConsoleApp.hb_port',
95 shell = 'IPythonConsoleApp.shell_port',
95 shell = 'IPythonConsoleApp.shell_port',
96 iopub = 'IPythonConsoleApp.iopub_port',
96 iopub = 'IPythonConsoleApp.iopub_port',
97 stdin = 'IPythonConsoleApp.stdin_port',
97 stdin = 'IPythonConsoleApp.stdin_port',
98 existing = 'IPythonConsoleApp.existing',
98 existing = 'IPythonConsoleApp.existing',
99 f = 'IPythonConsoleApp.connection_file',
99 f = 'IPythonConsoleApp.connection_file',
100
100
101
101
102 ssh = 'IPythonConsoleApp.sshserver',
102 ssh = 'IPythonConsoleApp.sshserver',
103 )
103 )
104 aliases.update(app_aliases)
104 aliases.update(app_aliases)
105
105
106 #-----------------------------------------------------------------------------
106 #-----------------------------------------------------------------------------
107 # Classes
107 # Classes
108 #-----------------------------------------------------------------------------
108 #-----------------------------------------------------------------------------
109
109
110 #-----------------------------------------------------------------------------
110 #-----------------------------------------------------------------------------
111 # IPythonConsole
111 # IPythonConsole
112 #-----------------------------------------------------------------------------
112 #-----------------------------------------------------------------------------
113
113
114 classes = [IPKernelApp, ZMQInteractiveShell, KernelManager, ProfileDir, Session]
114 classes = [IPKernelApp, ZMQInteractiveShell, KernelManager, ProfileDir, Session]
115
115
116 try:
116 try:
117 from IPython.kernel.zmq.pylab.backend_inline import InlineBackend
117 from IPython.kernel.zmq.pylab.backend_inline import InlineBackend
118 except ImportError:
118 except ImportError:
119 pass
119 pass
120 else:
120 else:
121 classes.append(InlineBackend)
121 classes.append(InlineBackend)
122
122
123 class IPythonConsoleApp(Configurable):
123 class IPythonConsoleApp(Configurable):
124 name = 'ipython-console-mixin'
124 name = 'ipython-console-mixin'
125
125
126 description = """
126 description = """
127 The IPython Mixin Console.
127 The IPython Mixin Console.
128
128
129 This class contains the common portions of console client (QtConsole,
129 This class contains the common portions of console client (QtConsole,
130 ZMQ-based terminal console, etc). It is not a full console, in that
130 ZMQ-based terminal console, etc). It is not a full console, in that
131 launched terminal subprocesses will not be able to accept input.
131 launched terminal subprocesses will not be able to accept input.
132
132
133 The Console using this mixing supports various extra features beyond
133 The Console using this mixing supports various extra features beyond
134 the single-process Terminal IPython shell, such as connecting to
134 the single-process Terminal IPython shell, such as connecting to
135 existing kernel, via:
135 existing kernel, via:
136
136
137 ipython <appname> --existing
137 ipython <appname> --existing
138
138
139 as well as tunnel via SSH
139 as well as tunnel via SSH
140
140
141 """
141 """
142
142
143 classes = classes
143 classes = classes
144 flags = Dict(flags)
144 flags = Dict(flags)
145 aliases = Dict(aliases)
145 aliases = Dict(aliases)
146 kernel_manager_class = KernelManager
146 kernel_manager_class = KernelManager
147 kernel_client_class = BlockingKernelClient
147 kernel_client_class = BlockingKernelClient
148
148
149 kernel_argv = List(Unicode)
149 kernel_argv = List(Unicode)
150 # frontend flags&aliases to be stripped when building kernel_argv
150 # frontend flags&aliases to be stripped when building kernel_argv
151 frontend_flags = Any(app_flags)
151 frontend_flags = Any(app_flags)
152 frontend_aliases = Any(app_aliases)
152 frontend_aliases = Any(app_aliases)
153
153
154 # create requested profiles by default, if they don't exist:
154 # create requested profiles by default, if they don't exist:
155 auto_create = CBool(True)
155 auto_create = CBool(True)
156 # connection info:
156 # connection info:
157
157
158 sshserver = Unicode('', config=True,
158 sshserver = Unicode('', config=True,
159 help="""The SSH server to use to connect to the kernel.""")
159 help="""The SSH server to use to connect to the kernel.""")
160 sshkey = Unicode('', config=True,
160 sshkey = Unicode('', config=True,
161 help="""Path to the ssh key to use for logging in to the ssh server.""")
161 help="""Path to the ssh key to use for logging in to the ssh server.""")
162
162
163 hb_port = Int(0, config=True,
163 hb_port = Int(0, config=True,
164 help="set the heartbeat port [default: random]")
164 help="set the heartbeat port [default: random]")
165 shell_port = Int(0, config=True,
165 shell_port = Int(0, config=True,
166 help="set the shell (ROUTER) port [default: random]")
166 help="set the shell (ROUTER) port [default: random]")
167 iopub_port = Int(0, config=True,
167 iopub_port = Int(0, config=True,
168 help="set the iopub (PUB) port [default: random]")
168 help="set the iopub (PUB) port [default: random]")
169 stdin_port = Int(0, config=True,
169 stdin_port = Int(0, config=True,
170 help="set the stdin (DEALER) port [default: random]")
170 help="set the stdin (DEALER) port [default: random]")
171 connection_file = Unicode('', config=True,
171 connection_file = Unicode('', config=True,
172 help="""JSON file in which to store connection info [default: kernel-<pid>.json]
172 help="""JSON file in which to store connection info [default: kernel-<pid>.json]
173
173
174 This file will contain the IP, ports, and authentication key needed to connect
174 This file will contain the IP, ports, and authentication key needed to connect
175 clients to this kernel. By default, this file will be created in the security-dir
175 clients to this kernel. By default, this file will be created in the security-dir
176 of the current profile, but can be specified by absolute path.
176 of the current profile, but can be specified by absolute path.
177 """)
177 """)
178 def _connection_file_default(self):
178 def _connection_file_default(self):
179 return 'kernel-%i.json' % os.getpid()
179 return 'kernel-%i.json' % os.getpid()
180
180
181 existing = CUnicode('', config=True,
181 existing = CUnicode('', config=True,
182 help="""Connect to an already running kernel""")
182 help="""Connect to an already running kernel""")
183
183
184 confirm_exit = CBool(True, config=True,
184 confirm_exit = CBool(True, config=True,
185 help="""
185 help="""
186 Set to display confirmation dialog on exit. You can always use 'exit' or 'quit',
186 Set to display confirmation dialog on exit. You can always use 'exit' or 'quit',
187 to force a direct exit without any confirmation.""",
187 to force a direct exit without any confirmation.""",
188 )
188 )
189
189
190
190
191 def build_kernel_argv(self, argv=None):
191 def build_kernel_argv(self, argv=None):
192 """build argv to be passed to kernel subprocess"""
192 """build argv to be passed to kernel subprocess"""
193 if argv is None:
193 if argv is None:
194 argv = sys.argv[1:]
194 argv = sys.argv[1:]
195 self.kernel_argv = swallow_argv(argv, self.frontend_aliases, self.frontend_flags)
195 self.kernel_argv = swallow_argv(argv, self.frontend_aliases, self.frontend_flags)
196 # kernel should inherit default config file from frontend
196 # kernel should inherit default config file from frontend
197 self.kernel_argv.append("--IPKernelApp.parent_appname='%s'" % self.name)
197 self.kernel_argv.append("--IPKernelApp.parent_appname='%s'" % self.name)
198
198
199 def init_connection_file(self):
199 def init_connection_file(self):
200 """find the connection file, and load the info if found.
200 """find the connection file, and load the info if found.
201
201
202 The current working directory and the current profile's security
202 The current working directory and the current profile's security
203 directory will be searched for the file if it is not given by
203 directory will be searched for the file if it is not given by
204 absolute path.
204 absolute path.
205
205
206 When attempting to connect to an existing kernel and the `--existing`
206 When attempting to connect to an existing kernel and the `--existing`
207 argument does not match an existing file, it will be interpreted as a
207 argument does not match an existing file, it will be interpreted as a
208 fileglob, and the matching file in the current profile's security dir
208 fileglob, and the matching file in the current profile's security dir
209 with the latest access time will be used.
209 with the latest access time will be used.
210
210
211 After this method is called, self.connection_file contains the *full path*
211 After this method is called, self.connection_file contains the *full path*
212 to the connection file, never just its name.
212 to the connection file, never just its name.
213 """
213 """
214 if self.existing:
214 if self.existing:
215 try:
215 try:
216 cf = find_connection_file(self.existing)
216 cf = find_connection_file(self.existing)
217 except Exception:
217 except Exception:
218 self.log.critical("Could not find existing kernel connection file %s", self.existing)
218 self.log.critical("Could not find existing kernel connection file %s", self.existing)
219 self.exit(1)
219 self.exit(1)
220 self.log.info("Connecting to existing kernel: %s" % cf)
220 self.log.info("Connecting to existing kernel: %s" % cf)
221 self.connection_file = cf
221 self.connection_file = cf
222 else:
222 else:
223 # not existing, check if we are going to write the file
223 # not existing, check if we are going to write the file
224 # and ensure that self.connection_file is a full path, not just the shortname
224 # and ensure that self.connection_file is a full path, not just the shortname
225 try:
225 try:
226 cf = find_connection_file(self.connection_file)
226 cf = find_connection_file(self.connection_file)
227 except Exception:
227 except Exception:
228 # file might not exist
228 # file might not exist
229 if self.connection_file == os.path.basename(self.connection_file):
229 if self.connection_file == os.path.basename(self.connection_file):
230 # just shortname, put it in security dir
230 # just shortname, put it in security dir
231 cf = os.path.join(self.profile_dir.security_dir, self.connection_file)
231 cf = os.path.join(self.profile_dir.security_dir, self.connection_file)
232 else:
232 else:
233 cf = self.connection_file
233 cf = self.connection_file
234 self.connection_file = cf
234 self.connection_file = cf
235
235
236 # should load_connection_file only be used for existing?
236 # should load_connection_file only be used for existing?
237 # as it is now, this allows reusing ports if an existing
237 # as it is now, this allows reusing ports if an existing
238 # file is requested
238 # file is requested
239 try:
239 try:
240 self.load_connection_file()
240 self.load_connection_file()
241 except Exception:
241 except Exception:
242 self.log.error("Failed to load connection file: %r", self.connection_file, exc_info=True)
242 self.log.error("Failed to load connection file: %r", self.connection_file, exc_info=True)
243 self.exit(1)
243 self.exit(1)
244
244
245 def load_connection_file(self):
245 def load_connection_file(self):
246 """load ip/port/hmac config from JSON connection file"""
246 """load ip/port/hmac config from JSON connection file"""
247 # this is identical to IPKernelApp.load_connection_file
247 # this is identical to IPKernelApp.load_connection_file
248 # perhaps it can be centralized somewhere?
248 # perhaps it can be centralized somewhere?
249 try:
249 try:
250 fname = filefind(self.connection_file, ['.', self.profile_dir.security_dir])
250 fname = filefind(self.connection_file, ['.', self.profile_dir.security_dir])
251 except IOError:
251 except IOError:
252 self.log.debug("Connection File not found: %s", self.connection_file)
252 self.log.debug("Connection File not found: %s", self.connection_file)
253 return
253 return
254 self.log.debug(u"Loading connection file %s", fname)
254 self.log.debug(u"Loading connection file %s", fname)
255 with open(fname) as f:
255 with open(fname) as f:
256 cfg = json.load(f)
256 cfg = json.load(f)
257
257
258 self.config.KernelManager.transport = cfg.get('transport', 'tcp')
258 self.config.KernelManager.transport = cfg.get('transport', 'tcp')
259 self.config.KernelManager.ip = cfg.get('ip', LOCALHOST)
259 self.config.KernelManager.ip = cfg.get('ip', LOCALHOST)
260
260
261 for channel in ('hb', 'shell', 'iopub', 'stdin'):
261 for channel in ('hb', 'shell', 'iopub', 'stdin'):
262 name = channel + '_port'
262 name = channel + '_port'
263 if getattr(self, name) == 0 and name in cfg:
263 if getattr(self, name) == 0 and name in cfg:
264 # not overridden by config or cl_args
264 # not overridden by config or cl_args
265 setattr(self, name, cfg[name])
265 setattr(self, name, cfg[name])
266 if 'key' in cfg:
266 if 'key' in cfg:
267 self.config.Session.key = str_to_bytes(cfg['key'])
267 self.config.Session.key = str_to_bytes(cfg['key'])
268 if 'signature_scheme' in cfg:
269 self.config.Session.signature_scheme = cfg['signature_scheme']
268
270
269 def init_ssh(self):
271 def init_ssh(self):
270 """set up ssh tunnels, if needed."""
272 """set up ssh tunnels, if needed."""
271 if not self.existing or (not self.sshserver and not self.sshkey):
273 if not self.existing or (not self.sshserver and not self.sshkey):
272 return
274 return
273
275
274 self.load_connection_file()
276 self.load_connection_file()
275
277
276 transport = self.config.KernelManager.transport
278 transport = self.config.KernelManager.transport
277 ip = self.config.KernelManager.ip
279 ip = self.config.KernelManager.ip
278
280
279 if transport != 'tcp':
281 if transport != 'tcp':
280 self.log.error("Can only use ssh tunnels with TCP sockets, not %s", transport)
282 self.log.error("Can only use ssh tunnels with TCP sockets, not %s", transport)
281 sys.exit(-1)
283 sys.exit(-1)
282
284
283 if self.sshkey and not self.sshserver:
285 if self.sshkey and not self.sshserver:
284 # specifying just the key implies that we are connecting directly
286 # specifying just the key implies that we are connecting directly
285 self.sshserver = ip
287 self.sshserver = ip
286 ip = LOCALHOST
288 ip = LOCALHOST
287
289
288 # build connection dict for tunnels:
290 # build connection dict for tunnels:
289 info = dict(ip=ip,
291 info = dict(ip=ip,
290 shell_port=self.shell_port,
292 shell_port=self.shell_port,
291 iopub_port=self.iopub_port,
293 iopub_port=self.iopub_port,
292 stdin_port=self.stdin_port,
294 stdin_port=self.stdin_port,
293 hb_port=self.hb_port
295 hb_port=self.hb_port
294 )
296 )
295
297
296 self.log.info("Forwarding connections to %s via %s"%(ip, self.sshserver))
298 self.log.info("Forwarding connections to %s via %s"%(ip, self.sshserver))
297
299
298 # tunnels return a new set of ports, which will be on localhost:
300 # tunnels return a new set of ports, which will be on localhost:
299 self.config.KernelManager.ip = LOCALHOST
301 self.config.KernelManager.ip = LOCALHOST
300 try:
302 try:
301 newports = tunnel_to_kernel(info, self.sshserver, self.sshkey)
303 newports = tunnel_to_kernel(info, self.sshserver, self.sshkey)
302 except:
304 except:
303 # even catch KeyboardInterrupt
305 # even catch KeyboardInterrupt
304 self.log.error("Could not setup tunnels", exc_info=True)
306 self.log.error("Could not setup tunnels", exc_info=True)
305 self.exit(1)
307 self.exit(1)
306
308
307 self.shell_port, self.iopub_port, self.stdin_port, self.hb_port = newports
309 self.shell_port, self.iopub_port, self.stdin_port, self.hb_port = newports
308
310
309 cf = self.connection_file
311 cf = self.connection_file
310 base,ext = os.path.splitext(cf)
312 base,ext = os.path.splitext(cf)
311 base = os.path.basename(base)
313 base = os.path.basename(base)
312 self.connection_file = os.path.basename(base)+'-ssh'+ext
314 self.connection_file = os.path.basename(base)+'-ssh'+ext
313 self.log.critical("To connect another client via this tunnel, use:")
315 self.log.critical("To connect another client via this tunnel, use:")
314 self.log.critical("--existing %s" % self.connection_file)
316 self.log.critical("--existing %s" % self.connection_file)
315
317
316 def _new_connection_file(self):
318 def _new_connection_file(self):
317 cf = ''
319 cf = ''
318 while not cf:
320 while not cf:
319 # we don't need a 128b id to distinguish kernels, use more readable
321 # we don't need a 128b id to distinguish kernels, use more readable
320 # 48b node segment (12 hex chars). Users running more than 32k simultaneous
322 # 48b node segment (12 hex chars). Users running more than 32k simultaneous
321 # kernels can subclass.
323 # kernels can subclass.
322 ident = str(uuid.uuid4()).split('-')[-1]
324 ident = str(uuid.uuid4()).split('-')[-1]
323 cf = os.path.join(self.profile_dir.security_dir, 'kernel-%s.json' % ident)
325 cf = os.path.join(self.profile_dir.security_dir, 'kernel-%s.json' % ident)
324 # only keep if it's actually new. Protect against unlikely collision
326 # only keep if it's actually new. Protect against unlikely collision
325 # in 48b random search space
327 # in 48b random search space
326 cf = cf if not os.path.exists(cf) else ''
328 cf = cf if not os.path.exists(cf) else ''
327 return cf
329 return cf
328
330
329 def init_kernel_manager(self):
331 def init_kernel_manager(self):
330 # Don't let Qt or ZMQ swallow KeyboardInterupts.
332 # Don't let Qt or ZMQ swallow KeyboardInterupts.
331 if self.existing:
333 if self.existing:
332 self.kernel_manager = None
334 self.kernel_manager = None
333 return
335 return
334 signal.signal(signal.SIGINT, signal.SIG_DFL)
336 signal.signal(signal.SIGINT, signal.SIG_DFL)
335
337
336 # Create a KernelManager and start a kernel.
338 # Create a KernelManager and start a kernel.
337 self.kernel_manager = self.kernel_manager_class(
339 self.kernel_manager = self.kernel_manager_class(
338 shell_port=self.shell_port,
340 shell_port=self.shell_port,
339 iopub_port=self.iopub_port,
341 iopub_port=self.iopub_port,
340 stdin_port=self.stdin_port,
342 stdin_port=self.stdin_port,
341 hb_port=self.hb_port,
343 hb_port=self.hb_port,
342 connection_file=self.connection_file,
344 connection_file=self.connection_file,
343 parent=self,
345 parent=self,
344 )
346 )
345 self.kernel_manager.client_factory = self.kernel_client_class
347 self.kernel_manager.client_factory = self.kernel_client_class
346 self.kernel_manager.start_kernel(extra_arguments=self.kernel_argv)
348 self.kernel_manager.start_kernel(extra_arguments=self.kernel_argv)
347 atexit.register(self.kernel_manager.cleanup_ipc_files)
349 atexit.register(self.kernel_manager.cleanup_ipc_files)
348
350
349 if self.sshserver:
351 if self.sshserver:
350 # ssh, write new connection file
352 # ssh, write new connection file
351 self.kernel_manager.write_connection_file()
353 self.kernel_manager.write_connection_file()
352
354
353 # in case KM defaults / ssh writing changes things:
355 # in case KM defaults / ssh writing changes things:
354 km = self.kernel_manager
356 km = self.kernel_manager
355 self.shell_port=km.shell_port
357 self.shell_port=km.shell_port
356 self.iopub_port=km.iopub_port
358 self.iopub_port=km.iopub_port
357 self.stdin_port=km.stdin_port
359 self.stdin_port=km.stdin_port
358 self.hb_port=km.hb_port
360 self.hb_port=km.hb_port
359 self.connection_file = km.connection_file
361 self.connection_file = km.connection_file
360
362
361 atexit.register(self.kernel_manager.cleanup_connection_file)
363 atexit.register(self.kernel_manager.cleanup_connection_file)
362
364
363 def init_kernel_client(self):
365 def init_kernel_client(self):
364 if self.kernel_manager is not None:
366 if self.kernel_manager is not None:
365 self.kernel_client = self.kernel_manager.client()
367 self.kernel_client = self.kernel_manager.client()
366 else:
368 else:
367 self.kernel_client = self.kernel_client_class(
369 self.kernel_client = self.kernel_client_class(
368 shell_port=self.shell_port,
370 shell_port=self.shell_port,
369 iopub_port=self.iopub_port,
371 iopub_port=self.iopub_port,
370 stdin_port=self.stdin_port,
372 stdin_port=self.stdin_port,
371 hb_port=self.hb_port,
373 hb_port=self.hb_port,
372 connection_file=self.connection_file,
374 connection_file=self.connection_file,
373 parent=self,
375 parent=self,
374 )
376 )
375
377
376 self.kernel_client.start_channels()
378 self.kernel_client.start_channels()
377
379
378
380
379
381
380 def initialize(self, argv=None):
382 def initialize(self, argv=None):
381 """
383 """
382 Classes which mix this class in should call:
384 Classes which mix this class in should call:
383 IPythonConsoleApp.initialize(self,argv)
385 IPythonConsoleApp.initialize(self,argv)
384 """
386 """
385 self.init_connection_file()
387 self.init_connection_file()
386 default_secure(self.config)
388 default_secure(self.config)
387 self.init_ssh()
389 self.init_ssh()
388 self.init_kernel_manager()
390 self.init_kernel_manager()
389 self.init_kernel_client()
391 self.init_kernel_client()
390
392
@@ -1,540 +1,557 b''
1 """Utilities for connecting to kernels
1 """Utilities for connecting to kernels
2
2
3 Authors:
3 Authors:
4
4
5 * Min Ragan-Kelley
5 * Min Ragan-Kelley
6
6
7 """
7 """
8
8
9 #-----------------------------------------------------------------------------
9 #-----------------------------------------------------------------------------
10 # Copyright (C) 2013 The IPython Development Team
10 # Copyright (C) 2013 The IPython Development Team
11 #
11 #
12 # Distributed under the terms of the BSD License. The full license is in
12 # Distributed under the terms of the BSD License. The full license is in
13 # the file COPYING, distributed as part of this software.
13 # the file COPYING, distributed as part of this software.
14 #-----------------------------------------------------------------------------
14 #-----------------------------------------------------------------------------
15
15
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17 # Imports
17 # Imports
18 #-----------------------------------------------------------------------------
18 #-----------------------------------------------------------------------------
19
19
20 from __future__ import absolute_import
20 from __future__ import absolute_import
21
21
22 import glob
22 import glob
23 import json
23 import json
24 import os
24 import os
25 import socket
25 import socket
26 import sys
26 import sys
27 from getpass import getpass
27 from getpass import getpass
28 from subprocess import Popen, PIPE
28 from subprocess import Popen, PIPE
29 import tempfile
29 import tempfile
30
30
31 import zmq
31 import zmq
32
32
33 # external imports
33 # external imports
34 from IPython.external.ssh import tunnel
34 from IPython.external.ssh import tunnel
35
35
36 # IPython imports
36 # IPython imports
37 # from IPython.config import Configurable
37 # from IPython.config import Configurable
38 from IPython.core.profiledir import ProfileDir
38 from IPython.core.profiledir import ProfileDir
39 from IPython.utils.localinterfaces import LOCALHOST
39 from IPython.utils.localinterfaces import LOCALHOST
40 from IPython.utils.path import filefind, get_ipython_dir
40 from IPython.utils.path import filefind, get_ipython_dir
41 from IPython.utils.py3compat import str_to_bytes, bytes_to_str
41 from IPython.utils.py3compat import str_to_bytes, bytes_to_str
42 from IPython.utils.traitlets import (
42 from IPython.utils.traitlets import (
43 Bool, Integer, Unicode, CaselessStrEnum,
43 Bool, Integer, Unicode, CaselessStrEnum,
44 HasTraits,
44 HasTraits,
45 )
45 )
46
46
47
47
48 #-----------------------------------------------------------------------------
48 #-----------------------------------------------------------------------------
49 # Working with Connection Files
49 # Working with Connection Files
50 #-----------------------------------------------------------------------------
50 #-----------------------------------------------------------------------------
51
51
52 def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, hb_port=0,
52 def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, hb_port=0,
53 control_port=0, ip=LOCALHOST, key=b'', transport='tcp'):
53 control_port=0, ip=LOCALHOST, key=b'', transport='tcp',
54 signature_scheme='hmac-sha256',
55 ):
54 """Generates a JSON config file, including the selection of random ports.
56 """Generates a JSON config file, including the selection of random ports.
55
57
56 Parameters
58 Parameters
57 ----------
59 ----------
58
60
59 fname : unicode
61 fname : unicode
60 The path to the file to write
62 The path to the file to write
61
63
62 shell_port : int, optional
64 shell_port : int, optional
63 The port to use for ROUTER (shell) channel.
65 The port to use for ROUTER (shell) channel.
64
66
65 iopub_port : int, optional
67 iopub_port : int, optional
66 The port to use for the SUB channel.
68 The port to use for the SUB channel.
67
69
68 stdin_port : int, optional
70 stdin_port : int, optional
69 The port to use for the ROUTER (raw input) channel.
71 The port to use for the ROUTER (raw input) channel.
70
72
71 control_port : int, optional
73 control_port : int, optional
72 The port to use for the ROUTER (control) channel.
74 The port to use for the ROUTER (control) channel.
73
75
74 hb_port : int, optional
76 hb_port : int, optional
75 The port to use for the heartbeat REP channel.
77 The port to use for the heartbeat REP channel.
76
78
77 ip : str, optional
79 ip : str, optional
78 The ip address the kernel will bind to.
80 The ip address the kernel will bind to.
79
81
80 key : str, optional
82 key : str, optional
81 The Session key used for HMAC authentication.
83 The Session key used for message authentication.
84
85 signature_scheme : str, optional
86 The scheme used for message authentication.
87 This has the form 'digest-hash', where 'digest'
88 is the scheme used for digests, and 'hash' is the name of the hash function
89 used by the digest scheme.
90 Currently, 'hmac' is the only supported digest scheme,
91 and 'sha256' is the default hash function.
82
92
83 """
93 """
84 # default to temporary connector file
94 # default to temporary connector file
85 if not fname:
95 if not fname:
86 fname = tempfile.mktemp('.json')
96 fname = tempfile.mktemp('.json')
87
97
88 # Find open ports as necessary.
98 # Find open ports as necessary.
89
99
90 ports = []
100 ports = []
91 ports_needed = int(shell_port <= 0) + \
101 ports_needed = int(shell_port <= 0) + \
92 int(iopub_port <= 0) + \
102 int(iopub_port <= 0) + \
93 int(stdin_port <= 0) + \
103 int(stdin_port <= 0) + \
94 int(control_port <= 0) + \
104 int(control_port <= 0) + \
95 int(hb_port <= 0)
105 int(hb_port <= 0)
96 if transport == 'tcp':
106 if transport == 'tcp':
97 for i in range(ports_needed):
107 for i in range(ports_needed):
98 sock = socket.socket()
108 sock = socket.socket()
99 sock.bind(('', 0))
109 sock.bind(('', 0))
100 ports.append(sock)
110 ports.append(sock)
101 for i, sock in enumerate(ports):
111 for i, sock in enumerate(ports):
102 port = sock.getsockname()[1]
112 port = sock.getsockname()[1]
103 sock.close()
113 sock.close()
104 ports[i] = port
114 ports[i] = port
105 else:
115 else:
106 N = 1
116 N = 1
107 for i in range(ports_needed):
117 for i in range(ports_needed):
108 while os.path.exists("%s-%s" % (ip, str(N))):
118 while os.path.exists("%s-%s" % (ip, str(N))):
109 N += 1
119 N += 1
110 ports.append(N)
120 ports.append(N)
111 N += 1
121 N += 1
112 if shell_port <= 0:
122 if shell_port <= 0:
113 shell_port = ports.pop(0)
123 shell_port = ports.pop(0)
114 if iopub_port <= 0:
124 if iopub_port <= 0:
115 iopub_port = ports.pop(0)
125 iopub_port = ports.pop(0)
116 if stdin_port <= 0:
126 if stdin_port <= 0:
117 stdin_port = ports.pop(0)
127 stdin_port = ports.pop(0)
118 if control_port <= 0:
128 if control_port <= 0:
119 control_port = ports.pop(0)
129 control_port = ports.pop(0)
120 if hb_port <= 0:
130 if hb_port <= 0:
121 hb_port = ports.pop(0)
131 hb_port = ports.pop(0)
122
132
123 cfg = dict( shell_port=shell_port,
133 cfg = dict( shell_port=shell_port,
124 iopub_port=iopub_port,
134 iopub_port=iopub_port,
125 stdin_port=stdin_port,
135 stdin_port=stdin_port,
126 control_port=control_port,
136 control_port=control_port,
127 hb_port=hb_port,
137 hb_port=hb_port,
128 )
138 )
129 cfg['ip'] = ip
139 cfg['ip'] = ip
130 cfg['key'] = bytes_to_str(key)
140 cfg['key'] = bytes_to_str(key)
131 cfg['transport'] = transport
141 cfg['transport'] = transport
142 cfg['signature_scheme'] = signature_scheme
132
143
133 with open(fname, 'w') as f:
144 with open(fname, 'w') as f:
134 f.write(json.dumps(cfg, indent=2))
145 f.write(json.dumps(cfg, indent=2))
135
146
136 return fname, cfg
147 return fname, cfg
137
148
138
149
139 def get_connection_file(app=None):
150 def get_connection_file(app=None):
140 """Return the path to the connection file of an app
151 """Return the path to the connection file of an app
141
152
142 Parameters
153 Parameters
143 ----------
154 ----------
144 app : IPKernelApp instance [optional]
155 app : IPKernelApp instance [optional]
145 If unspecified, the currently running app will be used
156 If unspecified, the currently running app will be used
146 """
157 """
147 if app is None:
158 if app is None:
148 from IPython.kernel.zmq.kernelapp import IPKernelApp
159 from IPython.kernel.zmq.kernelapp import IPKernelApp
149 if not IPKernelApp.initialized():
160 if not IPKernelApp.initialized():
150 raise RuntimeError("app not specified, and not in a running Kernel")
161 raise RuntimeError("app not specified, and not in a running Kernel")
151
162
152 app = IPKernelApp.instance()
163 app = IPKernelApp.instance()
153 return filefind(app.connection_file, ['.', app.profile_dir.security_dir])
164 return filefind(app.connection_file, ['.', app.profile_dir.security_dir])
154
165
155
166
156 def find_connection_file(filename, profile=None):
167 def find_connection_file(filename, profile=None):
157 """find a connection file, and return its absolute path.
168 """find a connection file, and return its absolute path.
158
169
159 The current working directory and the profile's security
170 The current working directory and the profile's security
160 directory will be searched for the file if it is not given by
171 directory will be searched for the file if it is not given by
161 absolute path.
172 absolute path.
162
173
163 If profile is unspecified, then the current running application's
174 If profile is unspecified, then the current running application's
164 profile will be used, or 'default', if not run from IPython.
175 profile will be used, or 'default', if not run from IPython.
165
176
166 If the argument does not match an existing file, it will be interpreted as a
177 If the argument does not match an existing file, it will be interpreted as a
167 fileglob, and the matching file in the profile's security dir with
178 fileglob, and the matching file in the profile's security dir with
168 the latest access time will be used.
179 the latest access time will be used.
169
180
170 Parameters
181 Parameters
171 ----------
182 ----------
172 filename : str
183 filename : str
173 The connection file or fileglob to search for.
184 The connection file or fileglob to search for.
174 profile : str [optional]
185 profile : str [optional]
175 The name of the profile to use when searching for the connection file,
186 The name of the profile to use when searching for the connection file,
176 if different from the current IPython session or 'default'.
187 if different from the current IPython session or 'default'.
177
188
178 Returns
189 Returns
179 -------
190 -------
180 str : The absolute path of the connection file.
191 str : The absolute path of the connection file.
181 """
192 """
182 from IPython.core.application import BaseIPythonApplication as IPApp
193 from IPython.core.application import BaseIPythonApplication as IPApp
183 try:
194 try:
184 # quick check for absolute path, before going through logic
195 # quick check for absolute path, before going through logic
185 return filefind(filename)
196 return filefind(filename)
186 except IOError:
197 except IOError:
187 pass
198 pass
188
199
189 if profile is None:
200 if profile is None:
190 # profile unspecified, check if running from an IPython app
201 # profile unspecified, check if running from an IPython app
191 if IPApp.initialized():
202 if IPApp.initialized():
192 app = IPApp.instance()
203 app = IPApp.instance()
193 profile_dir = app.profile_dir
204 profile_dir = app.profile_dir
194 else:
205 else:
195 # not running in IPython, use default profile
206 # not running in IPython, use default profile
196 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), 'default')
207 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), 'default')
197 else:
208 else:
198 # find profiledir by profile name:
209 # find profiledir by profile name:
199 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), profile)
210 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), profile)
200 security_dir = profile_dir.security_dir
211 security_dir = profile_dir.security_dir
201
212
202 try:
213 try:
203 # first, try explicit name
214 # first, try explicit name
204 return filefind(filename, ['.', security_dir])
215 return filefind(filename, ['.', security_dir])
205 except IOError:
216 except IOError:
206 pass
217 pass
207
218
208 # not found by full name
219 # not found by full name
209
220
210 if '*' in filename:
221 if '*' in filename:
211 # given as a glob already
222 # given as a glob already
212 pat = filename
223 pat = filename
213 else:
224 else:
214 # accept any substring match
225 # accept any substring match
215 pat = '*%s*' % filename
226 pat = '*%s*' % filename
216 matches = glob.glob( os.path.join(security_dir, pat) )
227 matches = glob.glob( os.path.join(security_dir, pat) )
217 if not matches:
228 if not matches:
218 raise IOError("Could not find %r in %r" % (filename, security_dir))
229 raise IOError("Could not find %r in %r" % (filename, security_dir))
219 elif len(matches) == 1:
230 elif len(matches) == 1:
220 return matches[0]
231 return matches[0]
221 else:
232 else:
222 # get most recent match, by access time:
233 # get most recent match, by access time:
223 return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
234 return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
224
235
225
236
226 def get_connection_info(connection_file=None, unpack=False, profile=None):
237 def get_connection_info(connection_file=None, unpack=False, profile=None):
227 """Return the connection information for the current Kernel.
238 """Return the connection information for the current Kernel.
228
239
229 Parameters
240 Parameters
230 ----------
241 ----------
231 connection_file : str [optional]
242 connection_file : str [optional]
232 The connection file to be used. Can be given by absolute path, or
243 The connection file to be used. Can be given by absolute path, or
233 IPython will search in the security directory of a given profile.
244 IPython will search in the security directory of a given profile.
234 If run from IPython,
245 If run from IPython,
235
246
236 If unspecified, the connection file for the currently running
247 If unspecified, the connection file for the currently running
237 IPython Kernel will be used, which is only allowed from inside a kernel.
248 IPython Kernel will be used, which is only allowed from inside a kernel.
238 unpack : bool [default: False]
249 unpack : bool [default: False]
239 if True, return the unpacked dict, otherwise just the string contents
250 if True, return the unpacked dict, otherwise just the string contents
240 of the file.
251 of the file.
241 profile : str [optional]
252 profile : str [optional]
242 The name of the profile to use when searching for the connection file,
253 The name of the profile to use when searching for the connection file,
243 if different from the current IPython session or 'default'.
254 if different from the current IPython session or 'default'.
244
255
245
256
246 Returns
257 Returns
247 -------
258 -------
248 The connection dictionary of the current kernel, as string or dict,
259 The connection dictionary of the current kernel, as string or dict,
249 depending on `unpack`.
260 depending on `unpack`.
250 """
261 """
251 if connection_file is None:
262 if connection_file is None:
252 # get connection file from current kernel
263 # get connection file from current kernel
253 cf = get_connection_file()
264 cf = get_connection_file()
254 else:
265 else:
255 # connection file specified, allow shortnames:
266 # connection file specified, allow shortnames:
256 cf = find_connection_file(connection_file, profile=profile)
267 cf = find_connection_file(connection_file, profile=profile)
257
268
258 with open(cf) as f:
269 with open(cf) as f:
259 info = f.read()
270 info = f.read()
260
271
261 if unpack:
272 if unpack:
262 info = json.loads(info)
273 info = json.loads(info)
263 # ensure key is bytes:
274 # ensure key is bytes:
264 info['key'] = str_to_bytes(info.get('key', ''))
275 info['key'] = str_to_bytes(info.get('key', ''))
265 return info
276 return info
266
277
267
278
268 def connect_qtconsole(connection_file=None, argv=None, profile=None):
279 def connect_qtconsole(connection_file=None, argv=None, profile=None):
269 """Connect a qtconsole to the current kernel.
280 """Connect a qtconsole to the current kernel.
270
281
271 This is useful for connecting a second qtconsole to a kernel, or to a
282 This is useful for connecting a second qtconsole to a kernel, or to a
272 local notebook.
283 local notebook.
273
284
274 Parameters
285 Parameters
275 ----------
286 ----------
276 connection_file : str [optional]
287 connection_file : str [optional]
277 The connection file to be used. Can be given by absolute path, or
288 The connection file to be used. Can be given by absolute path, or
278 IPython will search in the security directory of a given profile.
289 IPython will search in the security directory of a given profile.
279 If run from IPython,
290 If run from IPython,
280
291
281 If unspecified, the connection file for the currently running
292 If unspecified, the connection file for the currently running
282 IPython Kernel will be used, which is only allowed from inside a kernel.
293 IPython Kernel will be used, which is only allowed from inside a kernel.
283 argv : list [optional]
294 argv : list [optional]
284 Any extra args to be passed to the console.
295 Any extra args to be passed to the console.
285 profile : str [optional]
296 profile : str [optional]
286 The name of the profile to use when searching for the connection file,
297 The name of the profile to use when searching for the connection file,
287 if different from the current IPython session or 'default'.
298 if different from the current IPython session or 'default'.
288
299
289
300
290 Returns
301 Returns
291 -------
302 -------
292 subprocess.Popen instance running the qtconsole frontend
303 subprocess.Popen instance running the qtconsole frontend
293 """
304 """
294 argv = [] if argv is None else argv
305 argv = [] if argv is None else argv
295
306
296 if connection_file is None:
307 if connection_file is None:
297 # get connection file from current kernel
308 # get connection file from current kernel
298 cf = get_connection_file()
309 cf = get_connection_file()
299 else:
310 else:
300 cf = find_connection_file(connection_file, profile=profile)
311 cf = find_connection_file(connection_file, profile=profile)
301
312
302 cmd = ';'.join([
313 cmd = ';'.join([
303 "from IPython.qt.console import qtconsoleapp",
314 "from IPython.qt.console import qtconsoleapp",
304 "qtconsoleapp.main()"
315 "qtconsoleapp.main()"
305 ])
316 ])
306
317
307 return Popen([sys.executable, '-c', cmd, '--existing', cf] + argv,
318 return Popen([sys.executable, '-c', cmd, '--existing', cf] + argv,
308 stdout=PIPE, stderr=PIPE, close_fds=True,
319 stdout=PIPE, stderr=PIPE, close_fds=True,
309 )
320 )
310
321
311
322
312 def tunnel_to_kernel(connection_info, sshserver, sshkey=None):
323 def tunnel_to_kernel(connection_info, sshserver, sshkey=None):
313 """tunnel connections to a kernel via ssh
324 """tunnel connections to a kernel via ssh
314
325
315 This will open four SSH tunnels from localhost on this machine to the
326 This will open four SSH tunnels from localhost on this machine to the
316 ports associated with the kernel. They can be either direct
327 ports associated with the kernel. They can be either direct
317 localhost-localhost tunnels, or if an intermediate server is necessary,
328 localhost-localhost tunnels, or if an intermediate server is necessary,
318 the kernel must be listening on a public IP.
329 the kernel must be listening on a public IP.
319
330
320 Parameters
331 Parameters
321 ----------
332 ----------
322 connection_info : dict or str (path)
333 connection_info : dict or str (path)
323 Either a connection dict, or the path to a JSON connection file
334 Either a connection dict, or the path to a JSON connection file
324 sshserver : str
335 sshserver : str
325 The ssh sever to use to tunnel to the kernel. Can be a full
336 The ssh sever to use to tunnel to the kernel. Can be a full
326 `user@server:port` string. ssh config aliases are respected.
337 `user@server:port` string. ssh config aliases are respected.
327 sshkey : str [optional]
338 sshkey : str [optional]
328 Path to file containing ssh key to use for authentication.
339 Path to file containing ssh key to use for authentication.
329 Only necessary if your ssh config does not already associate
340 Only necessary if your ssh config does not already associate
330 a keyfile with the host.
341 a keyfile with the host.
331
342
332 Returns
343 Returns
333 -------
344 -------
334
345
335 (shell, iopub, stdin, hb) : ints
346 (shell, iopub, stdin, hb) : ints
336 The four ports on localhost that have been forwarded to the kernel.
347 The four ports on localhost that have been forwarded to the kernel.
337 """
348 """
338 if isinstance(connection_info, basestring):
349 if isinstance(connection_info, basestring):
339 # it's a path, unpack it
350 # it's a path, unpack it
340 with open(connection_info) as f:
351 with open(connection_info) as f:
341 connection_info = json.loads(f.read())
352 connection_info = json.loads(f.read())
342
353
343 cf = connection_info
354 cf = connection_info
344
355
345 lports = tunnel.select_random_ports(4)
356 lports = tunnel.select_random_ports(4)
346 rports = cf['shell_port'], cf['iopub_port'], cf['stdin_port'], cf['hb_port']
357 rports = cf['shell_port'], cf['iopub_port'], cf['stdin_port'], cf['hb_port']
347
358
348 remote_ip = cf['ip']
359 remote_ip = cf['ip']
349
360
350 if tunnel.try_passwordless_ssh(sshserver, sshkey):
361 if tunnel.try_passwordless_ssh(sshserver, sshkey):
351 password=False
362 password=False
352 else:
363 else:
353 password = getpass("SSH Password for %s: "%sshserver)
364 password = getpass("SSH Password for %s: "%sshserver)
354
365
355 for lp,rp in zip(lports, rports):
366 for lp,rp in zip(lports, rports):
356 tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
367 tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
357
368
358 return tuple(lports)
369 return tuple(lports)
359
370
360
371
361 #-----------------------------------------------------------------------------
372 #-----------------------------------------------------------------------------
362 # Mixin for classes that work with connection files
373 # Mixin for classes that work with connection files
363 #-----------------------------------------------------------------------------
374 #-----------------------------------------------------------------------------
364
375
365 channel_socket_types = {
376 channel_socket_types = {
366 'hb' : zmq.REQ,
377 'hb' : zmq.REQ,
367 'shell' : zmq.DEALER,
378 'shell' : zmq.DEALER,
368 'iopub' : zmq.SUB,
379 'iopub' : zmq.SUB,
369 'stdin' : zmq.DEALER,
380 'stdin' : zmq.DEALER,
370 'control': zmq.DEALER,
381 'control': zmq.DEALER,
371 }
382 }
372
383
373 port_names = [ "%s_port" % channel for channel in ('shell', 'stdin', 'iopub', 'hb', 'control')]
384 port_names = [ "%s_port" % channel for channel in ('shell', 'stdin', 'iopub', 'hb', 'control')]
374
385
375 class ConnectionFileMixin(HasTraits):
386 class ConnectionFileMixin(HasTraits):
376 """Mixin for configurable classes that work with connection files"""
387 """Mixin for configurable classes that work with connection files"""
377
388
378 # The addresses for the communication channels
389 # The addresses for the communication channels
379 connection_file = Unicode('')
390 connection_file = Unicode('')
380 _connection_file_written = Bool(False)
391 _connection_file_written = Bool(False)
381
392
382 transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True)
393 transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True)
394 signature_scheme = Unicode('')
383
395
384 ip = Unicode(LOCALHOST, config=True,
396 ip = Unicode(LOCALHOST, config=True,
385 help="""Set the kernel\'s IP address [default localhost].
397 help="""Set the kernel\'s IP address [default localhost].
386 If the IP address is something other than localhost, then
398 If the IP address is something other than localhost, then
387 Consoles on other machines will be able to connect
399 Consoles on other machines will be able to connect
388 to the Kernel, so be careful!"""
400 to the Kernel, so be careful!"""
389 )
401 )
390
402
391 def _ip_default(self):
403 def _ip_default(self):
392 if self.transport == 'ipc':
404 if self.transport == 'ipc':
393 if self.connection_file:
405 if self.connection_file:
394 return os.path.splitext(self.connection_file)[0] + '-ipc'
406 return os.path.splitext(self.connection_file)[0] + '-ipc'
395 else:
407 else:
396 return 'kernel-ipc'
408 return 'kernel-ipc'
397 else:
409 else:
398 return LOCALHOST
410 return LOCALHOST
399
411
400 def _ip_changed(self, name, old, new):
412 def _ip_changed(self, name, old, new):
401 if new == '*':
413 if new == '*':
402 self.ip = '0.0.0.0'
414 self.ip = '0.0.0.0'
403
415
404 # protected traits
416 # protected traits
405
417
406 shell_port = Integer(0)
418 shell_port = Integer(0)
407 iopub_port = Integer(0)
419 iopub_port = Integer(0)
408 stdin_port = Integer(0)
420 stdin_port = Integer(0)
409 control_port = Integer(0)
421 control_port = Integer(0)
410 hb_port = Integer(0)
422 hb_port = Integer(0)
411
423
412 @property
424 @property
413 def ports(self):
425 def ports(self):
414 return [ getattr(self, name) for name in port_names ]
426 return [ getattr(self, name) for name in port_names ]
415
427
416 #--------------------------------------------------------------------------
428 #--------------------------------------------------------------------------
417 # Connection and ipc file management
429 # Connection and ipc file management
418 #--------------------------------------------------------------------------
430 #--------------------------------------------------------------------------
419
431
420 def get_connection_info(self):
432 def get_connection_info(self):
421 """return the connection info as a dict"""
433 """return the connection info as a dict"""
422 return dict(
434 return dict(
423 transport=self.transport,
435 transport=self.transport,
424 ip=self.ip,
436 ip=self.ip,
425 shell_port=self.shell_port,
437 shell_port=self.shell_port,
426 iopub_port=self.iopub_port,
438 iopub_port=self.iopub_port,
427 stdin_port=self.stdin_port,
439 stdin_port=self.stdin_port,
428 hb_port=self.hb_port,
440 hb_port=self.hb_port,
429 control_port=self.control_port,
441 control_port=self.control_port,
442 signature_scheme=self.signature_scheme,
430 )
443 )
431
444
432 def cleanup_connection_file(self):
445 def cleanup_connection_file(self):
433 """Cleanup connection file *if we wrote it*
446 """Cleanup connection file *if we wrote it*
434
447
435 Will not raise if the connection file was already removed somehow.
448 Will not raise if the connection file was already removed somehow.
436 """
449 """
437 if self._connection_file_written:
450 if self._connection_file_written:
438 # cleanup connection files on full shutdown of kernel we started
451 # cleanup connection files on full shutdown of kernel we started
439 self._connection_file_written = False
452 self._connection_file_written = False
440 try:
453 try:
441 os.remove(self.connection_file)
454 os.remove(self.connection_file)
442 except (IOError, OSError, AttributeError):
455 except (IOError, OSError, AttributeError):
443 pass
456 pass
444
457
445 def cleanup_ipc_files(self):
458 def cleanup_ipc_files(self):
446 """Cleanup ipc files if we wrote them."""
459 """Cleanup ipc files if we wrote them."""
447 if self.transport != 'ipc':
460 if self.transport != 'ipc':
448 return
461 return
449 for port in self.ports:
462 for port in self.ports:
450 ipcfile = "%s-%i" % (self.ip, port)
463 ipcfile = "%s-%i" % (self.ip, port)
451 try:
464 try:
452 os.remove(ipcfile)
465 os.remove(ipcfile)
453 except (IOError, OSError):
466 except (IOError, OSError):
454 pass
467 pass
455
468
456 def write_connection_file(self):
469 def write_connection_file(self):
457 """Write connection info to JSON dict in self.connection_file."""
470 """Write connection info to JSON dict in self.connection_file."""
458 if self._connection_file_written:
471 if self._connection_file_written:
459 return
472 return
460
473
461 self.connection_file, cfg = write_connection_file(self.connection_file,
474 self.connection_file, cfg = write_connection_file(self.connection_file,
462 transport=self.transport, ip=self.ip, key=self.session.key,
475 transport=self.transport, ip=self.ip, key=self.session.key,
463 stdin_port=self.stdin_port, iopub_port=self.iopub_port,
476 stdin_port=self.stdin_port, iopub_port=self.iopub_port,
464 shell_port=self.shell_port, hb_port=self.hb_port,
477 shell_port=self.shell_port, hb_port=self.hb_port,
465 control_port=self.control_port,
478 control_port=self.control_port,
479 signature_scheme=self.signature_scheme,
466 )
480 )
467 # write_connection_file also sets default ports:
481 # write_connection_file also sets default ports:
468 for name in port_names:
482 for name in port_names:
469 setattr(self, name, cfg[name])
483 setattr(self, name, cfg[name])
470
484
471 self._connection_file_written = True
485 self._connection_file_written = True
472
486
473 def load_connection_file(self):
487 def load_connection_file(self):
474 """Load connection info from JSON dict in self.connection_file."""
488 """Load connection info from JSON dict in self.connection_file."""
475 with open(self.connection_file) as f:
489 with open(self.connection_file) as f:
476 cfg = json.loads(f.read())
490 cfg = json.loads(f.read())
477
491
478 self.transport = cfg.get('transport', 'tcp')
492 self.transport = cfg.get('transport', 'tcp')
479 self.ip = cfg['ip']
493 self.ip = cfg['ip']
480 for name in port_names:
494 for name in port_names:
481 setattr(self, name, cfg[name])
495 setattr(self, name, cfg[name])
482 self.session.key = str_to_bytes(cfg['key'])
496 if 'key' in cfg:
497 self.session.key = str_to_bytes(cfg['key'])
498 if cfg.get('signature_scheme'):
499 self.session.signature_scheme = cfg['signature_scheme']
483
500
484 #--------------------------------------------------------------------------
501 #--------------------------------------------------------------------------
485 # Creating connected sockets
502 # Creating connected sockets
486 #--------------------------------------------------------------------------
503 #--------------------------------------------------------------------------
487
504
488 def _make_url(self, channel):
505 def _make_url(self, channel):
489 """Make a ZeroMQ URL for a given channel."""
506 """Make a ZeroMQ URL for a given channel."""
490 transport = self.transport
507 transport = self.transport
491 ip = self.ip
508 ip = self.ip
492 port = getattr(self, '%s_port' % channel)
509 port = getattr(self, '%s_port' % channel)
493
510
494 if transport == 'tcp':
511 if transport == 'tcp':
495 return "tcp://%s:%i" % (ip, port)
512 return "tcp://%s:%i" % (ip, port)
496 else:
513 else:
497 return "%s://%s-%s" % (transport, ip, port)
514 return "%s://%s-%s" % (transport, ip, port)
498
515
499 def _create_connected_socket(self, channel, identity=None):
516 def _create_connected_socket(self, channel, identity=None):
500 """Create a zmq Socket and connect it to the kernel."""
517 """Create a zmq Socket and connect it to the kernel."""
501 url = self._make_url(channel)
518 url = self._make_url(channel)
502 socket_type = channel_socket_types[channel]
519 socket_type = channel_socket_types[channel]
503 self.log.info("Connecting to: %s" % url)
520 self.log.info("Connecting to: %s" % url)
504 sock = self.context.socket(socket_type)
521 sock = self.context.socket(socket_type)
505 if identity:
522 if identity:
506 sock.identity = identity
523 sock.identity = identity
507 sock.connect(url)
524 sock.connect(url)
508 return sock
525 return sock
509
526
510 def connect_iopub(self, identity=None):
527 def connect_iopub(self, identity=None):
511 """return zmq Socket connected to the IOPub channel"""
528 """return zmq Socket connected to the IOPub channel"""
512 sock = self._create_connected_socket('iopub', identity=identity)
529 sock = self._create_connected_socket('iopub', identity=identity)
513 sock.setsockopt(zmq.SUBSCRIBE, b'')
530 sock.setsockopt(zmq.SUBSCRIBE, b'')
514 return sock
531 return sock
515
532
516 def connect_shell(self, identity=None):
533 def connect_shell(self, identity=None):
517 """return zmq Socket connected to the Shell channel"""
534 """return zmq Socket connected to the Shell channel"""
518 return self._create_connected_socket('shell', identity=identity)
535 return self._create_connected_socket('shell', identity=identity)
519
536
520 def connect_stdin(self, identity=None):
537 def connect_stdin(self, identity=None):
521 """return zmq Socket connected to the StdIn channel"""
538 """return zmq Socket connected to the StdIn channel"""
522 return self._create_connected_socket('stdin', identity=identity)
539 return self._create_connected_socket('stdin', identity=identity)
523
540
524 def connect_hb(self, identity=None):
541 def connect_hb(self, identity=None):
525 """return zmq Socket connected to the Heartbeat channel"""
542 """return zmq Socket connected to the Heartbeat channel"""
526 return self._create_connected_socket('hb', identity=identity)
543 return self._create_connected_socket('hb', identity=identity)
527
544
528 def connect_control(self, identity=None):
545 def connect_control(self, identity=None):
529 """return zmq Socket connected to the Heartbeat channel"""
546 """return zmq Socket connected to the Heartbeat channel"""
530 return self._create_connected_socket('control', identity=identity)
547 return self._create_connected_socket('control', identity=identity)
531
548
532
549
533 __all__ = [
550 __all__ = [
534 'write_connection_file',
551 'write_connection_file',
535 'get_connection_file',
552 'get_connection_file',
536 'find_connection_file',
553 'find_connection_file',
537 'get_connection_info',
554 'get_connection_info',
538 'connect_qtconsole',
555 'connect_qtconsole',
539 'tunnel_to_kernel',
556 'tunnel_to_kernel',
540 ]
557 ]
@@ -1,806 +1,830 b''
1 """Session object for building, serializing, sending, and receiving messages in
1 """Session object for building, serializing, sending, and receiving messages in
2 IPython. The Session object supports serialization, HMAC signatures, and
2 IPython. The Session object supports serialization, HMAC signatures, and
3 metadata on messages.
3 metadata on messages.
4
4
5 Also defined here are utilities for working with Sessions:
5 Also defined here are utilities for working with Sessions:
6 * A SessionFactory to be used as a base class for configurables that work with
6 * A SessionFactory to be used as a base class for configurables that work with
7 Sessions.
7 Sessions.
8 * A Message object for convenience that allows attribute-access to the msg dict.
8 * A Message object for convenience that allows attribute-access to the msg dict.
9
9
10 Authors:
10 Authors:
11
11
12 * Min RK
12 * Min RK
13 * Brian Granger
13 * Brian Granger
14 * Fernando Perez
14 * Fernando Perez
15 """
15 """
16 #-----------------------------------------------------------------------------
16 #-----------------------------------------------------------------------------
17 # Copyright (C) 2010-2011 The IPython Development Team
17 # Copyright (C) 2010-2011 The IPython Development Team
18 #
18 #
19 # Distributed under the terms of the BSD License. The full license is in
19 # Distributed under the terms of the BSD License. The full license is in
20 # the file COPYING, distributed as part of this software.
20 # the file COPYING, distributed as part of this software.
21 #-----------------------------------------------------------------------------
21 #-----------------------------------------------------------------------------
22
22
23 #-----------------------------------------------------------------------------
23 #-----------------------------------------------------------------------------
24 # Imports
24 # Imports
25 #-----------------------------------------------------------------------------
25 #-----------------------------------------------------------------------------
26
26
27 import hashlib
27 import hmac
28 import hmac
28 import logging
29 import logging
29 import os
30 import os
30 import pprint
31 import pprint
31 import random
32 import random
32 import uuid
33 import uuid
33 from datetime import datetime
34 from datetime import datetime
34
35
35 try:
36 try:
36 import cPickle
37 import cPickle
37 pickle = cPickle
38 pickle = cPickle
38 except:
39 except:
39 cPickle = None
40 cPickle = None
40 import pickle
41 import pickle
41
42
42 import zmq
43 import zmq
43 from zmq.utils import jsonapi
44 from zmq.utils import jsonapi
44 from zmq.eventloop.ioloop import IOLoop
45 from zmq.eventloop.ioloop import IOLoop
45 from zmq.eventloop.zmqstream import ZMQStream
46 from zmq.eventloop.zmqstream import ZMQStream
46
47
47 from IPython.config.configurable import Configurable, LoggingConfigurable
48 from IPython.config.configurable import Configurable, LoggingConfigurable
48 from IPython.utils import io
49 from IPython.utils import io
49 from IPython.utils.importstring import import_item
50 from IPython.utils.importstring import import_item
50 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
51 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
51 from IPython.utils.py3compat import str_to_bytes, str_to_unicode
52 from IPython.utils.py3compat import str_to_bytes, str_to_unicode
52 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
53 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
53 DottedObjectName, CUnicode, Dict, Integer)
54 DottedObjectName, CUnicode, Dict, Integer,
55 TraitError,
56 )
54 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
57 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
55
58
56 #-----------------------------------------------------------------------------
59 #-----------------------------------------------------------------------------
57 # utility functions
60 # utility functions
58 #-----------------------------------------------------------------------------
61 #-----------------------------------------------------------------------------
59
62
60 def squash_unicode(obj):
63 def squash_unicode(obj):
61 """coerce unicode back to bytestrings."""
64 """coerce unicode back to bytestrings."""
62 if isinstance(obj,dict):
65 if isinstance(obj,dict):
63 for key in obj.keys():
66 for key in obj.keys():
64 obj[key] = squash_unicode(obj[key])
67 obj[key] = squash_unicode(obj[key])
65 if isinstance(key, unicode):
68 if isinstance(key, unicode):
66 obj[squash_unicode(key)] = obj.pop(key)
69 obj[squash_unicode(key)] = obj.pop(key)
67 elif isinstance(obj, list):
70 elif isinstance(obj, list):
68 for i,v in enumerate(obj):
71 for i,v in enumerate(obj):
69 obj[i] = squash_unicode(v)
72 obj[i] = squash_unicode(v)
70 elif isinstance(obj, unicode):
73 elif isinstance(obj, unicode):
71 obj = obj.encode('utf8')
74 obj = obj.encode('utf8')
72 return obj
75 return obj
73
76
74 #-----------------------------------------------------------------------------
77 #-----------------------------------------------------------------------------
75 # globals and defaults
78 # globals and defaults
76 #-----------------------------------------------------------------------------
79 #-----------------------------------------------------------------------------
77
80
78 # ISO8601-ify datetime objects
81 # ISO8601-ify datetime objects
79 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default)
82 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default)
80 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
83 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
81
84
82 pickle_packer = lambda o: pickle.dumps(o,-1)
85 pickle_packer = lambda o: pickle.dumps(o,-1)
83 pickle_unpacker = pickle.loads
86 pickle_unpacker = pickle.loads
84
87
85 default_packer = json_packer
88 default_packer = json_packer
86 default_unpacker = json_unpacker
89 default_unpacker = json_unpacker
87
90
88 DELIM = b"<IDS|MSG>"
91 DELIM = b"<IDS|MSG>"
89 # singleton dummy tracker, which will always report as done
92 # singleton dummy tracker, which will always report as done
90 DONE = zmq.MessageTracker()
93 DONE = zmq.MessageTracker()
91
94
92 #-----------------------------------------------------------------------------
95 #-----------------------------------------------------------------------------
93 # Mixin tools for apps that use Sessions
96 # Mixin tools for apps that use Sessions
94 #-----------------------------------------------------------------------------
97 #-----------------------------------------------------------------------------
95
98
96 session_aliases = dict(
99 session_aliases = dict(
97 ident = 'Session.session',
100 ident = 'Session.session',
98 user = 'Session.username',
101 user = 'Session.username',
99 keyfile = 'Session.keyfile',
102 keyfile = 'Session.keyfile',
100 )
103 )
101
104
102 session_flags = {
105 session_flags = {
103 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
106 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
104 'keyfile' : '' }},
107 'keyfile' : '' }},
105 """Use HMAC digests for authentication of messages.
108 """Use HMAC digests for authentication of messages.
106 Setting this flag will generate a new UUID to use as the HMAC key.
109 Setting this flag will generate a new UUID to use as the HMAC key.
107 """),
110 """),
108 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
111 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
109 """Don't authenticate messages."""),
112 """Don't authenticate messages."""),
110 }
113 }
111
114
112 def default_secure(cfg):
115 def default_secure(cfg):
113 """Set the default behavior for a config environment to be secure.
116 """Set the default behavior for a config environment to be secure.
114
117
115 If Session.key/keyfile have not been set, set Session.key to
118 If Session.key/keyfile have not been set, set Session.key to
116 a new random UUID.
119 a new random UUID.
117 """
120 """
118
121
119 if 'Session' in cfg:
122 if 'Session' in cfg:
120 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
123 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
121 return
124 return
122 # key/keyfile not specified, generate new UUID:
125 # key/keyfile not specified, generate new UUID:
123 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
126 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
124
127
125
128
126 #-----------------------------------------------------------------------------
129 #-----------------------------------------------------------------------------
127 # Classes
130 # Classes
128 #-----------------------------------------------------------------------------
131 #-----------------------------------------------------------------------------
129
132
130 class SessionFactory(LoggingConfigurable):
133 class SessionFactory(LoggingConfigurable):
131 """The Base class for configurables that have a Session, Context, logger,
134 """The Base class for configurables that have a Session, Context, logger,
132 and IOLoop.
135 and IOLoop.
133 """
136 """
134
137
135 logname = Unicode('')
138 logname = Unicode('')
136 def _logname_changed(self, name, old, new):
139 def _logname_changed(self, name, old, new):
137 self.log = logging.getLogger(new)
140 self.log = logging.getLogger(new)
138
141
139 # not configurable:
142 # not configurable:
140 context = Instance('zmq.Context')
143 context = Instance('zmq.Context')
141 def _context_default(self):
144 def _context_default(self):
142 return zmq.Context.instance()
145 return zmq.Context.instance()
143
146
144 session = Instance('IPython.kernel.zmq.session.Session')
147 session = Instance('IPython.kernel.zmq.session.Session')
145
148
146 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
149 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
147 def _loop_default(self):
150 def _loop_default(self):
148 return IOLoop.instance()
151 return IOLoop.instance()
149
152
150 def __init__(self, **kwargs):
153 def __init__(self, **kwargs):
151 super(SessionFactory, self).__init__(**kwargs)
154 super(SessionFactory, self).__init__(**kwargs)
152
155
153 if self.session is None:
156 if self.session is None:
154 # construct the session
157 # construct the session
155 self.session = Session(**kwargs)
158 self.session = Session(**kwargs)
156
159
157
160
158 class Message(object):
161 class Message(object):
159 """A simple message object that maps dict keys to attributes.
162 """A simple message object that maps dict keys to attributes.
160
163
161 A Message can be created from a dict and a dict from a Message instance
164 A Message can be created from a dict and a dict from a Message instance
162 simply by calling dict(msg_obj)."""
165 simply by calling dict(msg_obj)."""
163
166
164 def __init__(self, msg_dict):
167 def __init__(self, msg_dict):
165 dct = self.__dict__
168 dct = self.__dict__
166 for k, v in dict(msg_dict).iteritems():
169 for k, v in dict(msg_dict).iteritems():
167 if isinstance(v, dict):
170 if isinstance(v, dict):
168 v = Message(v)
171 v = Message(v)
169 dct[k] = v
172 dct[k] = v
170
173
171 # Having this iterator lets dict(msg_obj) work out of the box.
174 # Having this iterator lets dict(msg_obj) work out of the box.
172 def __iter__(self):
175 def __iter__(self):
173 return iter(self.__dict__.iteritems())
176 return iter(self.__dict__.iteritems())
174
177
175 def __repr__(self):
178 def __repr__(self):
176 return repr(self.__dict__)
179 return repr(self.__dict__)
177
180
178 def __str__(self):
181 def __str__(self):
179 return pprint.pformat(self.__dict__)
182 return pprint.pformat(self.__dict__)
180
183
181 def __contains__(self, k):
184 def __contains__(self, k):
182 return k in self.__dict__
185 return k in self.__dict__
183
186
184 def __getitem__(self, k):
187 def __getitem__(self, k):
185 return self.__dict__[k]
188 return self.__dict__[k]
186
189
187
190
188 def msg_header(msg_id, msg_type, username, session):
191 def msg_header(msg_id, msg_type, username, session):
189 date = datetime.now()
192 date = datetime.now()
190 return locals()
193 return locals()
191
194
192 def extract_header(msg_or_header):
195 def extract_header(msg_or_header):
193 """Given a message or header, return the header."""
196 """Given a message or header, return the header."""
194 if not msg_or_header:
197 if not msg_or_header:
195 return {}
198 return {}
196 try:
199 try:
197 # See if msg_or_header is the entire message.
200 # See if msg_or_header is the entire message.
198 h = msg_or_header['header']
201 h = msg_or_header['header']
199 except KeyError:
202 except KeyError:
200 try:
203 try:
201 # See if msg_or_header is just the header
204 # See if msg_or_header is just the header
202 h = msg_or_header['msg_id']
205 h = msg_or_header['msg_id']
203 except KeyError:
206 except KeyError:
204 raise
207 raise
205 else:
208 else:
206 h = msg_or_header
209 h = msg_or_header
207 if not isinstance(h, dict):
210 if not isinstance(h, dict):
208 h = dict(h)
211 h = dict(h)
209 return h
212 return h
210
213
211 class Session(Configurable):
214 class Session(Configurable):
212 """Object for handling serialization and sending of messages.
215 """Object for handling serialization and sending of messages.
213
216
214 The Session object handles building messages and sending them
217 The Session object handles building messages and sending them
215 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
218 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
216 other over the network via Session objects, and only need to work with the
219 other over the network via Session objects, and only need to work with the
217 dict-based IPython message spec. The Session will handle
220 dict-based IPython message spec. The Session will handle
218 serialization/deserialization, security, and metadata.
221 serialization/deserialization, security, and metadata.
219
222
220 Sessions support configurable serialiization via packer/unpacker traits,
223 Sessions support configurable serialiization via packer/unpacker traits,
221 and signing with HMAC digests via the key/keyfile traits.
224 and signing with HMAC digests via the key/keyfile traits.
222
225
223 Parameters
226 Parameters
224 ----------
227 ----------
225
228
226 debug : bool
229 debug : bool
227 whether to trigger extra debugging statements
230 whether to trigger extra debugging statements
228 packer/unpacker : str : 'json', 'pickle' or import_string
231 packer/unpacker : str : 'json', 'pickle' or import_string
229 importstrings for methods to serialize message parts. If just
232 importstrings for methods to serialize message parts. If just
230 'json' or 'pickle', predefined JSON and pickle packers will be used.
233 'json' or 'pickle', predefined JSON and pickle packers will be used.
231 Otherwise, the entire importstring must be used.
234 Otherwise, the entire importstring must be used.
232
235
233 The functions must accept at least valid JSON input, and output *bytes*.
236 The functions must accept at least valid JSON input, and output *bytes*.
234
237
235 For example, to use msgpack:
238 For example, to use msgpack:
236 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
239 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
237 pack/unpack : callables
240 pack/unpack : callables
238 You can also set the pack/unpack callables for serialization directly.
241 You can also set the pack/unpack callables for serialization directly.
239 session : bytes
242 session : bytes
240 the ID of this Session object. The default is to generate a new UUID.
243 the ID of this Session object. The default is to generate a new UUID.
241 username : unicode
244 username : unicode
242 username added to message headers. The default is to ask the OS.
245 username added to message headers. The default is to ask the OS.
243 key : bytes
246 key : bytes
244 The key used to initialize an HMAC signature. If unset, messages
247 The key used to initialize an HMAC signature. If unset, messages
245 will not be signed or checked.
248 will not be signed or checked.
246 keyfile : filepath
249 keyfile : filepath
247 The file containing a key. If this is set, `key` will be initialized
250 The file containing a key. If this is set, `key` will be initialized
248 to the contents of the file.
251 to the contents of the file.
249
252
250 """
253 """
251
254
252 debug=Bool(False, config=True, help="""Debug output in the Session""")
255 debug=Bool(False, config=True, help="""Debug output in the Session""")
253
256
254 packer = DottedObjectName('json',config=True,
257 packer = DottedObjectName('json',config=True,
255 help="""The name of the packer for serializing messages.
258 help="""The name of the packer for serializing messages.
256 Should be one of 'json', 'pickle', or an import name
259 Should be one of 'json', 'pickle', or an import name
257 for a custom callable serializer.""")
260 for a custom callable serializer.""")
258 def _packer_changed(self, name, old, new):
261 def _packer_changed(self, name, old, new):
259 if new.lower() == 'json':
262 if new.lower() == 'json':
260 self.pack = json_packer
263 self.pack = json_packer
261 self.unpack = json_unpacker
264 self.unpack = json_unpacker
262 self.unpacker = new
265 self.unpacker = new
263 elif new.lower() == 'pickle':
266 elif new.lower() == 'pickle':
264 self.pack = pickle_packer
267 self.pack = pickle_packer
265 self.unpack = pickle_unpacker
268 self.unpack = pickle_unpacker
266 self.unpacker = new
269 self.unpacker = new
267 else:
270 else:
268 self.pack = import_item(str(new))
271 self.pack = import_item(str(new))
269
272
270 unpacker = DottedObjectName('json', config=True,
273 unpacker = DottedObjectName('json', config=True,
271 help="""The name of the unpacker for unserializing messages.
274 help="""The name of the unpacker for unserializing messages.
272 Only used with custom functions for `packer`.""")
275 Only used with custom functions for `packer`.""")
273 def _unpacker_changed(self, name, old, new):
276 def _unpacker_changed(self, name, old, new):
274 if new.lower() == 'json':
277 if new.lower() == 'json':
275 self.pack = json_packer
278 self.pack = json_packer
276 self.unpack = json_unpacker
279 self.unpack = json_unpacker
277 self.packer = new
280 self.packer = new
278 elif new.lower() == 'pickle':
281 elif new.lower() == 'pickle':
279 self.pack = pickle_packer
282 self.pack = pickle_packer
280 self.unpack = pickle_unpacker
283 self.unpack = pickle_unpacker
281 self.packer = new
284 self.packer = new
282 else:
285 else:
283 self.unpack = import_item(str(new))
286 self.unpack = import_item(str(new))
284
287
285 session = CUnicode(u'', config=True,
288 session = CUnicode(u'', config=True,
286 help="""The UUID identifying this session.""")
289 help="""The UUID identifying this session.""")
287 def _session_default(self):
290 def _session_default(self):
288 u = unicode(uuid.uuid4())
291 u = unicode(uuid.uuid4())
289 self.bsession = u.encode('ascii')
292 self.bsession = u.encode('ascii')
290 return u
293 return u
291
294
292 def _session_changed(self, name, old, new):
295 def _session_changed(self, name, old, new):
293 self.bsession = self.session.encode('ascii')
296 self.bsession = self.session.encode('ascii')
294
297
295 # bsession is the session as bytes
298 # bsession is the session as bytes
296 bsession = CBytes(b'')
299 bsession = CBytes(b'')
297
300
298 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
301 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
299 help="""Username for the Session. Default is your system username.""",
302 help="""Username for the Session. Default is your system username.""",
300 config=True)
303 config=True)
301
304
302 metadata = Dict({}, config=True,
305 metadata = Dict({}, config=True,
303 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
306 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
304
307
305 # message signature related traits:
308 # message signature related traits:
306
309
307 key = CBytes(b'', config=True,
310 key = CBytes(b'', config=True,
308 help="""execution key, for extra authentication.""")
311 help="""execution key, for extra authentication.""")
309 def _key_changed(self, name, old, new):
312 def _key_changed(self, name, old, new):
310 if new:
313 if new:
311 self.auth = hmac.HMAC(new)
314 self.auth = hmac.HMAC(new, digestmod=self.digest_mod)
312 else:
315 else:
313 self.auth = None
316 self.auth = None
314
317
318 signature_scheme = Unicode('hmac-sha256', config=True,
319 help="""The digest scheme used to construct the message signatures.
320 Must have the form 'hmac-HASH'.""")
321 def _signature_scheme_changed(self, name, old, new):
322 if not new.startswith('hmac-'):
323 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
324 hash_name = new.split('-', 1)[1]
325 try:
326 self.digest_mod = getattr(hashlib, hash_name)
327 except AttributeError:
328 raise TraitError("hashlib has no such attribute: %s" % hash_name)
329
330 digest_mod = Any()
331 def _digest_mod_default(self):
332 return hashlib.sha256
333
315 auth = Instance(hmac.HMAC)
334 auth = Instance(hmac.HMAC)
316
335
317 digest_history = Set()
336 digest_history = Set()
318 digest_history_size = Integer(2**16, config=True,
337 digest_history_size = Integer(2**16, config=True,
319 help="""The maximum number of digests to remember.
338 help="""The maximum number of digests to remember.
320
339
321 The digest history will be culled when it exceeds this value.
340 The digest history will be culled when it exceeds this value.
322 """
341 """
323 )
342 )
324
343
325 keyfile = Unicode('', config=True,
344 keyfile = Unicode('', config=True,
326 help="""path to file containing execution key.""")
345 help="""path to file containing execution key.""")
327 def _keyfile_changed(self, name, old, new):
346 def _keyfile_changed(self, name, old, new):
328 with open(new, 'rb') as f:
347 with open(new, 'rb') as f:
329 self.key = f.read().strip()
348 self.key = f.read().strip()
330
349
331 # for protecting against sends from forks
350 # for protecting against sends from forks
332 pid = Integer()
351 pid = Integer()
333
352
334 # serialization traits:
353 # serialization traits:
335
354
336 pack = Any(default_packer) # the actual packer function
355 pack = Any(default_packer) # the actual packer function
337 def _pack_changed(self, name, old, new):
356 def _pack_changed(self, name, old, new):
338 if not callable(new):
357 if not callable(new):
339 raise TypeError("packer must be callable, not %s"%type(new))
358 raise TypeError("packer must be callable, not %s"%type(new))
340
359
341 unpack = Any(default_unpacker) # the actual packer function
360 unpack = Any(default_unpacker) # the actual packer function
342 def _unpack_changed(self, name, old, new):
361 def _unpack_changed(self, name, old, new):
343 # unpacker is not checked - it is assumed to be
362 # unpacker is not checked - it is assumed to be
344 if not callable(new):
363 if not callable(new):
345 raise TypeError("unpacker must be callable, not %s"%type(new))
364 raise TypeError("unpacker must be callable, not %s"%type(new))
346
365
347 # thresholds:
366 # thresholds:
348 copy_threshold = Integer(2**16, config=True,
367 copy_threshold = Integer(2**16, config=True,
349 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
368 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
350 buffer_threshold = Integer(MAX_BYTES, config=True,
369 buffer_threshold = Integer(MAX_BYTES, config=True,
351 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
370 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
352 item_threshold = Integer(MAX_ITEMS, config=True,
371 item_threshold = Integer(MAX_ITEMS, config=True,
353 help="""The maximum number of items for a container to be introspected for custom serialization.
372 help="""The maximum number of items for a container to be introspected for custom serialization.
354 Containers larger than this are pickled outright.
373 Containers larger than this are pickled outright.
355 """
374 """
356 )
375 )
357
376
358
377
359 def __init__(self, **kwargs):
378 def __init__(self, **kwargs):
360 """create a Session object
379 """create a Session object
361
380
362 Parameters
381 Parameters
363 ----------
382 ----------
364
383
365 debug : bool
384 debug : bool
366 whether to trigger extra debugging statements
385 whether to trigger extra debugging statements
367 packer/unpacker : str : 'json', 'pickle' or import_string
386 packer/unpacker : str : 'json', 'pickle' or import_string
368 importstrings for methods to serialize message parts. If just
387 importstrings for methods to serialize message parts. If just
369 'json' or 'pickle', predefined JSON and pickle packers will be used.
388 'json' or 'pickle', predefined JSON and pickle packers will be used.
370 Otherwise, the entire importstring must be used.
389 Otherwise, the entire importstring must be used.
371
390
372 The functions must accept at least valid JSON input, and output
391 The functions must accept at least valid JSON input, and output
373 *bytes*.
392 *bytes*.
374
393
375 For example, to use msgpack:
394 For example, to use msgpack:
376 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
395 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
377 pack/unpack : callables
396 pack/unpack : callables
378 You can also set the pack/unpack callables for serialization
397 You can also set the pack/unpack callables for serialization
379 directly.
398 directly.
380 session : unicode (must be ascii)
399 session : unicode (must be ascii)
381 the ID of this Session object. The default is to generate a new
400 the ID of this Session object. The default is to generate a new
382 UUID.
401 UUID.
383 bsession : bytes
402 bsession : bytes
384 The session as bytes
403 The session as bytes
385 username : unicode
404 username : unicode
386 username added to message headers. The default is to ask the OS.
405 username added to message headers. The default is to ask the OS.
387 key : bytes
406 key : bytes
388 The key used to initialize an HMAC signature. If unset, messages
407 The key used to initialize an HMAC signature. If unset, messages
389 will not be signed or checked.
408 will not be signed or checked.
409 signature_scheme : str
410 The message digest scheme. Currently must be of the form 'hmac-HASH',
411 where 'HASH' is a hashing function available in Python's hashlib.
412 The default is 'hmac-sha256'.
413 This is ignored if 'key' is empty.
390 keyfile : filepath
414 keyfile : filepath
391 The file containing a key. If this is set, `key` will be
415 The file containing a key. If this is set, `key` will be
392 initialized to the contents of the file.
416 initialized to the contents of the file.
393 """
417 """
394 super(Session, self).__init__(**kwargs)
418 super(Session, self).__init__(**kwargs)
395 self._check_packers()
419 self._check_packers()
396 self.none = self.pack({})
420 self.none = self.pack({})
397 # ensure self._session_default() if necessary, so bsession is defined:
421 # ensure self._session_default() if necessary, so bsession is defined:
398 self.session
422 self.session
399 self.pid = os.getpid()
423 self.pid = os.getpid()
400
424
401 @property
425 @property
402 def msg_id(self):
426 def msg_id(self):
403 """always return new uuid"""
427 """always return new uuid"""
404 return str(uuid.uuid4())
428 return str(uuid.uuid4())
405
429
406 def _check_packers(self):
430 def _check_packers(self):
407 """check packers for binary data and datetime support."""
431 """check packers for binary data and datetime support."""
408 pack = self.pack
432 pack = self.pack
409 unpack = self.unpack
433 unpack = self.unpack
410
434
411 # check simple serialization
435 # check simple serialization
412 msg = dict(a=[1,'hi'])
436 msg = dict(a=[1,'hi'])
413 try:
437 try:
414 packed = pack(msg)
438 packed = pack(msg)
415 except Exception:
439 except Exception:
416 raise ValueError("packer could not serialize a simple message")
440 raise ValueError("packer could not serialize a simple message")
417
441
418 # ensure packed message is bytes
442 # ensure packed message is bytes
419 if not isinstance(packed, bytes):
443 if not isinstance(packed, bytes):
420 raise ValueError("message packed to %r, but bytes are required"%type(packed))
444 raise ValueError("message packed to %r, but bytes are required"%type(packed))
421
445
422 # check that unpack is pack's inverse
446 # check that unpack is pack's inverse
423 try:
447 try:
424 unpacked = unpack(packed)
448 unpacked = unpack(packed)
425 except Exception:
449 except Exception:
426 raise ValueError("unpacker could not handle the packer's output")
450 raise ValueError("unpacker could not handle the packer's output")
427
451
428 # check datetime support
452 # check datetime support
429 msg = dict(t=datetime.now())
453 msg = dict(t=datetime.now())
430 try:
454 try:
431 unpacked = unpack(pack(msg))
455 unpacked = unpack(pack(msg))
432 except Exception:
456 except Exception:
433 self.pack = lambda o: pack(squash_dates(o))
457 self.pack = lambda o: pack(squash_dates(o))
434 self.unpack = lambda s: extract_dates(unpack(s))
458 self.unpack = lambda s: extract_dates(unpack(s))
435
459
436 def msg_header(self, msg_type):
460 def msg_header(self, msg_type):
437 return msg_header(self.msg_id, msg_type, self.username, self.session)
461 return msg_header(self.msg_id, msg_type, self.username, self.session)
438
462
439 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
463 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
440 """Return the nested message dict.
464 """Return the nested message dict.
441
465
442 This format is different from what is sent over the wire. The
466 This format is different from what is sent over the wire. The
443 serialize/unserialize methods converts this nested message dict to the wire
467 serialize/unserialize methods converts this nested message dict to the wire
444 format, which is a list of message parts.
468 format, which is a list of message parts.
445 """
469 """
446 msg = {}
470 msg = {}
447 header = self.msg_header(msg_type) if header is None else header
471 header = self.msg_header(msg_type) if header is None else header
448 msg['header'] = header
472 msg['header'] = header
449 msg['msg_id'] = header['msg_id']
473 msg['msg_id'] = header['msg_id']
450 msg['msg_type'] = header['msg_type']
474 msg['msg_type'] = header['msg_type']
451 msg['parent_header'] = {} if parent is None else extract_header(parent)
475 msg['parent_header'] = {} if parent is None else extract_header(parent)
452 msg['content'] = {} if content is None else content
476 msg['content'] = {} if content is None else content
453 msg['metadata'] = self.metadata.copy()
477 msg['metadata'] = self.metadata.copy()
454 if metadata is not None:
478 if metadata is not None:
455 msg['metadata'].update(metadata)
479 msg['metadata'].update(metadata)
456 return msg
480 return msg
457
481
458 def sign(self, msg_list):
482 def sign(self, msg_list):
459 """Sign a message with HMAC digest. If no auth, return b''.
483 """Sign a message with HMAC digest. If no auth, return b''.
460
484
461 Parameters
485 Parameters
462 ----------
486 ----------
463 msg_list : list
487 msg_list : list
464 The [p_header,p_parent,p_content] part of the message list.
488 The [p_header,p_parent,p_content] part of the message list.
465 """
489 """
466 if self.auth is None:
490 if self.auth is None:
467 return b''
491 return b''
468 h = self.auth.copy()
492 h = self.auth.copy()
469 for m in msg_list:
493 for m in msg_list:
470 h.update(m)
494 h.update(m)
471 return str_to_bytes(h.hexdigest())
495 return str_to_bytes(h.hexdigest())
472
496
473 def serialize(self, msg, ident=None):
497 def serialize(self, msg, ident=None):
474 """Serialize the message components to bytes.
498 """Serialize the message components to bytes.
475
499
476 This is roughly the inverse of unserialize. The serialize/unserialize
500 This is roughly the inverse of unserialize. The serialize/unserialize
477 methods work with full message lists, whereas pack/unpack work with
501 methods work with full message lists, whereas pack/unpack work with
478 the individual message parts in the message list.
502 the individual message parts in the message list.
479
503
480 Parameters
504 Parameters
481 ----------
505 ----------
482 msg : dict or Message
506 msg : dict or Message
483 The nexted message dict as returned by the self.msg method.
507 The nexted message dict as returned by the self.msg method.
484
508
485 Returns
509 Returns
486 -------
510 -------
487 msg_list : list
511 msg_list : list
488 The list of bytes objects to be sent with the format:
512 The list of bytes objects to be sent with the format:
489 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_metadata,p_content,
513 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_metadata,p_content,
490 buffer1,buffer2,...]. In this list, the p_* entities are
514 buffer1,buffer2,...]. In this list, the p_* entities are
491 the packed or serialized versions, so if JSON is used, these
515 the packed or serialized versions, so if JSON is used, these
492 are utf8 encoded JSON strings.
516 are utf8 encoded JSON strings.
493 """
517 """
494 content = msg.get('content', {})
518 content = msg.get('content', {})
495 if content is None:
519 if content is None:
496 content = self.none
520 content = self.none
497 elif isinstance(content, dict):
521 elif isinstance(content, dict):
498 content = self.pack(content)
522 content = self.pack(content)
499 elif isinstance(content, bytes):
523 elif isinstance(content, bytes):
500 # content is already packed, as in a relayed message
524 # content is already packed, as in a relayed message
501 pass
525 pass
502 elif isinstance(content, unicode):
526 elif isinstance(content, unicode):
503 # should be bytes, but JSON often spits out unicode
527 # should be bytes, but JSON often spits out unicode
504 content = content.encode('utf8')
528 content = content.encode('utf8')
505 else:
529 else:
506 raise TypeError("Content incorrect type: %s"%type(content))
530 raise TypeError("Content incorrect type: %s"%type(content))
507
531
508 real_message = [self.pack(msg['header']),
532 real_message = [self.pack(msg['header']),
509 self.pack(msg['parent_header']),
533 self.pack(msg['parent_header']),
510 self.pack(msg['metadata']),
534 self.pack(msg['metadata']),
511 content,
535 content,
512 ]
536 ]
513
537
514 to_send = []
538 to_send = []
515
539
516 if isinstance(ident, list):
540 if isinstance(ident, list):
517 # accept list of idents
541 # accept list of idents
518 to_send.extend(ident)
542 to_send.extend(ident)
519 elif ident is not None:
543 elif ident is not None:
520 to_send.append(ident)
544 to_send.append(ident)
521 to_send.append(DELIM)
545 to_send.append(DELIM)
522
546
523 signature = self.sign(real_message)
547 signature = self.sign(real_message)
524 to_send.append(signature)
548 to_send.append(signature)
525
549
526 to_send.extend(real_message)
550 to_send.extend(real_message)
527
551
528 return to_send
552 return to_send
529
553
530 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
554 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
531 buffers=None, track=False, header=None, metadata=None):
555 buffers=None, track=False, header=None, metadata=None):
532 """Build and send a message via stream or socket.
556 """Build and send a message via stream or socket.
533
557
534 The message format used by this function internally is as follows:
558 The message format used by this function internally is as follows:
535
559
536 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
560 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
537 buffer1,buffer2,...]
561 buffer1,buffer2,...]
538
562
539 The serialize/unserialize methods convert the nested message dict into this
563 The serialize/unserialize methods convert the nested message dict into this
540 format.
564 format.
541
565
542 Parameters
566 Parameters
543 ----------
567 ----------
544
568
545 stream : zmq.Socket or ZMQStream
569 stream : zmq.Socket or ZMQStream
546 The socket-like object used to send the data.
570 The socket-like object used to send the data.
547 msg_or_type : str or Message/dict
571 msg_or_type : str or Message/dict
548 Normally, msg_or_type will be a msg_type unless a message is being
572 Normally, msg_or_type will be a msg_type unless a message is being
549 sent more than once. If a header is supplied, this can be set to
573 sent more than once. If a header is supplied, this can be set to
550 None and the msg_type will be pulled from the header.
574 None and the msg_type will be pulled from the header.
551
575
552 content : dict or None
576 content : dict or None
553 The content of the message (ignored if msg_or_type is a message).
577 The content of the message (ignored if msg_or_type is a message).
554 header : dict or None
578 header : dict or None
555 The header dict for the message (ignored if msg_to_type is a message).
579 The header dict for the message (ignored if msg_to_type is a message).
556 parent : Message or dict or None
580 parent : Message or dict or None
557 The parent or parent header describing the parent of this message
581 The parent or parent header describing the parent of this message
558 (ignored if msg_or_type is a message).
582 (ignored if msg_or_type is a message).
559 ident : bytes or list of bytes
583 ident : bytes or list of bytes
560 The zmq.IDENTITY routing path.
584 The zmq.IDENTITY routing path.
561 metadata : dict or None
585 metadata : dict or None
562 The metadata describing the message
586 The metadata describing the message
563 buffers : list or None
587 buffers : list or None
564 The already-serialized buffers to be appended to the message.
588 The already-serialized buffers to be appended to the message.
565 track : bool
589 track : bool
566 Whether to track. Only for use with Sockets, because ZMQStream
590 Whether to track. Only for use with Sockets, because ZMQStream
567 objects cannot track messages.
591 objects cannot track messages.
568
592
569
593
570 Returns
594 Returns
571 -------
595 -------
572 msg : dict
596 msg : dict
573 The constructed message.
597 The constructed message.
574 """
598 """
575 if not isinstance(stream, zmq.Socket):
599 if not isinstance(stream, zmq.Socket):
576 # ZMQStreams and dummy sockets do not support tracking.
600 # ZMQStreams and dummy sockets do not support tracking.
577 track = False
601 track = False
578
602
579 if isinstance(msg_or_type, (Message, dict)):
603 if isinstance(msg_or_type, (Message, dict)):
580 # We got a Message or message dict, not a msg_type so don't
604 # We got a Message or message dict, not a msg_type so don't
581 # build a new Message.
605 # build a new Message.
582 msg = msg_or_type
606 msg = msg_or_type
583 else:
607 else:
584 msg = self.msg(msg_or_type, content=content, parent=parent,
608 msg = self.msg(msg_or_type, content=content, parent=parent,
585 header=header, metadata=metadata)
609 header=header, metadata=metadata)
586 if not os.getpid() == self.pid:
610 if not os.getpid() == self.pid:
587 io.rprint("WARNING: attempted to send message from fork")
611 io.rprint("WARNING: attempted to send message from fork")
588 io.rprint(msg)
612 io.rprint(msg)
589 return
613 return
590 buffers = [] if buffers is None else buffers
614 buffers = [] if buffers is None else buffers
591 to_send = self.serialize(msg, ident)
615 to_send = self.serialize(msg, ident)
592 to_send.extend(buffers)
616 to_send.extend(buffers)
593 longest = max([ len(s) for s in to_send ])
617 longest = max([ len(s) for s in to_send ])
594 copy = (longest < self.copy_threshold)
618 copy = (longest < self.copy_threshold)
595
619
596 if buffers and track and not copy:
620 if buffers and track and not copy:
597 # only really track when we are doing zero-copy buffers
621 # only really track when we are doing zero-copy buffers
598 tracker = stream.send_multipart(to_send, copy=False, track=True)
622 tracker = stream.send_multipart(to_send, copy=False, track=True)
599 else:
623 else:
600 # use dummy tracker, which will be done immediately
624 # use dummy tracker, which will be done immediately
601 tracker = DONE
625 tracker = DONE
602 stream.send_multipart(to_send, copy=copy)
626 stream.send_multipart(to_send, copy=copy)
603
627
604 if self.debug:
628 if self.debug:
605 pprint.pprint(msg)
629 pprint.pprint(msg)
606 pprint.pprint(to_send)
630 pprint.pprint(to_send)
607 pprint.pprint(buffers)
631 pprint.pprint(buffers)
608
632
609 msg['tracker'] = tracker
633 msg['tracker'] = tracker
610
634
611 return msg
635 return msg
612
636
613 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
637 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
614 """Send a raw message via ident path.
638 """Send a raw message via ident path.
615
639
616 This method is used to send a already serialized message.
640 This method is used to send a already serialized message.
617
641
618 Parameters
642 Parameters
619 ----------
643 ----------
620 stream : ZMQStream or Socket
644 stream : ZMQStream or Socket
621 The ZMQ stream or socket to use for sending the message.
645 The ZMQ stream or socket to use for sending the message.
622 msg_list : list
646 msg_list : list
623 The serialized list of messages to send. This only includes the
647 The serialized list of messages to send. This only includes the
624 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
648 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
625 the message.
649 the message.
626 ident : ident or list
650 ident : ident or list
627 A single ident or a list of idents to use in sending.
651 A single ident or a list of idents to use in sending.
628 """
652 """
629 to_send = []
653 to_send = []
630 if isinstance(ident, bytes):
654 if isinstance(ident, bytes):
631 ident = [ident]
655 ident = [ident]
632 if ident is not None:
656 if ident is not None:
633 to_send.extend(ident)
657 to_send.extend(ident)
634
658
635 to_send.append(DELIM)
659 to_send.append(DELIM)
636 to_send.append(self.sign(msg_list))
660 to_send.append(self.sign(msg_list))
637 to_send.extend(msg_list)
661 to_send.extend(msg_list)
638 stream.send_multipart(msg_list, flags, copy=copy)
662 stream.send_multipart(msg_list, flags, copy=copy)
639
663
640 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
664 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
641 """Receive and unpack a message.
665 """Receive and unpack a message.
642
666
643 Parameters
667 Parameters
644 ----------
668 ----------
645 socket : ZMQStream or Socket
669 socket : ZMQStream or Socket
646 The socket or stream to use in receiving.
670 The socket or stream to use in receiving.
647
671
648 Returns
672 Returns
649 -------
673 -------
650 [idents], msg
674 [idents], msg
651 [idents] is a list of idents and msg is a nested message dict of
675 [idents] is a list of idents and msg is a nested message dict of
652 same format as self.msg returns.
676 same format as self.msg returns.
653 """
677 """
654 if isinstance(socket, ZMQStream):
678 if isinstance(socket, ZMQStream):
655 socket = socket.socket
679 socket = socket.socket
656 try:
680 try:
657 msg_list = socket.recv_multipart(mode, copy=copy)
681 msg_list = socket.recv_multipart(mode, copy=copy)
658 except zmq.ZMQError as e:
682 except zmq.ZMQError as e:
659 if e.errno == zmq.EAGAIN:
683 if e.errno == zmq.EAGAIN:
660 # We can convert EAGAIN to None as we know in this case
684 # We can convert EAGAIN to None as we know in this case
661 # recv_multipart won't return None.
685 # recv_multipart won't return None.
662 return None,None
686 return None,None
663 else:
687 else:
664 raise
688 raise
665 # split multipart message into identity list and message dict
689 # split multipart message into identity list and message dict
666 # invalid large messages can cause very expensive string comparisons
690 # invalid large messages can cause very expensive string comparisons
667 idents, msg_list = self.feed_identities(msg_list, copy)
691 idents, msg_list = self.feed_identities(msg_list, copy)
668 try:
692 try:
669 return idents, self.unserialize(msg_list, content=content, copy=copy)
693 return idents, self.unserialize(msg_list, content=content, copy=copy)
670 except Exception as e:
694 except Exception as e:
671 # TODO: handle it
695 # TODO: handle it
672 raise e
696 raise e
673
697
674 def feed_identities(self, msg_list, copy=True):
698 def feed_identities(self, msg_list, copy=True):
675 """Split the identities from the rest of the message.
699 """Split the identities from the rest of the message.
676
700
677 Feed until DELIM is reached, then return the prefix as idents and
701 Feed until DELIM is reached, then return the prefix as idents and
678 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
702 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
679 but that would be silly.
703 but that would be silly.
680
704
681 Parameters
705 Parameters
682 ----------
706 ----------
683 msg_list : a list of Message or bytes objects
707 msg_list : a list of Message or bytes objects
684 The message to be split.
708 The message to be split.
685 copy : bool
709 copy : bool
686 flag determining whether the arguments are bytes or Messages
710 flag determining whether the arguments are bytes or Messages
687
711
688 Returns
712 Returns
689 -------
713 -------
690 (idents, msg_list) : two lists
714 (idents, msg_list) : two lists
691 idents will always be a list of bytes, each of which is a ZMQ
715 idents will always be a list of bytes, each of which is a ZMQ
692 identity. msg_list will be a list of bytes or zmq.Messages of the
716 identity. msg_list will be a list of bytes or zmq.Messages of the
693 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
717 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
694 should be unpackable/unserializable via self.unserialize at this
718 should be unpackable/unserializable via self.unserialize at this
695 point.
719 point.
696 """
720 """
697 if copy:
721 if copy:
698 idx = msg_list.index(DELIM)
722 idx = msg_list.index(DELIM)
699 return msg_list[:idx], msg_list[idx+1:]
723 return msg_list[:idx], msg_list[idx+1:]
700 else:
724 else:
701 failed = True
725 failed = True
702 for idx,m in enumerate(msg_list):
726 for idx,m in enumerate(msg_list):
703 if m.bytes == DELIM:
727 if m.bytes == DELIM:
704 failed = False
728 failed = False
705 break
729 break
706 if failed:
730 if failed:
707 raise ValueError("DELIM not in msg_list")
731 raise ValueError("DELIM not in msg_list")
708 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
732 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
709 return [m.bytes for m in idents], msg_list
733 return [m.bytes for m in idents], msg_list
710
734
711 def _add_digest(self, signature):
735 def _add_digest(self, signature):
712 """add a digest to history to protect against replay attacks"""
736 """add a digest to history to protect against replay attacks"""
713 if self.digest_history_size == 0:
737 if self.digest_history_size == 0:
714 # no history, never add digests
738 # no history, never add digests
715 return
739 return
716
740
717 self.digest_history.add(signature)
741 self.digest_history.add(signature)
718 if len(self.digest_history) > self.digest_history_size:
742 if len(self.digest_history) > self.digest_history_size:
719 # threshold reached, cull 10%
743 # threshold reached, cull 10%
720 self._cull_digest_history()
744 self._cull_digest_history()
721
745
722 def _cull_digest_history(self):
746 def _cull_digest_history(self):
723 """cull the digest history
747 """cull the digest history
724
748
725 Removes a randomly selected 10% of the digest history
749 Removes a randomly selected 10% of the digest history
726 """
750 """
727 current = len(self.digest_history)
751 current = len(self.digest_history)
728 n_to_cull = max(int(current // 10), current - self.digest_history_size)
752 n_to_cull = max(int(current // 10), current - self.digest_history_size)
729 if n_to_cull >= current:
753 if n_to_cull >= current:
730 self.digest_history = set()
754 self.digest_history = set()
731 return
755 return
732 to_cull = random.sample(self.digest_history, n_to_cull)
756 to_cull = random.sample(self.digest_history, n_to_cull)
733 self.digest_history.difference_update(to_cull)
757 self.digest_history.difference_update(to_cull)
734
758
735 def unserialize(self, msg_list, content=True, copy=True):
759 def unserialize(self, msg_list, content=True, copy=True):
736 """Unserialize a msg_list to a nested message dict.
760 """Unserialize a msg_list to a nested message dict.
737
761
738 This is roughly the inverse of serialize. The serialize/unserialize
762 This is roughly the inverse of serialize. The serialize/unserialize
739 methods work with full message lists, whereas pack/unpack work with
763 methods work with full message lists, whereas pack/unpack work with
740 the individual message parts in the message list.
764 the individual message parts in the message list.
741
765
742 Parameters:
766 Parameters:
743 -----------
767 -----------
744 msg_list : list of bytes or Message objects
768 msg_list : list of bytes or Message objects
745 The list of message parts of the form [HMAC,p_header,p_parent,
769 The list of message parts of the form [HMAC,p_header,p_parent,
746 p_metadata,p_content,buffer1,buffer2,...].
770 p_metadata,p_content,buffer1,buffer2,...].
747 content : bool (True)
771 content : bool (True)
748 Whether to unpack the content dict (True), or leave it packed
772 Whether to unpack the content dict (True), or leave it packed
749 (False).
773 (False).
750 copy : bool (True)
774 copy : bool (True)
751 Whether to return the bytes (True), or the non-copying Message
775 Whether to return the bytes (True), or the non-copying Message
752 object in each place (False).
776 object in each place (False).
753
777
754 Returns
778 Returns
755 -------
779 -------
756 msg : dict
780 msg : dict
757 The nested message dict with top-level keys [header, parent_header,
781 The nested message dict with top-level keys [header, parent_header,
758 content, buffers].
782 content, buffers].
759 """
783 """
760 minlen = 5
784 minlen = 5
761 message = {}
785 message = {}
762 if not copy:
786 if not copy:
763 for i in range(minlen):
787 for i in range(minlen):
764 msg_list[i] = msg_list[i].bytes
788 msg_list[i] = msg_list[i].bytes
765 if self.auth is not None:
789 if self.auth is not None:
766 signature = msg_list[0]
790 signature = msg_list[0]
767 if not signature:
791 if not signature:
768 raise ValueError("Unsigned Message")
792 raise ValueError("Unsigned Message")
769 if signature in self.digest_history:
793 if signature in self.digest_history:
770 raise ValueError("Duplicate Signature: %r" % signature)
794 raise ValueError("Duplicate Signature: %r" % signature)
771 self._add_digest(signature)
795 self._add_digest(signature)
772 check = self.sign(msg_list[1:5])
796 check = self.sign(msg_list[1:5])
773 if not signature == check:
797 if not signature == check:
774 raise ValueError("Invalid Signature: %r" % signature)
798 raise ValueError("Invalid Signature: %r" % signature)
775 if not len(msg_list) >= minlen:
799 if not len(msg_list) >= minlen:
776 raise TypeError("malformed message, must have at least %i elements"%minlen)
800 raise TypeError("malformed message, must have at least %i elements"%minlen)
777 header = self.unpack(msg_list[1])
801 header = self.unpack(msg_list[1])
778 message['header'] = header
802 message['header'] = header
779 message['msg_id'] = header['msg_id']
803 message['msg_id'] = header['msg_id']
780 message['msg_type'] = header['msg_type']
804 message['msg_type'] = header['msg_type']
781 message['parent_header'] = self.unpack(msg_list[2])
805 message['parent_header'] = self.unpack(msg_list[2])
782 message['metadata'] = self.unpack(msg_list[3])
806 message['metadata'] = self.unpack(msg_list[3])
783 if content:
807 if content:
784 message['content'] = self.unpack(msg_list[4])
808 message['content'] = self.unpack(msg_list[4])
785 else:
809 else:
786 message['content'] = msg_list[4]
810 message['content'] = msg_list[4]
787
811
788 message['buffers'] = msg_list[5:]
812 message['buffers'] = msg_list[5:]
789 return message
813 return message
790
814
791 def test_msg2obj():
815 def test_msg2obj():
792 am = dict(x=1)
816 am = dict(x=1)
793 ao = Message(am)
817 ao = Message(am)
794 assert ao.x == am['x']
818 assert ao.x == am['x']
795
819
796 am['y'] = dict(z=1)
820 am['y'] = dict(z=1)
797 ao = Message(am)
821 ao = Message(am)
798 assert ao.y.z == am['y']['z']
822 assert ao.y.z == am['y']['z']
799
823
800 k1, k2 = 'y', 'z'
824 k1, k2 = 'y', 'z'
801 assert ao[k1][k2] == am[k1][k2]
825 assert ao[k1][k2] == am[k1][k2]
802
826
803 am2 = dict(ao)
827 am2 = dict(ao)
804 assert am['x'] == am2['x']
828 assert am['x'] == am2['x']
805 assert am['y']['z'] == am2['y']['z']
829 assert am['y']['z'] == am2['y']['z']
806
830
General Comments 0
You need to be logged in to leave comments. Login now