##// END OF EJS Templates
add signature_scheme to Session and connection files...
MinRK -
Show More
@@ -1,390 +1,392 b''
1 1 """ A minimal application base mixin for all ZMQ based IPython frontends.
2 2
3 3 This is not a complete console app, as subprocess will not be able to receive
4 4 input, there is no real readline support, among other limitations. This is a
5 5 refactoring of what used to be the IPython/qt/console/qtconsoleapp.py
6 6
7 7 Authors:
8 8
9 9 * Evan Patterson
10 10 * Min RK
11 11 * Erik Tollerud
12 12 * Fernando Perez
13 13 * Bussonnier Matthias
14 14 * Thomas Kluyver
15 15 * Paul Ivanov
16 16
17 17 """
18 18
19 19 #-----------------------------------------------------------------------------
20 20 # Imports
21 21 #-----------------------------------------------------------------------------
22 22
23 23 # stdlib imports
24 24 import atexit
25 25 import json
26 26 import os
27 27 import shutil
28 28 import signal
29 29 import sys
30 30 import uuid
31 31
32 32
33 33 # Local imports
34 34 from IPython.config.application import boolean_flag
35 35 from IPython.config.configurable import Configurable
36 36 from IPython.core.profiledir import ProfileDir
37 37 from IPython.kernel.blocking import BlockingKernelClient
38 38 from IPython.kernel import KernelManager
39 39 from IPython.kernel import tunnel_to_kernel, find_connection_file, swallow_argv
40 40 from IPython.utils.path import filefind
41 41 from IPython.utils.py3compat import str_to_bytes
42 42 from IPython.utils.traitlets import (
43 43 Dict, List, Unicode, CUnicode, Int, CBool, Any, CaselessStrEnum
44 44 )
45 45 from IPython.kernel.zmq.kernelapp import (
46 46 kernel_flags,
47 47 kernel_aliases,
48 48 IPKernelApp
49 49 )
50 50 from IPython.kernel.zmq.session import Session, default_secure
51 51 from IPython.kernel.zmq.zmqshell import ZMQInteractiveShell
52 52
53 53 #-----------------------------------------------------------------------------
54 54 # Network Constants
55 55 #-----------------------------------------------------------------------------
56 56
57 57 from IPython.utils.localinterfaces import LOCALHOST, LOCAL_IPS
58 58
59 59 #-----------------------------------------------------------------------------
60 60 # Globals
61 61 #-----------------------------------------------------------------------------
62 62
63 63
64 64 #-----------------------------------------------------------------------------
65 65 # Aliases and Flags
66 66 #-----------------------------------------------------------------------------
67 67
68 68 flags = dict(kernel_flags)
69 69
70 70 # the flags that are specific to the frontend
71 71 # these must be scrubbed before being passed to the kernel,
72 72 # or it will raise an error on unrecognized flags
73 73 app_flags = {
74 74 'existing' : ({'IPythonConsoleApp' : {'existing' : 'kernel*.json'}},
75 75 "Connect to an existing kernel. If no argument specified, guess most recent"),
76 76 }
77 77 app_flags.update(boolean_flag(
78 78 'confirm-exit', 'IPythonConsoleApp.confirm_exit',
79 79 """Set to display confirmation dialog on exit. You can always use 'exit' or 'quit',
80 80 to force a direct exit without any confirmation.
81 81 """,
82 82 """Don't prompt the user when exiting. This will terminate the kernel
83 83 if it is owned by the frontend, and leave it alive if it is external.
84 84 """
85 85 ))
86 86 flags.update(app_flags)
87 87
88 88 aliases = dict(kernel_aliases)
89 89
90 90 # also scrub aliases from the frontend
91 91 app_aliases = dict(
92 92 ip = 'KernelManager.ip',
93 93 transport = 'KernelManager.transport',
94 94 hb = 'IPythonConsoleApp.hb_port',
95 95 shell = 'IPythonConsoleApp.shell_port',
96 96 iopub = 'IPythonConsoleApp.iopub_port',
97 97 stdin = 'IPythonConsoleApp.stdin_port',
98 98 existing = 'IPythonConsoleApp.existing',
99 99 f = 'IPythonConsoleApp.connection_file',
100 100
101 101
102 102 ssh = 'IPythonConsoleApp.sshserver',
103 103 )
104 104 aliases.update(app_aliases)
105 105
106 106 #-----------------------------------------------------------------------------
107 107 # Classes
108 108 #-----------------------------------------------------------------------------
109 109
110 110 #-----------------------------------------------------------------------------
111 111 # IPythonConsole
112 112 #-----------------------------------------------------------------------------
113 113
114 114 classes = [IPKernelApp, ZMQInteractiveShell, KernelManager, ProfileDir, Session]
115 115
116 116 try:
117 117 from IPython.kernel.zmq.pylab.backend_inline import InlineBackend
118 118 except ImportError:
119 119 pass
120 120 else:
121 121 classes.append(InlineBackend)
122 122
123 123 class IPythonConsoleApp(Configurable):
124 124 name = 'ipython-console-mixin'
125 125
126 126 description = """
127 127 The IPython Mixin Console.
128 128
129 129 This class contains the common portions of console client (QtConsole,
130 130 ZMQ-based terminal console, etc). It is not a full console, in that
131 131 launched terminal subprocesses will not be able to accept input.
132 132
133 133 The Console using this mixing supports various extra features beyond
134 134 the single-process Terminal IPython shell, such as connecting to
135 135 existing kernel, via:
136 136
137 137 ipython <appname> --existing
138 138
139 139 as well as tunnel via SSH
140 140
141 141 """
142 142
143 143 classes = classes
144 144 flags = Dict(flags)
145 145 aliases = Dict(aliases)
146 146 kernel_manager_class = KernelManager
147 147 kernel_client_class = BlockingKernelClient
148 148
149 149 kernel_argv = List(Unicode)
150 150 # frontend flags&aliases to be stripped when building kernel_argv
151 151 frontend_flags = Any(app_flags)
152 152 frontend_aliases = Any(app_aliases)
153 153
154 154 # create requested profiles by default, if they don't exist:
155 155 auto_create = CBool(True)
156 156 # connection info:
157 157
158 158 sshserver = Unicode('', config=True,
159 159 help="""The SSH server to use to connect to the kernel.""")
160 160 sshkey = Unicode('', config=True,
161 161 help="""Path to the ssh key to use for logging in to the ssh server.""")
162 162
163 163 hb_port = Int(0, config=True,
164 164 help="set the heartbeat port [default: random]")
165 165 shell_port = Int(0, config=True,
166 166 help="set the shell (ROUTER) port [default: random]")
167 167 iopub_port = Int(0, config=True,
168 168 help="set the iopub (PUB) port [default: random]")
169 169 stdin_port = Int(0, config=True,
170 170 help="set the stdin (DEALER) port [default: random]")
171 171 connection_file = Unicode('', config=True,
172 172 help="""JSON file in which to store connection info [default: kernel-<pid>.json]
173 173
174 174 This file will contain the IP, ports, and authentication key needed to connect
175 175 clients to this kernel. By default, this file will be created in the security-dir
176 176 of the current profile, but can be specified by absolute path.
177 177 """)
178 178 def _connection_file_default(self):
179 179 return 'kernel-%i.json' % os.getpid()
180 180
181 181 existing = CUnicode('', config=True,
182 182 help="""Connect to an already running kernel""")
183 183
184 184 confirm_exit = CBool(True, config=True,
185 185 help="""
186 186 Set to display confirmation dialog on exit. You can always use 'exit' or 'quit',
187 187 to force a direct exit without any confirmation.""",
188 188 )
189 189
190 190
191 191 def build_kernel_argv(self, argv=None):
192 192 """build argv to be passed to kernel subprocess"""
193 193 if argv is None:
194 194 argv = sys.argv[1:]
195 195 self.kernel_argv = swallow_argv(argv, self.frontend_aliases, self.frontend_flags)
196 196 # kernel should inherit default config file from frontend
197 197 self.kernel_argv.append("--IPKernelApp.parent_appname='%s'" % self.name)
198 198
199 199 def init_connection_file(self):
200 200 """find the connection file, and load the info if found.
201 201
202 202 The current working directory and the current profile's security
203 203 directory will be searched for the file if it is not given by
204 204 absolute path.
205 205
206 206 When attempting to connect to an existing kernel and the `--existing`
207 207 argument does not match an existing file, it will be interpreted as a
208 208 fileglob, and the matching file in the current profile's security dir
209 209 with the latest access time will be used.
210 210
211 211 After this method is called, self.connection_file contains the *full path*
212 212 to the connection file, never just its name.
213 213 """
214 214 if self.existing:
215 215 try:
216 216 cf = find_connection_file(self.existing)
217 217 except Exception:
218 218 self.log.critical("Could not find existing kernel connection file %s", self.existing)
219 219 self.exit(1)
220 220 self.log.info("Connecting to existing kernel: %s" % cf)
221 221 self.connection_file = cf
222 222 else:
223 223 # not existing, check if we are going to write the file
224 224 # and ensure that self.connection_file is a full path, not just the shortname
225 225 try:
226 226 cf = find_connection_file(self.connection_file)
227 227 except Exception:
228 228 # file might not exist
229 229 if self.connection_file == os.path.basename(self.connection_file):
230 230 # just shortname, put it in security dir
231 231 cf = os.path.join(self.profile_dir.security_dir, self.connection_file)
232 232 else:
233 233 cf = self.connection_file
234 234 self.connection_file = cf
235 235
236 236 # should load_connection_file only be used for existing?
237 237 # as it is now, this allows reusing ports if an existing
238 238 # file is requested
239 239 try:
240 240 self.load_connection_file()
241 241 except Exception:
242 242 self.log.error("Failed to load connection file: %r", self.connection_file, exc_info=True)
243 243 self.exit(1)
244 244
245 245 def load_connection_file(self):
246 246 """load ip/port/hmac config from JSON connection file"""
247 247 # this is identical to IPKernelApp.load_connection_file
248 248 # perhaps it can be centralized somewhere?
249 249 try:
250 250 fname = filefind(self.connection_file, ['.', self.profile_dir.security_dir])
251 251 except IOError:
252 252 self.log.debug("Connection File not found: %s", self.connection_file)
253 253 return
254 254 self.log.debug(u"Loading connection file %s", fname)
255 255 with open(fname) as f:
256 256 cfg = json.load(f)
257 257
258 258 self.config.KernelManager.transport = cfg.get('transport', 'tcp')
259 259 self.config.KernelManager.ip = cfg.get('ip', LOCALHOST)
260 260
261 261 for channel in ('hb', 'shell', 'iopub', 'stdin'):
262 262 name = channel + '_port'
263 263 if getattr(self, name) == 0 and name in cfg:
264 264 # not overridden by config or cl_args
265 265 setattr(self, name, cfg[name])
266 266 if 'key' in cfg:
267 267 self.config.Session.key = str_to_bytes(cfg['key'])
268 if 'signature_scheme' in cfg:
269 self.config.Session.signature_scheme = cfg['signature_scheme']
268 270
269 271 def init_ssh(self):
270 272 """set up ssh tunnels, if needed."""
271 273 if not self.existing or (not self.sshserver and not self.sshkey):
272 274 return
273 275
274 276 self.load_connection_file()
275 277
276 278 transport = self.config.KernelManager.transport
277 279 ip = self.config.KernelManager.ip
278 280
279 281 if transport != 'tcp':
280 282 self.log.error("Can only use ssh tunnels with TCP sockets, not %s", transport)
281 283 sys.exit(-1)
282 284
283 285 if self.sshkey and not self.sshserver:
284 286 # specifying just the key implies that we are connecting directly
285 287 self.sshserver = ip
286 288 ip = LOCALHOST
287 289
288 290 # build connection dict for tunnels:
289 291 info = dict(ip=ip,
290 292 shell_port=self.shell_port,
291 293 iopub_port=self.iopub_port,
292 294 stdin_port=self.stdin_port,
293 295 hb_port=self.hb_port
294 296 )
295 297
296 298 self.log.info("Forwarding connections to %s via %s"%(ip, self.sshserver))
297 299
298 300 # tunnels return a new set of ports, which will be on localhost:
299 301 self.config.KernelManager.ip = LOCALHOST
300 302 try:
301 303 newports = tunnel_to_kernel(info, self.sshserver, self.sshkey)
302 304 except:
303 305 # even catch KeyboardInterrupt
304 306 self.log.error("Could not setup tunnels", exc_info=True)
305 307 self.exit(1)
306 308
307 309 self.shell_port, self.iopub_port, self.stdin_port, self.hb_port = newports
308 310
309 311 cf = self.connection_file
310 312 base,ext = os.path.splitext(cf)
311 313 base = os.path.basename(base)
312 314 self.connection_file = os.path.basename(base)+'-ssh'+ext
313 315 self.log.critical("To connect another client via this tunnel, use:")
314 316 self.log.critical("--existing %s" % self.connection_file)
315 317
316 318 def _new_connection_file(self):
317 319 cf = ''
318 320 while not cf:
319 321 # we don't need a 128b id to distinguish kernels, use more readable
320 322 # 48b node segment (12 hex chars). Users running more than 32k simultaneous
321 323 # kernels can subclass.
322 324 ident = str(uuid.uuid4()).split('-')[-1]
323 325 cf = os.path.join(self.profile_dir.security_dir, 'kernel-%s.json' % ident)
324 326 # only keep if it's actually new. Protect against unlikely collision
325 327 # in 48b random search space
326 328 cf = cf if not os.path.exists(cf) else ''
327 329 return cf
328 330
329 331 def init_kernel_manager(self):
330 332 # Don't let Qt or ZMQ swallow KeyboardInterupts.
331 333 if self.existing:
332 334 self.kernel_manager = None
333 335 return
334 336 signal.signal(signal.SIGINT, signal.SIG_DFL)
335 337
336 338 # Create a KernelManager and start a kernel.
337 339 self.kernel_manager = self.kernel_manager_class(
338 340 shell_port=self.shell_port,
339 341 iopub_port=self.iopub_port,
340 342 stdin_port=self.stdin_port,
341 343 hb_port=self.hb_port,
342 344 connection_file=self.connection_file,
343 345 parent=self,
344 346 )
345 347 self.kernel_manager.client_factory = self.kernel_client_class
346 348 self.kernel_manager.start_kernel(extra_arguments=self.kernel_argv)
347 349 atexit.register(self.kernel_manager.cleanup_ipc_files)
348 350
349 351 if self.sshserver:
350 352 # ssh, write new connection file
351 353 self.kernel_manager.write_connection_file()
352 354
353 355 # in case KM defaults / ssh writing changes things:
354 356 km = self.kernel_manager
355 357 self.shell_port=km.shell_port
356 358 self.iopub_port=km.iopub_port
357 359 self.stdin_port=km.stdin_port
358 360 self.hb_port=km.hb_port
359 361 self.connection_file = km.connection_file
360 362
361 363 atexit.register(self.kernel_manager.cleanup_connection_file)
362 364
363 365 def init_kernel_client(self):
364 366 if self.kernel_manager is not None:
365 367 self.kernel_client = self.kernel_manager.client()
366 368 else:
367 369 self.kernel_client = self.kernel_client_class(
368 370 shell_port=self.shell_port,
369 371 iopub_port=self.iopub_port,
370 372 stdin_port=self.stdin_port,
371 373 hb_port=self.hb_port,
372 374 connection_file=self.connection_file,
373 375 parent=self,
374 376 )
375 377
376 378 self.kernel_client.start_channels()
377 379
378 380
379 381
380 382 def initialize(self, argv=None):
381 383 """
382 384 Classes which mix this class in should call:
383 385 IPythonConsoleApp.initialize(self,argv)
384 386 """
385 387 self.init_connection_file()
386 388 default_secure(self.config)
387 389 self.init_ssh()
388 390 self.init_kernel_manager()
389 391 self.init_kernel_client()
390 392
@@ -1,540 +1,557 b''
1 1 """Utilities for connecting to kernels
2 2
3 3 Authors:
4 4
5 5 * Min Ragan-Kelley
6 6
7 7 """
8 8
9 9 #-----------------------------------------------------------------------------
10 10 # Copyright (C) 2013 The IPython Development Team
11 11 #
12 12 # Distributed under the terms of the BSD License. The full license is in
13 13 # the file COPYING, distributed as part of this software.
14 14 #-----------------------------------------------------------------------------
15 15
16 16 #-----------------------------------------------------------------------------
17 17 # Imports
18 18 #-----------------------------------------------------------------------------
19 19
20 20 from __future__ import absolute_import
21 21
22 22 import glob
23 23 import json
24 24 import os
25 25 import socket
26 26 import sys
27 27 from getpass import getpass
28 28 from subprocess import Popen, PIPE
29 29 import tempfile
30 30
31 31 import zmq
32 32
33 33 # external imports
34 34 from IPython.external.ssh import tunnel
35 35
36 36 # IPython imports
37 37 # from IPython.config import Configurable
38 38 from IPython.core.profiledir import ProfileDir
39 39 from IPython.utils.localinterfaces import LOCALHOST
40 40 from IPython.utils.path import filefind, get_ipython_dir
41 41 from IPython.utils.py3compat import str_to_bytes, bytes_to_str
42 42 from IPython.utils.traitlets import (
43 43 Bool, Integer, Unicode, CaselessStrEnum,
44 44 HasTraits,
45 45 )
46 46
47 47
48 48 #-----------------------------------------------------------------------------
49 49 # Working with Connection Files
50 50 #-----------------------------------------------------------------------------
51 51
52 52 def write_connection_file(fname=None, shell_port=0, iopub_port=0, stdin_port=0, hb_port=0,
53 control_port=0, ip=LOCALHOST, key=b'', transport='tcp'):
53 control_port=0, ip=LOCALHOST, key=b'', transport='tcp',
54 signature_scheme='hmac-sha256',
55 ):
54 56 """Generates a JSON config file, including the selection of random ports.
55 57
56 58 Parameters
57 59 ----------
58 60
59 61 fname : unicode
60 62 The path to the file to write
61 63
62 64 shell_port : int, optional
63 65 The port to use for ROUTER (shell) channel.
64 66
65 67 iopub_port : int, optional
66 68 The port to use for the SUB channel.
67 69
68 70 stdin_port : int, optional
69 71 The port to use for the ROUTER (raw input) channel.
70 72
71 73 control_port : int, optional
72 74 The port to use for the ROUTER (control) channel.
73 75
74 76 hb_port : int, optional
75 77 The port to use for the heartbeat REP channel.
76 78
77 79 ip : str, optional
78 80 The ip address the kernel will bind to.
79 81
80 82 key : str, optional
81 The Session key used for HMAC authentication.
83 The Session key used for message authentication.
84
85 signature_scheme : str, optional
86 The scheme used for message authentication.
87 This has the form 'digest-hash', where 'digest'
88 is the scheme used for digests, and 'hash' is the name of the hash function
89 used by the digest scheme.
90 Currently, 'hmac' is the only supported digest scheme,
91 and 'sha256' is the default hash function.
82 92
83 93 """
84 94 # default to temporary connector file
85 95 if not fname:
86 96 fname = tempfile.mktemp('.json')
87 97
88 98 # Find open ports as necessary.
89 99
90 100 ports = []
91 101 ports_needed = int(shell_port <= 0) + \
92 102 int(iopub_port <= 0) + \
93 103 int(stdin_port <= 0) + \
94 104 int(control_port <= 0) + \
95 105 int(hb_port <= 0)
96 106 if transport == 'tcp':
97 107 for i in range(ports_needed):
98 108 sock = socket.socket()
99 109 sock.bind(('', 0))
100 110 ports.append(sock)
101 111 for i, sock in enumerate(ports):
102 112 port = sock.getsockname()[1]
103 113 sock.close()
104 114 ports[i] = port
105 115 else:
106 116 N = 1
107 117 for i in range(ports_needed):
108 118 while os.path.exists("%s-%s" % (ip, str(N))):
109 119 N += 1
110 120 ports.append(N)
111 121 N += 1
112 122 if shell_port <= 0:
113 123 shell_port = ports.pop(0)
114 124 if iopub_port <= 0:
115 125 iopub_port = ports.pop(0)
116 126 if stdin_port <= 0:
117 127 stdin_port = ports.pop(0)
118 128 if control_port <= 0:
119 129 control_port = ports.pop(0)
120 130 if hb_port <= 0:
121 131 hb_port = ports.pop(0)
122 132
123 133 cfg = dict( shell_port=shell_port,
124 134 iopub_port=iopub_port,
125 135 stdin_port=stdin_port,
126 136 control_port=control_port,
127 137 hb_port=hb_port,
128 138 )
129 139 cfg['ip'] = ip
130 140 cfg['key'] = bytes_to_str(key)
131 141 cfg['transport'] = transport
142 cfg['signature_scheme'] = signature_scheme
132 143
133 144 with open(fname, 'w') as f:
134 145 f.write(json.dumps(cfg, indent=2))
135 146
136 147 return fname, cfg
137 148
138 149
139 150 def get_connection_file(app=None):
140 151 """Return the path to the connection file of an app
141 152
142 153 Parameters
143 154 ----------
144 155 app : IPKernelApp instance [optional]
145 156 If unspecified, the currently running app will be used
146 157 """
147 158 if app is None:
148 159 from IPython.kernel.zmq.kernelapp import IPKernelApp
149 160 if not IPKernelApp.initialized():
150 161 raise RuntimeError("app not specified, and not in a running Kernel")
151 162
152 163 app = IPKernelApp.instance()
153 164 return filefind(app.connection_file, ['.', app.profile_dir.security_dir])
154 165
155 166
156 167 def find_connection_file(filename, profile=None):
157 168 """find a connection file, and return its absolute path.
158 169
159 170 The current working directory and the profile's security
160 171 directory will be searched for the file if it is not given by
161 172 absolute path.
162 173
163 174 If profile is unspecified, then the current running application's
164 175 profile will be used, or 'default', if not run from IPython.
165 176
166 177 If the argument does not match an existing file, it will be interpreted as a
167 178 fileglob, and the matching file in the profile's security dir with
168 179 the latest access time will be used.
169 180
170 181 Parameters
171 182 ----------
172 183 filename : str
173 184 The connection file or fileglob to search for.
174 185 profile : str [optional]
175 186 The name of the profile to use when searching for the connection file,
176 187 if different from the current IPython session or 'default'.
177 188
178 189 Returns
179 190 -------
180 191 str : The absolute path of the connection file.
181 192 """
182 193 from IPython.core.application import BaseIPythonApplication as IPApp
183 194 try:
184 195 # quick check for absolute path, before going through logic
185 196 return filefind(filename)
186 197 except IOError:
187 198 pass
188 199
189 200 if profile is None:
190 201 # profile unspecified, check if running from an IPython app
191 202 if IPApp.initialized():
192 203 app = IPApp.instance()
193 204 profile_dir = app.profile_dir
194 205 else:
195 206 # not running in IPython, use default profile
196 207 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), 'default')
197 208 else:
198 209 # find profiledir by profile name:
199 210 profile_dir = ProfileDir.find_profile_dir_by_name(get_ipython_dir(), profile)
200 211 security_dir = profile_dir.security_dir
201 212
202 213 try:
203 214 # first, try explicit name
204 215 return filefind(filename, ['.', security_dir])
205 216 except IOError:
206 217 pass
207 218
208 219 # not found by full name
209 220
210 221 if '*' in filename:
211 222 # given as a glob already
212 223 pat = filename
213 224 else:
214 225 # accept any substring match
215 226 pat = '*%s*' % filename
216 227 matches = glob.glob( os.path.join(security_dir, pat) )
217 228 if not matches:
218 229 raise IOError("Could not find %r in %r" % (filename, security_dir))
219 230 elif len(matches) == 1:
220 231 return matches[0]
221 232 else:
222 233 # get most recent match, by access time:
223 234 return sorted(matches, key=lambda f: os.stat(f).st_atime)[-1]
224 235
225 236
226 237 def get_connection_info(connection_file=None, unpack=False, profile=None):
227 238 """Return the connection information for the current Kernel.
228 239
229 240 Parameters
230 241 ----------
231 242 connection_file : str [optional]
232 243 The connection file to be used. Can be given by absolute path, or
233 244 IPython will search in the security directory of a given profile.
234 245 If run from IPython,
235 246
236 247 If unspecified, the connection file for the currently running
237 248 IPython Kernel will be used, which is only allowed from inside a kernel.
238 249 unpack : bool [default: False]
239 250 if True, return the unpacked dict, otherwise just the string contents
240 251 of the file.
241 252 profile : str [optional]
242 253 The name of the profile to use when searching for the connection file,
243 254 if different from the current IPython session or 'default'.
244 255
245 256
246 257 Returns
247 258 -------
248 259 The connection dictionary of the current kernel, as string or dict,
249 260 depending on `unpack`.
250 261 """
251 262 if connection_file is None:
252 263 # get connection file from current kernel
253 264 cf = get_connection_file()
254 265 else:
255 266 # connection file specified, allow shortnames:
256 267 cf = find_connection_file(connection_file, profile=profile)
257 268
258 269 with open(cf) as f:
259 270 info = f.read()
260 271
261 272 if unpack:
262 273 info = json.loads(info)
263 274 # ensure key is bytes:
264 275 info['key'] = str_to_bytes(info.get('key', ''))
265 276 return info
266 277
267 278
268 279 def connect_qtconsole(connection_file=None, argv=None, profile=None):
269 280 """Connect a qtconsole to the current kernel.
270 281
271 282 This is useful for connecting a second qtconsole to a kernel, or to a
272 283 local notebook.
273 284
274 285 Parameters
275 286 ----------
276 287 connection_file : str [optional]
277 288 The connection file to be used. Can be given by absolute path, or
278 289 IPython will search in the security directory of a given profile.
279 290 If run from IPython,
280 291
281 292 If unspecified, the connection file for the currently running
282 293 IPython Kernel will be used, which is only allowed from inside a kernel.
283 294 argv : list [optional]
284 295 Any extra args to be passed to the console.
285 296 profile : str [optional]
286 297 The name of the profile to use when searching for the connection file,
287 298 if different from the current IPython session or 'default'.
288 299
289 300
290 301 Returns
291 302 -------
292 303 subprocess.Popen instance running the qtconsole frontend
293 304 """
294 305 argv = [] if argv is None else argv
295 306
296 307 if connection_file is None:
297 308 # get connection file from current kernel
298 309 cf = get_connection_file()
299 310 else:
300 311 cf = find_connection_file(connection_file, profile=profile)
301 312
302 313 cmd = ';'.join([
303 314 "from IPython.qt.console import qtconsoleapp",
304 315 "qtconsoleapp.main()"
305 316 ])
306 317
307 318 return Popen([sys.executable, '-c', cmd, '--existing', cf] + argv,
308 319 stdout=PIPE, stderr=PIPE, close_fds=True,
309 320 )
310 321
311 322
312 323 def tunnel_to_kernel(connection_info, sshserver, sshkey=None):
313 324 """tunnel connections to a kernel via ssh
314 325
315 326 This will open four SSH tunnels from localhost on this machine to the
316 327 ports associated with the kernel. They can be either direct
317 328 localhost-localhost tunnels, or if an intermediate server is necessary,
318 329 the kernel must be listening on a public IP.
319 330
320 331 Parameters
321 332 ----------
322 333 connection_info : dict or str (path)
323 334 Either a connection dict, or the path to a JSON connection file
324 335 sshserver : str
325 336 The ssh sever to use to tunnel to the kernel. Can be a full
326 337 `user@server:port` string. ssh config aliases are respected.
327 338 sshkey : str [optional]
328 339 Path to file containing ssh key to use for authentication.
329 340 Only necessary if your ssh config does not already associate
330 341 a keyfile with the host.
331 342
332 343 Returns
333 344 -------
334 345
335 346 (shell, iopub, stdin, hb) : ints
336 347 The four ports on localhost that have been forwarded to the kernel.
337 348 """
338 349 if isinstance(connection_info, basestring):
339 350 # it's a path, unpack it
340 351 with open(connection_info) as f:
341 352 connection_info = json.loads(f.read())
342 353
343 354 cf = connection_info
344 355
345 356 lports = tunnel.select_random_ports(4)
346 357 rports = cf['shell_port'], cf['iopub_port'], cf['stdin_port'], cf['hb_port']
347 358
348 359 remote_ip = cf['ip']
349 360
350 361 if tunnel.try_passwordless_ssh(sshserver, sshkey):
351 362 password=False
352 363 else:
353 364 password = getpass("SSH Password for %s: "%sshserver)
354 365
355 366 for lp,rp in zip(lports, rports):
356 367 tunnel.ssh_tunnel(lp, rp, sshserver, remote_ip, sshkey, password)
357 368
358 369 return tuple(lports)
359 370
360 371
361 372 #-----------------------------------------------------------------------------
362 373 # Mixin for classes that work with connection files
363 374 #-----------------------------------------------------------------------------
364 375
365 376 channel_socket_types = {
366 377 'hb' : zmq.REQ,
367 378 'shell' : zmq.DEALER,
368 379 'iopub' : zmq.SUB,
369 380 'stdin' : zmq.DEALER,
370 381 'control': zmq.DEALER,
371 382 }
372 383
373 384 port_names = [ "%s_port" % channel for channel in ('shell', 'stdin', 'iopub', 'hb', 'control')]
374 385
375 386 class ConnectionFileMixin(HasTraits):
376 387 """Mixin for configurable classes that work with connection files"""
377 388
378 389 # The addresses for the communication channels
379 390 connection_file = Unicode('')
380 391 _connection_file_written = Bool(False)
381 392
382 393 transport = CaselessStrEnum(['tcp', 'ipc'], default_value='tcp', config=True)
394 signature_scheme = Unicode('')
383 395
384 396 ip = Unicode(LOCALHOST, config=True,
385 397 help="""Set the kernel\'s IP address [default localhost].
386 398 If the IP address is something other than localhost, then
387 399 Consoles on other machines will be able to connect
388 400 to the Kernel, so be careful!"""
389 401 )
390 402
391 403 def _ip_default(self):
392 404 if self.transport == 'ipc':
393 405 if self.connection_file:
394 406 return os.path.splitext(self.connection_file)[0] + '-ipc'
395 407 else:
396 408 return 'kernel-ipc'
397 409 else:
398 410 return LOCALHOST
399 411
400 412 def _ip_changed(self, name, old, new):
401 413 if new == '*':
402 414 self.ip = '0.0.0.0'
403 415
404 416 # protected traits
405 417
406 418 shell_port = Integer(0)
407 419 iopub_port = Integer(0)
408 420 stdin_port = Integer(0)
409 421 control_port = Integer(0)
410 422 hb_port = Integer(0)
411 423
412 424 @property
413 425 def ports(self):
414 426 return [ getattr(self, name) for name in port_names ]
415 427
416 428 #--------------------------------------------------------------------------
417 429 # Connection and ipc file management
418 430 #--------------------------------------------------------------------------
419 431
420 432 def get_connection_info(self):
421 433 """return the connection info as a dict"""
422 434 return dict(
423 435 transport=self.transport,
424 436 ip=self.ip,
425 437 shell_port=self.shell_port,
426 438 iopub_port=self.iopub_port,
427 439 stdin_port=self.stdin_port,
428 440 hb_port=self.hb_port,
429 441 control_port=self.control_port,
442 signature_scheme=self.signature_scheme,
430 443 )
431 444
432 445 def cleanup_connection_file(self):
433 446 """Cleanup connection file *if we wrote it*
434 447
435 448 Will not raise if the connection file was already removed somehow.
436 449 """
437 450 if self._connection_file_written:
438 451 # cleanup connection files on full shutdown of kernel we started
439 452 self._connection_file_written = False
440 453 try:
441 454 os.remove(self.connection_file)
442 455 except (IOError, OSError, AttributeError):
443 456 pass
444 457
445 458 def cleanup_ipc_files(self):
446 459 """Cleanup ipc files if we wrote them."""
447 460 if self.transport != 'ipc':
448 461 return
449 462 for port in self.ports:
450 463 ipcfile = "%s-%i" % (self.ip, port)
451 464 try:
452 465 os.remove(ipcfile)
453 466 except (IOError, OSError):
454 467 pass
455 468
456 469 def write_connection_file(self):
457 470 """Write connection info to JSON dict in self.connection_file."""
458 471 if self._connection_file_written:
459 472 return
460 473
461 474 self.connection_file, cfg = write_connection_file(self.connection_file,
462 475 transport=self.transport, ip=self.ip, key=self.session.key,
463 476 stdin_port=self.stdin_port, iopub_port=self.iopub_port,
464 477 shell_port=self.shell_port, hb_port=self.hb_port,
465 478 control_port=self.control_port,
479 signature_scheme=self.signature_scheme,
466 480 )
467 481 # write_connection_file also sets default ports:
468 482 for name in port_names:
469 483 setattr(self, name, cfg[name])
470 484
471 485 self._connection_file_written = True
472 486
473 487 def load_connection_file(self):
474 488 """Load connection info from JSON dict in self.connection_file."""
475 489 with open(self.connection_file) as f:
476 490 cfg = json.loads(f.read())
477 491
478 492 self.transport = cfg.get('transport', 'tcp')
479 493 self.ip = cfg['ip']
480 494 for name in port_names:
481 495 setattr(self, name, cfg[name])
496 if 'key' in cfg:
482 497 self.session.key = str_to_bytes(cfg['key'])
498 if cfg.get('signature_scheme'):
499 self.session.signature_scheme = cfg['signature_scheme']
483 500
484 501 #--------------------------------------------------------------------------
485 502 # Creating connected sockets
486 503 #--------------------------------------------------------------------------
487 504
488 505 def _make_url(self, channel):
489 506 """Make a ZeroMQ URL for a given channel."""
490 507 transport = self.transport
491 508 ip = self.ip
492 509 port = getattr(self, '%s_port' % channel)
493 510
494 511 if transport == 'tcp':
495 512 return "tcp://%s:%i" % (ip, port)
496 513 else:
497 514 return "%s://%s-%s" % (transport, ip, port)
498 515
499 516 def _create_connected_socket(self, channel, identity=None):
500 517 """Create a zmq Socket and connect it to the kernel."""
501 518 url = self._make_url(channel)
502 519 socket_type = channel_socket_types[channel]
503 520 self.log.info("Connecting to: %s" % url)
504 521 sock = self.context.socket(socket_type)
505 522 if identity:
506 523 sock.identity = identity
507 524 sock.connect(url)
508 525 return sock
509 526
510 527 def connect_iopub(self, identity=None):
511 528 """return zmq Socket connected to the IOPub channel"""
512 529 sock = self._create_connected_socket('iopub', identity=identity)
513 530 sock.setsockopt(zmq.SUBSCRIBE, b'')
514 531 return sock
515 532
516 533 def connect_shell(self, identity=None):
517 534 """return zmq Socket connected to the Shell channel"""
518 535 return self._create_connected_socket('shell', identity=identity)
519 536
520 537 def connect_stdin(self, identity=None):
521 538 """return zmq Socket connected to the StdIn channel"""
522 539 return self._create_connected_socket('stdin', identity=identity)
523 540
524 541 def connect_hb(self, identity=None):
525 542 """return zmq Socket connected to the Heartbeat channel"""
526 543 return self._create_connected_socket('hb', identity=identity)
527 544
528 545 def connect_control(self, identity=None):
529 546 """return zmq Socket connected to the Heartbeat channel"""
530 547 return self._create_connected_socket('control', identity=identity)
531 548
532 549
533 550 __all__ = [
534 551 'write_connection_file',
535 552 'get_connection_file',
536 553 'find_connection_file',
537 554 'get_connection_info',
538 555 'connect_qtconsole',
539 556 'tunnel_to_kernel',
540 557 ]
@@ -1,806 +1,830 b''
1 1 """Session object for building, serializing, sending, and receiving messages in
2 2 IPython. The Session object supports serialization, HMAC signatures, and
3 3 metadata on messages.
4 4
5 5 Also defined here are utilities for working with Sessions:
6 6 * A SessionFactory to be used as a base class for configurables that work with
7 7 Sessions.
8 8 * A Message object for convenience that allows attribute-access to the msg dict.
9 9
10 10 Authors:
11 11
12 12 * Min RK
13 13 * Brian Granger
14 14 * Fernando Perez
15 15 """
16 16 #-----------------------------------------------------------------------------
17 17 # Copyright (C) 2010-2011 The IPython Development Team
18 18 #
19 19 # Distributed under the terms of the BSD License. The full license is in
20 20 # the file COPYING, distributed as part of this software.
21 21 #-----------------------------------------------------------------------------
22 22
23 23 #-----------------------------------------------------------------------------
24 24 # Imports
25 25 #-----------------------------------------------------------------------------
26 26
27 import hashlib
27 28 import hmac
28 29 import logging
29 30 import os
30 31 import pprint
31 32 import random
32 33 import uuid
33 34 from datetime import datetime
34 35
35 36 try:
36 37 import cPickle
37 38 pickle = cPickle
38 39 except:
39 40 cPickle = None
40 41 import pickle
41 42
42 43 import zmq
43 44 from zmq.utils import jsonapi
44 45 from zmq.eventloop.ioloop import IOLoop
45 46 from zmq.eventloop.zmqstream import ZMQStream
46 47
47 48 from IPython.config.configurable import Configurable, LoggingConfigurable
48 49 from IPython.utils import io
49 50 from IPython.utils.importstring import import_item
50 51 from IPython.utils.jsonutil import extract_dates, squash_dates, date_default
51 52 from IPython.utils.py3compat import str_to_bytes, str_to_unicode
52 53 from IPython.utils.traitlets import (CBytes, Unicode, Bool, Any, Instance, Set,
53 DottedObjectName, CUnicode, Dict, Integer)
54 DottedObjectName, CUnicode, Dict, Integer,
55 TraitError,
56 )
54 57 from IPython.kernel.zmq.serialize import MAX_ITEMS, MAX_BYTES
55 58
56 59 #-----------------------------------------------------------------------------
57 60 # utility functions
58 61 #-----------------------------------------------------------------------------
59 62
60 63 def squash_unicode(obj):
61 64 """coerce unicode back to bytestrings."""
62 65 if isinstance(obj,dict):
63 66 for key in obj.keys():
64 67 obj[key] = squash_unicode(obj[key])
65 68 if isinstance(key, unicode):
66 69 obj[squash_unicode(key)] = obj.pop(key)
67 70 elif isinstance(obj, list):
68 71 for i,v in enumerate(obj):
69 72 obj[i] = squash_unicode(v)
70 73 elif isinstance(obj, unicode):
71 74 obj = obj.encode('utf8')
72 75 return obj
73 76
74 77 #-----------------------------------------------------------------------------
75 78 # globals and defaults
76 79 #-----------------------------------------------------------------------------
77 80
78 81 # ISO8601-ify datetime objects
79 82 json_packer = lambda obj: jsonapi.dumps(obj, default=date_default)
80 83 json_unpacker = lambda s: extract_dates(jsonapi.loads(s))
81 84
82 85 pickle_packer = lambda o: pickle.dumps(o,-1)
83 86 pickle_unpacker = pickle.loads
84 87
85 88 default_packer = json_packer
86 89 default_unpacker = json_unpacker
87 90
88 91 DELIM = b"<IDS|MSG>"
89 92 # singleton dummy tracker, which will always report as done
90 93 DONE = zmq.MessageTracker()
91 94
92 95 #-----------------------------------------------------------------------------
93 96 # Mixin tools for apps that use Sessions
94 97 #-----------------------------------------------------------------------------
95 98
96 99 session_aliases = dict(
97 100 ident = 'Session.session',
98 101 user = 'Session.username',
99 102 keyfile = 'Session.keyfile',
100 103 )
101 104
102 105 session_flags = {
103 106 'secure' : ({'Session' : { 'key' : str_to_bytes(str(uuid.uuid4())),
104 107 'keyfile' : '' }},
105 108 """Use HMAC digests for authentication of messages.
106 109 Setting this flag will generate a new UUID to use as the HMAC key.
107 110 """),
108 111 'no-secure' : ({'Session' : { 'key' : b'', 'keyfile' : '' }},
109 112 """Don't authenticate messages."""),
110 113 }
111 114
112 115 def default_secure(cfg):
113 116 """Set the default behavior for a config environment to be secure.
114 117
115 118 If Session.key/keyfile have not been set, set Session.key to
116 119 a new random UUID.
117 120 """
118 121
119 122 if 'Session' in cfg:
120 123 if 'key' in cfg.Session or 'keyfile' in cfg.Session:
121 124 return
122 125 # key/keyfile not specified, generate new UUID:
123 126 cfg.Session.key = str_to_bytes(str(uuid.uuid4()))
124 127
125 128
126 129 #-----------------------------------------------------------------------------
127 130 # Classes
128 131 #-----------------------------------------------------------------------------
129 132
130 133 class SessionFactory(LoggingConfigurable):
131 134 """The Base class for configurables that have a Session, Context, logger,
132 135 and IOLoop.
133 136 """
134 137
135 138 logname = Unicode('')
136 139 def _logname_changed(self, name, old, new):
137 140 self.log = logging.getLogger(new)
138 141
139 142 # not configurable:
140 143 context = Instance('zmq.Context')
141 144 def _context_default(self):
142 145 return zmq.Context.instance()
143 146
144 147 session = Instance('IPython.kernel.zmq.session.Session')
145 148
146 149 loop = Instance('zmq.eventloop.ioloop.IOLoop', allow_none=False)
147 150 def _loop_default(self):
148 151 return IOLoop.instance()
149 152
150 153 def __init__(self, **kwargs):
151 154 super(SessionFactory, self).__init__(**kwargs)
152 155
153 156 if self.session is None:
154 157 # construct the session
155 158 self.session = Session(**kwargs)
156 159
157 160
158 161 class Message(object):
159 162 """A simple message object that maps dict keys to attributes.
160 163
161 164 A Message can be created from a dict and a dict from a Message instance
162 165 simply by calling dict(msg_obj)."""
163 166
164 167 def __init__(self, msg_dict):
165 168 dct = self.__dict__
166 169 for k, v in dict(msg_dict).iteritems():
167 170 if isinstance(v, dict):
168 171 v = Message(v)
169 172 dct[k] = v
170 173
171 174 # Having this iterator lets dict(msg_obj) work out of the box.
172 175 def __iter__(self):
173 176 return iter(self.__dict__.iteritems())
174 177
175 178 def __repr__(self):
176 179 return repr(self.__dict__)
177 180
178 181 def __str__(self):
179 182 return pprint.pformat(self.__dict__)
180 183
181 184 def __contains__(self, k):
182 185 return k in self.__dict__
183 186
184 187 def __getitem__(self, k):
185 188 return self.__dict__[k]
186 189
187 190
188 191 def msg_header(msg_id, msg_type, username, session):
189 192 date = datetime.now()
190 193 return locals()
191 194
192 195 def extract_header(msg_or_header):
193 196 """Given a message or header, return the header."""
194 197 if not msg_or_header:
195 198 return {}
196 199 try:
197 200 # See if msg_or_header is the entire message.
198 201 h = msg_or_header['header']
199 202 except KeyError:
200 203 try:
201 204 # See if msg_or_header is just the header
202 205 h = msg_or_header['msg_id']
203 206 except KeyError:
204 207 raise
205 208 else:
206 209 h = msg_or_header
207 210 if not isinstance(h, dict):
208 211 h = dict(h)
209 212 return h
210 213
211 214 class Session(Configurable):
212 215 """Object for handling serialization and sending of messages.
213 216
214 217 The Session object handles building messages and sending them
215 218 with ZMQ sockets or ZMQStream objects. Objects can communicate with each
216 219 other over the network via Session objects, and only need to work with the
217 220 dict-based IPython message spec. The Session will handle
218 221 serialization/deserialization, security, and metadata.
219 222
220 223 Sessions support configurable serialiization via packer/unpacker traits,
221 224 and signing with HMAC digests via the key/keyfile traits.
222 225
223 226 Parameters
224 227 ----------
225 228
226 229 debug : bool
227 230 whether to trigger extra debugging statements
228 231 packer/unpacker : str : 'json', 'pickle' or import_string
229 232 importstrings for methods to serialize message parts. If just
230 233 'json' or 'pickle', predefined JSON and pickle packers will be used.
231 234 Otherwise, the entire importstring must be used.
232 235
233 236 The functions must accept at least valid JSON input, and output *bytes*.
234 237
235 238 For example, to use msgpack:
236 239 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
237 240 pack/unpack : callables
238 241 You can also set the pack/unpack callables for serialization directly.
239 242 session : bytes
240 243 the ID of this Session object. The default is to generate a new UUID.
241 244 username : unicode
242 245 username added to message headers. The default is to ask the OS.
243 246 key : bytes
244 247 The key used to initialize an HMAC signature. If unset, messages
245 248 will not be signed or checked.
246 249 keyfile : filepath
247 250 The file containing a key. If this is set, `key` will be initialized
248 251 to the contents of the file.
249 252
250 253 """
251 254
252 255 debug=Bool(False, config=True, help="""Debug output in the Session""")
253 256
254 257 packer = DottedObjectName('json',config=True,
255 258 help="""The name of the packer for serializing messages.
256 259 Should be one of 'json', 'pickle', or an import name
257 260 for a custom callable serializer.""")
258 261 def _packer_changed(self, name, old, new):
259 262 if new.lower() == 'json':
260 263 self.pack = json_packer
261 264 self.unpack = json_unpacker
262 265 self.unpacker = new
263 266 elif new.lower() == 'pickle':
264 267 self.pack = pickle_packer
265 268 self.unpack = pickle_unpacker
266 269 self.unpacker = new
267 270 else:
268 271 self.pack = import_item(str(new))
269 272
270 273 unpacker = DottedObjectName('json', config=True,
271 274 help="""The name of the unpacker for unserializing messages.
272 275 Only used with custom functions for `packer`.""")
273 276 def _unpacker_changed(self, name, old, new):
274 277 if new.lower() == 'json':
275 278 self.pack = json_packer
276 279 self.unpack = json_unpacker
277 280 self.packer = new
278 281 elif new.lower() == 'pickle':
279 282 self.pack = pickle_packer
280 283 self.unpack = pickle_unpacker
281 284 self.packer = new
282 285 else:
283 286 self.unpack = import_item(str(new))
284 287
285 288 session = CUnicode(u'', config=True,
286 289 help="""The UUID identifying this session.""")
287 290 def _session_default(self):
288 291 u = unicode(uuid.uuid4())
289 292 self.bsession = u.encode('ascii')
290 293 return u
291 294
292 295 def _session_changed(self, name, old, new):
293 296 self.bsession = self.session.encode('ascii')
294 297
295 298 # bsession is the session as bytes
296 299 bsession = CBytes(b'')
297 300
298 301 username = Unicode(str_to_unicode(os.environ.get('USER', 'username')),
299 302 help="""Username for the Session. Default is your system username.""",
300 303 config=True)
301 304
302 305 metadata = Dict({}, config=True,
303 306 help="""Metadata dictionary, which serves as the default top-level metadata dict for each message.""")
304 307
305 308 # message signature related traits:
306 309
307 310 key = CBytes(b'', config=True,
308 311 help="""execution key, for extra authentication.""")
309 312 def _key_changed(self, name, old, new):
310 313 if new:
311 self.auth = hmac.HMAC(new)
314 self.auth = hmac.HMAC(new, digestmod=self.digest_mod)
312 315 else:
313 316 self.auth = None
314 317
318 signature_scheme = Unicode('hmac-sha256', config=True,
319 help="""The digest scheme used to construct the message signatures.
320 Must have the form 'hmac-HASH'.""")
321 def _signature_scheme_changed(self, name, old, new):
322 if not new.startswith('hmac-'):
323 raise TraitError("signature_scheme must start with 'hmac-', got %r" % new)
324 hash_name = new.split('-', 1)[1]
325 try:
326 self.digest_mod = getattr(hashlib, hash_name)
327 except AttributeError:
328 raise TraitError("hashlib has no such attribute: %s" % hash_name)
329
330 digest_mod = Any()
331 def _digest_mod_default(self):
332 return hashlib.sha256
333
315 334 auth = Instance(hmac.HMAC)
316 335
317 336 digest_history = Set()
318 337 digest_history_size = Integer(2**16, config=True,
319 338 help="""The maximum number of digests to remember.
320 339
321 340 The digest history will be culled when it exceeds this value.
322 341 """
323 342 )
324 343
325 344 keyfile = Unicode('', config=True,
326 345 help="""path to file containing execution key.""")
327 346 def _keyfile_changed(self, name, old, new):
328 347 with open(new, 'rb') as f:
329 348 self.key = f.read().strip()
330 349
331 350 # for protecting against sends from forks
332 351 pid = Integer()
333 352
334 353 # serialization traits:
335 354
336 355 pack = Any(default_packer) # the actual packer function
337 356 def _pack_changed(self, name, old, new):
338 357 if not callable(new):
339 358 raise TypeError("packer must be callable, not %s"%type(new))
340 359
341 360 unpack = Any(default_unpacker) # the actual packer function
342 361 def _unpack_changed(self, name, old, new):
343 362 # unpacker is not checked - it is assumed to be
344 363 if not callable(new):
345 364 raise TypeError("unpacker must be callable, not %s"%type(new))
346 365
347 366 # thresholds:
348 367 copy_threshold = Integer(2**16, config=True,
349 368 help="Threshold (in bytes) beyond which a buffer should be sent without copying.")
350 369 buffer_threshold = Integer(MAX_BYTES, config=True,
351 370 help="Threshold (in bytes) beyond which an object's buffer should be extracted to avoid pickling.")
352 371 item_threshold = Integer(MAX_ITEMS, config=True,
353 372 help="""The maximum number of items for a container to be introspected for custom serialization.
354 373 Containers larger than this are pickled outright.
355 374 """
356 375 )
357 376
358 377
359 378 def __init__(self, **kwargs):
360 379 """create a Session object
361 380
362 381 Parameters
363 382 ----------
364 383
365 384 debug : bool
366 385 whether to trigger extra debugging statements
367 386 packer/unpacker : str : 'json', 'pickle' or import_string
368 387 importstrings for methods to serialize message parts. If just
369 388 'json' or 'pickle', predefined JSON and pickle packers will be used.
370 389 Otherwise, the entire importstring must be used.
371 390
372 391 The functions must accept at least valid JSON input, and output
373 392 *bytes*.
374 393
375 394 For example, to use msgpack:
376 395 packer = 'msgpack.packb', unpacker='msgpack.unpackb'
377 396 pack/unpack : callables
378 397 You can also set the pack/unpack callables for serialization
379 398 directly.
380 399 session : unicode (must be ascii)
381 400 the ID of this Session object. The default is to generate a new
382 401 UUID.
383 402 bsession : bytes
384 403 The session as bytes
385 404 username : unicode
386 405 username added to message headers. The default is to ask the OS.
387 406 key : bytes
388 407 The key used to initialize an HMAC signature. If unset, messages
389 408 will not be signed or checked.
409 signature_scheme : str
410 The message digest scheme. Currently must be of the form 'hmac-HASH',
411 where 'HASH' is a hashing function available in Python's hashlib.
412 The default is 'hmac-sha256'.
413 This is ignored if 'key' is empty.
390 414 keyfile : filepath
391 415 The file containing a key. If this is set, `key` will be
392 416 initialized to the contents of the file.
393 417 """
394 418 super(Session, self).__init__(**kwargs)
395 419 self._check_packers()
396 420 self.none = self.pack({})
397 421 # ensure self._session_default() if necessary, so bsession is defined:
398 422 self.session
399 423 self.pid = os.getpid()
400 424
401 425 @property
402 426 def msg_id(self):
403 427 """always return new uuid"""
404 428 return str(uuid.uuid4())
405 429
406 430 def _check_packers(self):
407 431 """check packers for binary data and datetime support."""
408 432 pack = self.pack
409 433 unpack = self.unpack
410 434
411 435 # check simple serialization
412 436 msg = dict(a=[1,'hi'])
413 437 try:
414 438 packed = pack(msg)
415 439 except Exception:
416 440 raise ValueError("packer could not serialize a simple message")
417 441
418 442 # ensure packed message is bytes
419 443 if not isinstance(packed, bytes):
420 444 raise ValueError("message packed to %r, but bytes are required"%type(packed))
421 445
422 446 # check that unpack is pack's inverse
423 447 try:
424 448 unpacked = unpack(packed)
425 449 except Exception:
426 450 raise ValueError("unpacker could not handle the packer's output")
427 451
428 452 # check datetime support
429 453 msg = dict(t=datetime.now())
430 454 try:
431 455 unpacked = unpack(pack(msg))
432 456 except Exception:
433 457 self.pack = lambda o: pack(squash_dates(o))
434 458 self.unpack = lambda s: extract_dates(unpack(s))
435 459
436 460 def msg_header(self, msg_type):
437 461 return msg_header(self.msg_id, msg_type, self.username, self.session)
438 462
439 463 def msg(self, msg_type, content=None, parent=None, header=None, metadata=None):
440 464 """Return the nested message dict.
441 465
442 466 This format is different from what is sent over the wire. The
443 467 serialize/unserialize methods converts this nested message dict to the wire
444 468 format, which is a list of message parts.
445 469 """
446 470 msg = {}
447 471 header = self.msg_header(msg_type) if header is None else header
448 472 msg['header'] = header
449 473 msg['msg_id'] = header['msg_id']
450 474 msg['msg_type'] = header['msg_type']
451 475 msg['parent_header'] = {} if parent is None else extract_header(parent)
452 476 msg['content'] = {} if content is None else content
453 477 msg['metadata'] = self.metadata.copy()
454 478 if metadata is not None:
455 479 msg['metadata'].update(metadata)
456 480 return msg
457 481
458 482 def sign(self, msg_list):
459 483 """Sign a message with HMAC digest. If no auth, return b''.
460 484
461 485 Parameters
462 486 ----------
463 487 msg_list : list
464 488 The [p_header,p_parent,p_content] part of the message list.
465 489 """
466 490 if self.auth is None:
467 491 return b''
468 492 h = self.auth.copy()
469 493 for m in msg_list:
470 494 h.update(m)
471 495 return str_to_bytes(h.hexdigest())
472 496
473 497 def serialize(self, msg, ident=None):
474 498 """Serialize the message components to bytes.
475 499
476 500 This is roughly the inverse of unserialize. The serialize/unserialize
477 501 methods work with full message lists, whereas pack/unpack work with
478 502 the individual message parts in the message list.
479 503
480 504 Parameters
481 505 ----------
482 506 msg : dict or Message
483 507 The nexted message dict as returned by the self.msg method.
484 508
485 509 Returns
486 510 -------
487 511 msg_list : list
488 512 The list of bytes objects to be sent with the format:
489 513 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_metadata,p_content,
490 514 buffer1,buffer2,...]. In this list, the p_* entities are
491 515 the packed or serialized versions, so if JSON is used, these
492 516 are utf8 encoded JSON strings.
493 517 """
494 518 content = msg.get('content', {})
495 519 if content is None:
496 520 content = self.none
497 521 elif isinstance(content, dict):
498 522 content = self.pack(content)
499 523 elif isinstance(content, bytes):
500 524 # content is already packed, as in a relayed message
501 525 pass
502 526 elif isinstance(content, unicode):
503 527 # should be bytes, but JSON often spits out unicode
504 528 content = content.encode('utf8')
505 529 else:
506 530 raise TypeError("Content incorrect type: %s"%type(content))
507 531
508 532 real_message = [self.pack(msg['header']),
509 533 self.pack(msg['parent_header']),
510 534 self.pack(msg['metadata']),
511 535 content,
512 536 ]
513 537
514 538 to_send = []
515 539
516 540 if isinstance(ident, list):
517 541 # accept list of idents
518 542 to_send.extend(ident)
519 543 elif ident is not None:
520 544 to_send.append(ident)
521 545 to_send.append(DELIM)
522 546
523 547 signature = self.sign(real_message)
524 548 to_send.append(signature)
525 549
526 550 to_send.extend(real_message)
527 551
528 552 return to_send
529 553
530 554 def send(self, stream, msg_or_type, content=None, parent=None, ident=None,
531 555 buffers=None, track=False, header=None, metadata=None):
532 556 """Build and send a message via stream or socket.
533 557
534 558 The message format used by this function internally is as follows:
535 559
536 560 [ident1,ident2,...,DELIM,HMAC,p_header,p_parent,p_content,
537 561 buffer1,buffer2,...]
538 562
539 563 The serialize/unserialize methods convert the nested message dict into this
540 564 format.
541 565
542 566 Parameters
543 567 ----------
544 568
545 569 stream : zmq.Socket or ZMQStream
546 570 The socket-like object used to send the data.
547 571 msg_or_type : str or Message/dict
548 572 Normally, msg_or_type will be a msg_type unless a message is being
549 573 sent more than once. If a header is supplied, this can be set to
550 574 None and the msg_type will be pulled from the header.
551 575
552 576 content : dict or None
553 577 The content of the message (ignored if msg_or_type is a message).
554 578 header : dict or None
555 579 The header dict for the message (ignored if msg_to_type is a message).
556 580 parent : Message or dict or None
557 581 The parent or parent header describing the parent of this message
558 582 (ignored if msg_or_type is a message).
559 583 ident : bytes or list of bytes
560 584 The zmq.IDENTITY routing path.
561 585 metadata : dict or None
562 586 The metadata describing the message
563 587 buffers : list or None
564 588 The already-serialized buffers to be appended to the message.
565 589 track : bool
566 590 Whether to track. Only for use with Sockets, because ZMQStream
567 591 objects cannot track messages.
568 592
569 593
570 594 Returns
571 595 -------
572 596 msg : dict
573 597 The constructed message.
574 598 """
575 599 if not isinstance(stream, zmq.Socket):
576 600 # ZMQStreams and dummy sockets do not support tracking.
577 601 track = False
578 602
579 603 if isinstance(msg_or_type, (Message, dict)):
580 604 # We got a Message or message dict, not a msg_type so don't
581 605 # build a new Message.
582 606 msg = msg_or_type
583 607 else:
584 608 msg = self.msg(msg_or_type, content=content, parent=parent,
585 609 header=header, metadata=metadata)
586 610 if not os.getpid() == self.pid:
587 611 io.rprint("WARNING: attempted to send message from fork")
588 612 io.rprint(msg)
589 613 return
590 614 buffers = [] if buffers is None else buffers
591 615 to_send = self.serialize(msg, ident)
592 616 to_send.extend(buffers)
593 617 longest = max([ len(s) for s in to_send ])
594 618 copy = (longest < self.copy_threshold)
595 619
596 620 if buffers and track and not copy:
597 621 # only really track when we are doing zero-copy buffers
598 622 tracker = stream.send_multipart(to_send, copy=False, track=True)
599 623 else:
600 624 # use dummy tracker, which will be done immediately
601 625 tracker = DONE
602 626 stream.send_multipart(to_send, copy=copy)
603 627
604 628 if self.debug:
605 629 pprint.pprint(msg)
606 630 pprint.pprint(to_send)
607 631 pprint.pprint(buffers)
608 632
609 633 msg['tracker'] = tracker
610 634
611 635 return msg
612 636
613 637 def send_raw(self, stream, msg_list, flags=0, copy=True, ident=None):
614 638 """Send a raw message via ident path.
615 639
616 640 This method is used to send a already serialized message.
617 641
618 642 Parameters
619 643 ----------
620 644 stream : ZMQStream or Socket
621 645 The ZMQ stream or socket to use for sending the message.
622 646 msg_list : list
623 647 The serialized list of messages to send. This only includes the
624 648 [p_header,p_parent,p_metadata,p_content,buffer1,buffer2,...] portion of
625 649 the message.
626 650 ident : ident or list
627 651 A single ident or a list of idents to use in sending.
628 652 """
629 653 to_send = []
630 654 if isinstance(ident, bytes):
631 655 ident = [ident]
632 656 if ident is not None:
633 657 to_send.extend(ident)
634 658
635 659 to_send.append(DELIM)
636 660 to_send.append(self.sign(msg_list))
637 661 to_send.extend(msg_list)
638 662 stream.send_multipart(msg_list, flags, copy=copy)
639 663
640 664 def recv(self, socket, mode=zmq.NOBLOCK, content=True, copy=True):
641 665 """Receive and unpack a message.
642 666
643 667 Parameters
644 668 ----------
645 669 socket : ZMQStream or Socket
646 670 The socket or stream to use in receiving.
647 671
648 672 Returns
649 673 -------
650 674 [idents], msg
651 675 [idents] is a list of idents and msg is a nested message dict of
652 676 same format as self.msg returns.
653 677 """
654 678 if isinstance(socket, ZMQStream):
655 679 socket = socket.socket
656 680 try:
657 681 msg_list = socket.recv_multipart(mode, copy=copy)
658 682 except zmq.ZMQError as e:
659 683 if e.errno == zmq.EAGAIN:
660 684 # We can convert EAGAIN to None as we know in this case
661 685 # recv_multipart won't return None.
662 686 return None,None
663 687 else:
664 688 raise
665 689 # split multipart message into identity list and message dict
666 690 # invalid large messages can cause very expensive string comparisons
667 691 idents, msg_list = self.feed_identities(msg_list, copy)
668 692 try:
669 693 return idents, self.unserialize(msg_list, content=content, copy=copy)
670 694 except Exception as e:
671 695 # TODO: handle it
672 696 raise e
673 697
674 698 def feed_identities(self, msg_list, copy=True):
675 699 """Split the identities from the rest of the message.
676 700
677 701 Feed until DELIM is reached, then return the prefix as idents and
678 702 remainder as msg_list. This is easily broken by setting an IDENT to DELIM,
679 703 but that would be silly.
680 704
681 705 Parameters
682 706 ----------
683 707 msg_list : a list of Message or bytes objects
684 708 The message to be split.
685 709 copy : bool
686 710 flag determining whether the arguments are bytes or Messages
687 711
688 712 Returns
689 713 -------
690 714 (idents, msg_list) : two lists
691 715 idents will always be a list of bytes, each of which is a ZMQ
692 716 identity. msg_list will be a list of bytes or zmq.Messages of the
693 717 form [HMAC,p_header,p_parent,p_content,buffer1,buffer2,...] and
694 718 should be unpackable/unserializable via self.unserialize at this
695 719 point.
696 720 """
697 721 if copy:
698 722 idx = msg_list.index(DELIM)
699 723 return msg_list[:idx], msg_list[idx+1:]
700 724 else:
701 725 failed = True
702 726 for idx,m in enumerate(msg_list):
703 727 if m.bytes == DELIM:
704 728 failed = False
705 729 break
706 730 if failed:
707 731 raise ValueError("DELIM not in msg_list")
708 732 idents, msg_list = msg_list[:idx], msg_list[idx+1:]
709 733 return [m.bytes for m in idents], msg_list
710 734
711 735 def _add_digest(self, signature):
712 736 """add a digest to history to protect against replay attacks"""
713 737 if self.digest_history_size == 0:
714 738 # no history, never add digests
715 739 return
716 740
717 741 self.digest_history.add(signature)
718 742 if len(self.digest_history) > self.digest_history_size:
719 743 # threshold reached, cull 10%
720 744 self._cull_digest_history()
721 745
722 746 def _cull_digest_history(self):
723 747 """cull the digest history
724 748
725 749 Removes a randomly selected 10% of the digest history
726 750 """
727 751 current = len(self.digest_history)
728 752 n_to_cull = max(int(current // 10), current - self.digest_history_size)
729 753 if n_to_cull >= current:
730 754 self.digest_history = set()
731 755 return
732 756 to_cull = random.sample(self.digest_history, n_to_cull)
733 757 self.digest_history.difference_update(to_cull)
734 758
735 759 def unserialize(self, msg_list, content=True, copy=True):
736 760 """Unserialize a msg_list to a nested message dict.
737 761
738 762 This is roughly the inverse of serialize. The serialize/unserialize
739 763 methods work with full message lists, whereas pack/unpack work with
740 764 the individual message parts in the message list.
741 765
742 766 Parameters:
743 767 -----------
744 768 msg_list : list of bytes or Message objects
745 769 The list of message parts of the form [HMAC,p_header,p_parent,
746 770 p_metadata,p_content,buffer1,buffer2,...].
747 771 content : bool (True)
748 772 Whether to unpack the content dict (True), or leave it packed
749 773 (False).
750 774 copy : bool (True)
751 775 Whether to return the bytes (True), or the non-copying Message
752 776 object in each place (False).
753 777
754 778 Returns
755 779 -------
756 780 msg : dict
757 781 The nested message dict with top-level keys [header, parent_header,
758 782 content, buffers].
759 783 """
760 784 minlen = 5
761 785 message = {}
762 786 if not copy:
763 787 for i in range(minlen):
764 788 msg_list[i] = msg_list[i].bytes
765 789 if self.auth is not None:
766 790 signature = msg_list[0]
767 791 if not signature:
768 792 raise ValueError("Unsigned Message")
769 793 if signature in self.digest_history:
770 794 raise ValueError("Duplicate Signature: %r" % signature)
771 795 self._add_digest(signature)
772 796 check = self.sign(msg_list[1:5])
773 797 if not signature == check:
774 798 raise ValueError("Invalid Signature: %r" % signature)
775 799 if not len(msg_list) >= minlen:
776 800 raise TypeError("malformed message, must have at least %i elements"%minlen)
777 801 header = self.unpack(msg_list[1])
778 802 message['header'] = header
779 803 message['msg_id'] = header['msg_id']
780 804 message['msg_type'] = header['msg_type']
781 805 message['parent_header'] = self.unpack(msg_list[2])
782 806 message['metadata'] = self.unpack(msg_list[3])
783 807 if content:
784 808 message['content'] = self.unpack(msg_list[4])
785 809 else:
786 810 message['content'] = msg_list[4]
787 811
788 812 message['buffers'] = msg_list[5:]
789 813 return message
790 814
791 815 def test_msg2obj():
792 816 am = dict(x=1)
793 817 ao = Message(am)
794 818 assert ao.x == am['x']
795 819
796 820 am['y'] = dict(z=1)
797 821 ao = Message(am)
798 822 assert ao.y.z == am['y']['z']
799 823
800 824 k1, k2 = 'y', 'z'
801 825 assert ao[k1][k2] == am[k1][k2]
802 826
803 827 am2 = dict(ao)
804 828 assert am['x'] == am2['x']
805 829 assert am['y']['z'] == am2['y']['z']
806 830
General Comments 0
You need to be logged in to leave comments. Login now