Show More
@@ -12,6 +12,7 b'' | |||||
12 |
|
12 | |||
13 | from __future__ import print_function |
|
13 | from __future__ import print_function | |
14 |
|
14 | |||
|
15 | import os | |||
15 | import time |
|
16 | import time | |
16 | from pprint import pprint |
|
17 | from pprint import pprint | |
17 |
|
18 | |||
@@ -139,19 +140,30 b' class Client(object):' | |||||
139 | A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port' |
|
140 | A string of the form passed to ssh, i.e. 'server.tld' or 'user@server.tld:port' | |
140 | If keyfile or password is specified, and this is not, it will default to |
|
141 | If keyfile or password is specified, and this is not, it will default to | |
141 | the ip given in addr. |
|
142 | the ip given in addr. | |
142 |
key |
|
143 | sshkey : str; path to public ssh key file | |
143 | This specifies a key to be used in ssh login, default None. |
|
144 | This specifies a key to be used in ssh login, default None. | |
144 | Regular default ssh keys will be used without specifying this argument. |
|
145 | Regular default ssh keys will be used without specifying this argument. | |
145 | password : str; |
|
146 | password : str; | |
146 | Your ssh password to sshserver. Note that if this is left None, |
|
147 | Your ssh password to sshserver. Note that if this is left None, | |
147 | you will be prompted for it if passwordless key based login is unavailable. |
|
148 | you will be prompted for it if passwordless key based login is unavailable. | |
148 |
|
149 | |||
|
150 | #------- exec authentication args ------- | |||
|
151 | # If even localhost is untrusted, you can have some protection against | |||
|
152 | # unauthorized execution by using a key. Messages are still sent | |||
|
153 | # as cleartext, so if someone can snoop your loopback traffic this will | |||
|
154 | # not help anything. | |||
|
155 | ||||
|
156 | exec_key : str | |||
|
157 | an authentication key or file containing a key | |||
|
158 | default: None | |||
|
159 | ||||
|
160 | ||||
149 | Attributes |
|
161 | Attributes | |
150 | ---------- |
|
162 | ---------- | |
151 | ids : set of int engine IDs |
|
163 | ids : set of int engine IDs | |
152 | requesting the ids attribute always synchronizes |
|
164 | requesting the ids attribute always synchronizes | |
153 | the registration state. To request ids without synchronization, |
|
165 | the registration state. To request ids without synchronization, | |
154 | use semi-private _ids. |
|
166 | use semi-private _ids attributes. | |
155 |
|
167 | |||
156 | history : list of msg_ids |
|
168 | history : list of msg_ids | |
157 | a list of msg_ids, keeping track of all the execution |
|
169 | a list of msg_ids, keeping track of all the execution | |
@@ -175,7 +187,7 b' class Client(object):' | |||||
175 |
|
187 | |||
176 | barrier : wait on one or more msg_ids |
|
188 | barrier : wait on one or more msg_ids | |
177 |
|
189 | |||
178 |
execution methods: apply/apply_bound/apply_to/appl |
|
190 | execution methods: apply/apply_bound/apply_to/apply_bound | |
179 | legacy: execute, run |
|
191 | legacy: execute, run | |
180 |
|
192 | |||
181 | query methods: queue_status, get_result, purge |
|
193 | query methods: queue_status, get_result, purge | |
@@ -202,26 +214,32 b' class Client(object):' | |||||
202 | debug = False |
|
214 | debug = False | |
203 |
|
215 | |||
204 | def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False, |
|
216 | def __init__(self, addr='tcp://127.0.0.1:10101', context=None, username=None, debug=False, | |
205 |
sshserver=None, key |
|
217 | sshserver=None, sshkey=None, password=None, paramiko=None, | |
|
218 | exec_key=None,): | |||
206 | if context is None: |
|
219 | if context is None: | |
207 | context = zmq.Context() |
|
220 | context = zmq.Context() | |
208 | self.context = context |
|
221 | self.context = context | |
209 | self._addr = addr |
|
222 | self._addr = addr | |
210 |
self._ssh = bool(sshserver or key |
|
223 | self._ssh = bool(sshserver or sshkey or password) | |
211 | if self._ssh and sshserver is None: |
|
224 | if self._ssh and sshserver is None: | |
212 | # default to the same |
|
225 | # default to the same | |
213 | sshserver = addr.split('://')[1].split(':')[0] |
|
226 | sshserver = addr.split('://')[1].split(':')[0] | |
214 | if self._ssh and password is None: |
|
227 | if self._ssh and password is None: | |
215 |
if tunnel.try_passwordless_ssh(sshserver, key |
|
228 | if tunnel.try_passwordless_ssh(sshserver, sshkey, paramiko): | |
216 | password=False |
|
229 | password=False | |
217 | else: |
|
230 | else: | |
218 | password = getpass("SSH Password for %s: "%sshserver) |
|
231 | password = getpass("SSH Password for %s: "%sshserver) | |
219 |
ssh_kwargs = dict(keyfile=key |
|
232 | ssh_kwargs = dict(keyfile=sshkey, password=password, paramiko=paramiko) | |
220 |
|
233 | |||
|
234 | if os.path.isfile(exec_key): | |||
|
235 | arg = 'keyfile' | |||
|
236 | else: | |||
|
237 | arg = 'key' | |||
|
238 | key_arg = {arg:exec_key} | |||
221 | if username is None: |
|
239 | if username is None: | |
222 | self.session = ss.StreamSession() |
|
240 | self.session = ss.StreamSession(**key_arg) | |
223 | else: |
|
241 | else: | |
224 | self.session = ss.StreamSession(username) |
|
242 | self.session = ss.StreamSession(username, **key_arg) | |
225 | self._registration_socket = self.context.socket(zmq.XREQ) |
|
243 | self._registration_socket = self.context.socket(zmq.XREQ) | |
226 | self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session) |
|
244 | self._registration_socket.setsockopt(zmq.IDENTITY, self.session.session) | |
227 | if self._ssh: |
|
245 | if self._ssh: | |
@@ -536,11 +554,12 b' class Client(object):' | |||||
536 |
|
554 | |||
537 | @spinfirst |
|
555 | @spinfirst | |
538 | @defaultblock |
|
556 | @defaultblock | |
539 |
def |
|
557 | def shutdown(self, targets=None, restart=False, block=None): | |
540 | """Terminates one or more engine processes.""" |
|
558 | """Terminates one or more engine processes.""" | |
541 | targets = self._build_targets(targets)[0] |
|
559 | targets = self._build_targets(targets)[0] | |
542 | for t in targets: |
|
560 | for t in targets: | |
543 |
self.session.send(self._control_socket, ' |
|
561 | self.session.send(self._control_socket, 'shutdown_request', | |
|
562 | content={'restart':restart},ident=t) | |||
544 | error = False |
|
563 | error = False | |
545 | if self.block: |
|
564 | if self.block: | |
546 | for i in range(len(targets)): |
|
565 | for i in range(len(targets)): |
@@ -15,6 +15,7 b' and monitors traffic through the various queues.' | |||||
15 | #----------------------------------------------------------------------------- |
|
15 | #----------------------------------------------------------------------------- | |
16 | from __future__ import print_function |
|
16 | from __future__ import print_function | |
17 |
|
17 | |||
|
18 | import os | |||
18 | from datetime import datetime |
|
19 | from datetime import datetime | |
19 | import logging |
|
20 | import logging | |
20 |
|
21 | |||
@@ -28,7 +29,7 b' from IPython.zmq.entry_point import bind_port' | |||||
28 |
|
29 | |||
29 | from streamsession import Message, wrap_exception |
|
30 | from streamsession import Message, wrap_exception | |
30 | from entry_point import (make_base_argument_parser, select_random_ports, split_ports, |
|
31 | from entry_point import (make_base_argument_parser, select_random_ports, split_ports, | |
31 | connect_logger, parse_url, signal_children) |
|
32 | connect_logger, parse_url, signal_children, generate_exec_key) | |
32 |
|
33 | |||
33 | #----------------------------------------------------------------------------- |
|
34 | #----------------------------------------------------------------------------- | |
34 | # Code |
|
35 | # Code | |
@@ -283,13 +284,12 b' class Controller(object):' | |||||
283 | logger.debug("registration::dispatch_register_request(%s)"%msg) |
|
284 | logger.debug("registration::dispatch_register_request(%s)"%msg) | |
284 | idents,msg = self.session.feed_identities(msg) |
|
285 | idents,msg = self.session.feed_identities(msg) | |
285 | if not idents: |
|
286 | if not idents: | |
286 | logger.error("Bad Queue Message: %s"%msg) |
|
287 | logger.error("Bad Queue Message: %s"%msg, exc_info=True) | |
287 | return |
|
288 | return | |
288 | try: |
|
289 | try: | |
289 | msg = self.session.unpack_message(msg,content=True) |
|
290 | msg = self.session.unpack_message(msg,content=True) | |
290 |
except |
|
291 | except: | |
291 | logger.error("registration::got bad registration message: %s"%msg) |
|
292 | logger.error("registration::got bad registration message: %s"%msg, exc_info=True) | |
292 | raise e |
|
|||
293 | return |
|
293 | return | |
294 |
|
294 | |||
295 | msg_type = msg['msg_type'] |
|
295 | msg_type = msg['msg_type'] | |
@@ -326,7 +326,7 b' class Controller(object):' | |||||
326 | msg = self.session.unpack_message(msg, content=True) |
|
326 | msg = self.session.unpack_message(msg, content=True) | |
327 | except: |
|
327 | except: | |
328 | content = wrap_exception() |
|
328 | content = wrap_exception() | |
329 | logger.error("Bad Client Message: %s"%msg) |
|
329 | logger.error("Bad Client Message: %s"%msg, exc_info=True) | |
330 | self.session.send(self.clientele, "controller_error", ident=client_id, |
|
330 | self.session.send(self.clientele, "controller_error", ident=client_id, | |
331 | content=content) |
|
331 | content=content) | |
332 | return |
|
332 | return | |
@@ -340,7 +340,7 b' class Controller(object):' | |||||
340 | assert handler is not None, "Bad Message Type: %s"%msg_type |
|
340 | assert handler is not None, "Bad Message Type: %s"%msg_type | |
341 | except: |
|
341 | except: | |
342 | content = wrap_exception() |
|
342 | content = wrap_exception() | |
343 | logger.error("Bad Message Type: %s"%msg_type) |
|
343 | logger.error("Bad Message Type: %s"%msg_type, exc_info=True) | |
344 | self.session.send(self.clientele, "controller_error", ident=client_id, |
|
344 | self.session.send(self.clientele, "controller_error", ident=client_id, | |
345 | content=content) |
|
345 | content=content) | |
346 | return |
|
346 | return | |
@@ -390,7 +390,7 b' class Controller(object):' | |||||
390 | try: |
|
390 | try: | |
391 | msg = self.session.unpack_message(msg, content=False) |
|
391 | msg = self.session.unpack_message(msg, content=False) | |
392 | except: |
|
392 | except: | |
393 | logger.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg)) |
|
393 | logger.error("queue::client %r sent invalid message to %r: %s"%(client_id, queue_id, msg), exc_info=True) | |
394 | return |
|
394 | return | |
395 |
|
395 | |||
396 | eid = self.by_ident.get(queue_id, None) |
|
396 | eid = self.by_ident.get(queue_id, None) | |
@@ -417,7 +417,7 b' class Controller(object):' | |||||
417 | msg = self.session.unpack_message(msg, content=False) |
|
417 | msg = self.session.unpack_message(msg, content=False) | |
418 | except: |
|
418 | except: | |
419 | logger.error("queue::engine %r sent invalid message to %r: %s"%( |
|
419 | logger.error("queue::engine %r sent invalid message to %r: %s"%( | |
420 | queue_id,client_id, msg)) |
|
420 | queue_id,client_id, msg), exc_info=True) | |
421 | return |
|
421 | return | |
422 |
|
422 | |||
423 | eid = self.by_ident.get(queue_id, None) |
|
423 | eid = self.by_ident.get(queue_id, None) | |
@@ -448,7 +448,7 b' class Controller(object):' | |||||
448 | msg = self.session.unpack_message(msg, content=False) |
|
448 | msg = self.session.unpack_message(msg, content=False) | |
449 | except: |
|
449 | except: | |
450 | logger.error("task::client %r sent invalid task message: %s"%( |
|
450 | logger.error("task::client %r sent invalid task message: %s"%( | |
451 | client_id, msg)) |
|
451 | client_id, msg), exc_info=True) | |
452 | return |
|
452 | return | |
453 |
|
453 | |||
454 | header = msg['header'] |
|
454 | header = msg['header'] | |
@@ -871,7 +871,11 b' def main():' | |||||
871 | n = ZMQStream(ctx.socket(zmq.PUB), loop) |
|
871 | n = ZMQStream(ctx.socket(zmq.PUB), loop) | |
872 | nport = bind_port(n, args.ip, args.notice) |
|
872 | nport = bind_port(n, args.ip, args.notice) | |
873 |
|
873 | |||
874 | thesession = session.StreamSession(username=args.ident or "controller") |
|
874 | ### Key File ### | |
|
875 | if args.execkey and not os.path.isfile(args.execkey): | |||
|
876 | generate_exec_key(args.execkey) | |||
|
877 | ||||
|
878 | thesession = session.StreamSession(username=args.ident or "controller", keyfile=args.execkey) | |||
875 |
|
879 | |||
876 | ### build and launch the queues ### |
|
880 | ### build and launch the queues ### | |
877 |
|
881 |
@@ -40,7 +40,7 b' class Engine(object):' | |||||
40 | heart=None |
|
40 | heart=None | |
41 | kernel=None |
|
41 | kernel=None | |
42 |
|
42 | |||
43 |
def __init__(self, context, loop, session, registrar, client, ident=None |
|
43 | def __init__(self, context, loop, session, registrar, client=None, ident=None): | |
44 | self.context = context |
|
44 | self.context = context | |
45 | self.loop = loop |
|
45 | self.loop = loop | |
46 | self.session = session |
|
46 | self.session = session | |
@@ -53,6 +53,7 b' class Engine(object):' | |||||
53 |
|
53 | |||
54 | content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident) |
|
54 | content = dict(queue=self.ident, heartbeat=self.ident, control=self.ident) | |
55 | self.registrar.on_recv(self.complete_registration) |
|
55 | self.registrar.on_recv(self.complete_registration) | |
|
56 | # print (self.session.key) | |||
56 | self.session.send(self.registrar, "registration_request",content=content) |
|
57 | self.session.send(self.registrar, "registration_request",content=content) | |
57 |
|
58 | |||
58 | def complete_registration(self, msg): |
|
59 | def complete_registration(self, msg): | |
@@ -77,9 +78,8 b' class Engine(object):' | |||||
77 | sub.on_recv(lambda *a: None) |
|
78 | sub.on_recv(lambda *a: None) | |
78 | port = sub.bind_to_random_port("tcp://%s"%LOCALHOST) |
|
79 | port = sub.bind_to_random_port("tcp://%s"%LOCALHOST) | |
79 | iopub_addr = "tcp://%s:%i"%(LOCALHOST,12345) |
|
80 | iopub_addr = "tcp://%s:%i"%(LOCALHOST,12345) | |
80 |
|
||||
81 | make_kernel(self.ident, control_addr, shell_addrs, iopub_addr, hb_addrs, |
|
81 | make_kernel(self.ident, control_addr, shell_addrs, iopub_addr, hb_addrs, | |
82 | client_addr=None, loop=self.loop, context=self.context) |
|
82 | client_addr=None, loop=self.loop, context=self.context, key=self.session.key) | |
83 |
|
83 | |||
84 | else: |
|
84 | else: | |
85 | # logger.error("Registration Failed: %s"%msg) |
|
85 | # logger.error("Registration Failed: %s"%msg) | |
@@ -111,7 +111,8 b' def main():' | |||||
111 | iface="%s://%s"%(args.transport,args.ip)+':%i' |
|
111 | iface="%s://%s"%(args.transport,args.ip)+':%i' | |
112 |
|
112 | |||
113 | loop = ioloop.IOLoop.instance() |
|
113 | loop = ioloop.IOLoop.instance() | |
114 | session = StreamSession() |
|
114 | session = StreamSession(keyfile=args.execkey) | |
|
115 | # print (session.key) | |||
115 | ctx = zmq.Context() |
|
116 | ctx = zmq.Context() | |
116 |
|
117 | |||
117 | # setup logging |
|
118 | # setup logging | |
@@ -124,7 +125,7 b' def main():' | |||||
124 | reg = ctx.socket(zmq.PAIR) |
|
125 | reg = ctx.socket(zmq.PAIR) | |
125 | reg.connect(reg_conn) |
|
126 | reg.connect(reg_conn) | |
126 | reg = zmqstream.ZMQStream(reg, loop) |
|
127 | reg = zmqstream.ZMQStream(reg, loop) | |
127 |
client = |
|
128 | client = None | |
128 |
|
129 | |||
129 | e = Engine(ctx, loop, session, reg, client, args.ident) |
|
130 | e = Engine(ctx, loop, session, reg, client, args.ident) | |
130 | dc = ioloop.DelayedCallback(e.start, 100, loop) |
|
131 | dc = ioloop.DelayedCallback(e.start, 100, loop) |
@@ -7,6 +7,7 b' import logging' | |||||
7 | import atexit |
|
7 | import atexit | |
8 | import sys |
|
8 | import sys | |
9 | import os |
|
9 | import os | |
|
10 | import stat | |||
10 | import socket |
|
11 | import socket | |
11 | from subprocess import Popen, PIPE |
|
12 | from subprocess import Popen, PIPE | |
12 | from signal import signal, SIGINT, SIGABRT, SIGTERM |
|
13 | from signal import signal, SIGINT, SIGABRT, SIGTERM | |
@@ -33,7 +34,7 b' def split_ports(s, n):' | |||||
33 | return ports |
|
34 | return ports | |
34 |
|
35 | |||
35 | def select_random_ports(n): |
|
36 | def select_random_ports(n): | |
36 |
"""Selects and return n random ports that are |
|
37 | """Selects and return n random ports that are available.""" | |
37 | ports = [] |
|
38 | ports = [] | |
38 | for i in xrange(n): |
|
39 | for i in xrange(n): | |
39 | sock = socket.socket() |
|
40 | sock = socket.socket() | |
@@ -46,6 +47,7 b' def select_random_ports(n):' | |||||
46 | return ports |
|
47 | return ports | |
47 |
|
48 | |||
48 | def parse_url(args): |
|
49 | def parse_url(args): | |
|
50 | """Ensure args.url contains full transport://interface:port""" | |||
49 | if args.url: |
|
51 | if args.url: | |
50 | iface = args.url.split('://',1) |
|
52 | iface = args.url.split('://',1) | |
51 | if len(args) == 2: |
|
53 | if len(args) == 2: | |
@@ -57,6 +59,7 b' def parse_url(args):' | |||||
57 | args.url = "%s://%s:%i"%(args.transport, args.ip,args.regport) |
|
59 | args.url = "%s://%s:%i"%(args.transport, args.ip,args.regport) | |
58 |
|
60 | |||
59 | def signal_children(children): |
|
61 | def signal_children(children): | |
|
62 | """Relay interupt/term signals to children, for more solid process cleanup.""" | |||
60 | def terminate_children(sig, frame): |
|
63 | def terminate_children(sig, frame): | |
61 | for child in children: |
|
64 | for child in children: | |
62 | child.terminate() |
|
65 | child.terminate() | |
@@ -64,6 +67,17 b' def signal_children(children):' | |||||
64 | for sig in (SIGINT, SIGABRT, SIGTERM): |
|
67 | for sig in (SIGINT, SIGABRT, SIGTERM): | |
65 | signal(sig, terminate_children) |
|
68 | signal(sig, terminate_children) | |
66 |
|
69 | |||
|
70 | def generate_exec_key(keyfile): | |||
|
71 | import uuid | |||
|
72 | newkey = str(uuid.uuid4()) | |||
|
73 | with open(keyfile, 'w') as f: | |||
|
74 | # f.write('ipython-key ') | |||
|
75 | f.write(newkey) | |||
|
76 | # set user-only RW permissions (0600) | |||
|
77 | # this will have no effect on Windows | |||
|
78 | os.chmod(keyfile, stat.S_IRUSR|stat.S_IWUSR) | |||
|
79 | ||||
|
80 | ||||
67 | def make_base_argument_parser(): |
|
81 | def make_base_argument_parser(): | |
68 | """ Creates an ArgumentParser for the generic arguments supported by all |
|
82 | """ Creates an ArgumentParser for the generic arguments supported by all | |
69 | ipcluster entry points. |
|
83 | ipcluster entry points. | |
@@ -86,6 +100,8 b' def make_base_argument_parser():' | |||||
86 | help='set the message format method [default: json]') |
|
100 | help='set the message format method [default: json]') | |
87 | parser.add_argument('--url', type=str, |
|
101 | parser.add_argument('--url', type=str, | |
88 | help='set transport,ip,regport in one arg, e.g. tcp://127.0.0.1:10101') |
|
102 | help='set transport,ip,regport in one arg, e.g. tcp://127.0.0.1:10101') | |
|
103 | parser.add_argument('--execkey', type=str, | |||
|
104 | help="File containing key for authenticating requests.") | |||
89 |
|
105 | |||
90 | return parser |
|
106 | return parser | |
91 |
|
107 |
@@ -65,7 +65,7 b' def main():' | |||||
65 |
|
65 | |||
66 | controller_args = strip_args([('--n','-n')]) |
|
66 | controller_args = strip_args([('--n','-n')]) | |
67 | engine_args = filter_args(['--url', '--regport', '--logport', '--ip', |
|
67 | engine_args = filter_args(['--url', '--regport', '--logport', '--ip', | |
68 | '--transport','--loglevel','--packer'])+['--ident'] |
|
68 | '--transport','--loglevel','--packer', '--execkey'])+['--ident'] | |
69 |
|
69 | |||
70 | controller = launch_process('controller', controller_args) |
|
70 | controller = launch_process('controller', controller_args) | |
71 | for i in range(10): |
|
71 | for i in range(10): |
@@ -127,17 +127,21 b' class Kernel(HasTraits):' | |||||
127 | """kill ourself. This should really be handled in an external process""" |
|
127 | """kill ourself. This should really be handled in an external process""" | |
128 | self.abort_queues() |
|
128 | self.abort_queues() | |
129 | content = dict(parent['content']) |
|
129 | content = dict(parent['content']) | |
130 |
msg = self.session.send(s |
|
130 | msg = self.session.send(stream, 'shutdown_reply', | |
131 | content, parent, ident) |
|
131 | content=content, parent=parent, ident=ident) | |
132 | msg = self.session.send(self.pub_socket, 'shutdown_reply', |
|
132 | # msg = self.session.send(self.pub_socket, 'shutdown_reply', | |
133 | content, parent, ident) |
|
133 | # content, parent, ident) | |
134 | # print >> sys.__stdout__, msg |
|
134 | # print >> sys.__stdout__, msg | |
135 | time.sleep(0.1) |
|
135 | time.sleep(0.1) | |
136 | sys.exit(0) |
|
136 | sys.exit(0) | |
137 |
|
137 | |||
138 | def dispatch_control(self, msg): |
|
138 | def dispatch_control(self, msg): | |
139 | idents,msg = self.session.feed_identities(msg, copy=False) |
|
139 | idents,msg = self.session.feed_identities(msg, copy=False) | |
140 | msg = self.session.unpack_message(msg, content=True, copy=False) |
|
140 | try: | |
|
141 | msg = self.session.unpack_message(msg, content=True, copy=False) | |||
|
142 | except: | |||
|
143 | logger.error("Invalid Message", exc_info=True) | |||
|
144 | return | |||
141 |
|
145 | |||
142 | header = msg['header'] |
|
146 | header = msg['header'] | |
143 | msg_id = header['msg_id'] |
|
147 | msg_id = header['msg_id'] | |
@@ -313,7 +317,12 b' class Kernel(HasTraits):' | |||||
313 | def dispatch_queue(self, stream, msg): |
|
317 | def dispatch_queue(self, stream, msg): | |
314 | self.control_stream.flush() |
|
318 | self.control_stream.flush() | |
315 | idents,msg = self.session.feed_identities(msg, copy=False) |
|
319 | idents,msg = self.session.feed_identities(msg, copy=False) | |
316 | msg = self.session.unpack_message(msg, content=True, copy=False) |
|
320 | try: | |
|
321 | msg = self.session.unpack_message(msg, content=True, copy=False) | |||
|
322 | except: | |||
|
323 | logger.error("Invalid Message", exc_info=True) | |||
|
324 | return | |||
|
325 | ||||
317 |
|
326 | |||
318 | header = msg['header'] |
|
327 | header = msg['header'] | |
319 | msg_id = header['msg_id'] |
|
328 | msg_id = header['msg_id'] | |
@@ -367,14 +376,15 b' class Kernel(HasTraits):' | |||||
367 | # time.sleep(1e-3) |
|
376 | # time.sleep(1e-3) | |
368 |
|
377 | |||
369 | def make_kernel(identity, control_addr, shell_addrs, iopub_addr, hb_addrs, |
|
378 | def make_kernel(identity, control_addr, shell_addrs, iopub_addr, hb_addrs, | |
370 | client_addr=None, loop=None, context=None): |
|
379 | client_addr=None, loop=None, context=None, key=None): | |
371 | # create loop, context, and session: |
|
380 | # create loop, context, and session: | |
372 | if loop is None: |
|
381 | if loop is None: | |
373 | loop = ioloop.IOLoop.instance() |
|
382 | loop = ioloop.IOLoop.instance() | |
374 | if context is None: |
|
383 | if context is None: | |
375 | context = zmq.Context() |
|
384 | context = zmq.Context() | |
376 | c = context |
|
385 | c = context | |
377 | session = StreamSession() |
|
386 | session = StreamSession(key=key) | |
|
387 | # print (session.key) | |||
378 | print (control_addr, shell_addrs, iopub_addr, hb_addrs) |
|
388 | print (control_addr, shell_addrs, iopub_addr, hb_addrs) | |
379 |
|
389 | |||
380 | # create Control Stream |
|
390 | # create Control Stream |
@@ -277,7 +277,9 b' def unpack_apply_message(bufs, g=None, copy=True):' | |||||
277 | class StreamSession(object): |
|
277 | class StreamSession(object): | |
278 | """tweaked version of IPython.zmq.session.Session, for development in Parallel""" |
|
278 | """tweaked version of IPython.zmq.session.Session, for development in Parallel""" | |
279 | debug=False |
|
279 | debug=False | |
280 | def __init__(self, username=None, session=None, packer=None, unpacker=None): |
|
280 | key=None | |
|
281 | ||||
|
282 | def __init__(self, username=None, session=None, packer=None, unpacker=None, key=None, keyfile=None): | |||
281 | if username is None: |
|
283 | if username is None: | |
282 | username = os.environ.get('USER','username') |
|
284 | username = os.environ.get('USER','username') | |
283 | self.username = username |
|
285 | self.username = username | |
@@ -300,6 +302,14 b' class StreamSession(object):' | |||||
300 | raise TypeError("unpacker must be callable, not %s"%type(unpacker)) |
|
302 | raise TypeError("unpacker must be callable, not %s"%type(unpacker)) | |
301 | self.unpack = unpacker |
|
303 | self.unpack = unpacker | |
302 |
|
304 | |||
|
305 | if key is not None and keyfile is not None: | |||
|
306 | raise TypeError("Must specify key OR keyfile, not both") | |||
|
307 | if keyfile is not None: | |||
|
308 | with open(keyfile) as f: | |||
|
309 | self.key = f.read().strip() | |||
|
310 | else: | |||
|
311 | self.key = key | |||
|
312 | # print key, keyfile, self.key | |||
303 | self.none = self.pack({}) |
|
313 | self.none = self.pack({}) | |
304 |
|
314 | |||
305 | def msg_header(self, msg_type): |
|
315 | def msg_header(self, msg_type): | |
@@ -318,6 +328,14 b' class StreamSession(object):' | |||||
318 | msg['header'].update(sub) |
|
328 | msg['header'].update(sub) | |
319 | return msg |
|
329 | return msg | |
320 |
|
330 | |||
|
331 | def check_key(self, msg_or_header): | |||
|
332 | """Check that a message's header has the right key""" | |||
|
333 | if self.key is None: | |||
|
334 | return True | |||
|
335 | header = extract_header(msg_or_header) | |||
|
336 | return header.get('key', None) == self.key | |||
|
337 | ||||
|
338 | ||||
321 | def send(self, stream, msg_type, content=None, buffers=None, parent=None, subheader=None, ident=None): |
|
339 | def send(self, stream, msg_type, content=None, buffers=None, parent=None, subheader=None, ident=None): | |
322 | """Build and send a message via stream or socket. |
|
340 | """Build and send a message via stream or socket. | |
323 |
|
341 | |||
@@ -353,6 +371,8 b' class StreamSession(object):' | |||||
353 | elif ident is not None: |
|
371 | elif ident is not None: | |
354 | to_send.append(ident) |
|
372 | to_send.append(ident) | |
355 | to_send.append(DELIM) |
|
373 | to_send.append(DELIM) | |
|
374 | if self.key is not None: | |||
|
375 | to_send.append(self.key) | |||
356 | to_send.append(self.pack(msg['header'])) |
|
376 | to_send.append(self.pack(msg['header'])) | |
357 | to_send.append(self.pack(msg['parent_header'])) |
|
377 | to_send.append(self.pack(msg['parent_header'])) | |
358 |
|
378 | |||
@@ -393,6 +413,8 b' class StreamSession(object):' | |||||
393 | if ident is not None: |
|
413 | if ident is not None: | |
394 | to_send.extend(ident) |
|
414 | to_send.extend(ident) | |
395 | to_send.append(DELIM) |
|
415 | to_send.append(DELIM) | |
|
416 | if self.key is not None: | |||
|
417 | to_send.append(self.key) | |||
396 | to_send.extend(msg) |
|
418 | to_send.extend(msg) | |
397 | stream.send_multipart(msg, flags, copy=copy) |
|
419 | stream.send_multipart(msg, flags, copy=copy) | |
398 |
|
420 | |||
@@ -457,19 +479,24 b' class StreamSession(object):' | |||||
457 | or the non-copying Message object in each place (False) |
|
479 | or the non-copying Message object in each place (False) | |
458 |
|
480 | |||
459 | """ |
|
481 | """ | |
460 | if not len(msg) >= 3: |
|
482 | ikey = int(self.key is not None) | |
461 | raise TypeError("malformed message, must have at least 3 elements") |
|
483 | minlen = 3 + ikey | |
|
484 | if not len(msg) >= minlen: | |||
|
485 | raise TypeError("malformed message, must have at least %i elements"%minlen) | |||
462 | message = {} |
|
486 | message = {} | |
463 | if not copy: |
|
487 | if not copy: | |
464 |
for i in range( |
|
488 | for i in range(minlen): | |
465 | msg[i] = msg[i].bytes |
|
489 | msg[i] = msg[i].bytes | |
466 | message['header'] = self.unpack(msg[0]) |
|
490 | if ikey: | |
|
491 | if not self.key == msg[0]: | |||
|
492 | raise KeyError("Invalid Session Key: %s"%msg[0]) | |||
|
493 | message['header'] = self.unpack(msg[ikey+0]) | |||
467 | message['msg_type'] = message['header']['msg_type'] |
|
494 | message['msg_type'] = message['header']['msg_type'] | |
468 | message['parent_header'] = self.unpack(msg[1]) |
|
495 | message['parent_header'] = self.unpack(msg[ikey+1]) | |
469 | if content: |
|
496 | if content: | |
470 | message['content'] = self.unpack(msg[2]) |
|
497 | message['content'] = self.unpack(msg[ikey+2]) | |
471 | else: |
|
498 | else: | |
472 | message['content'] = msg[2] |
|
499 | message['content'] = msg[ikey+2] | |
473 |
|
500 | |||
474 | # message['buffers'] = msg[3:] |
|
501 | # message['buffers'] = msg[3:] | |
475 | # else: |
|
502 | # else: | |
@@ -481,7 +508,7 b' class StreamSession(object):' | |||||
481 | # else: |
|
508 | # else: | |
482 | # message['content'] = msg[2].bytes |
|
509 | # message['content'] = msg[2].bytes | |
483 |
|
510 | |||
484 | message['buffers'] = msg[3:]# [ m.buffer for m in msg[3:] ] |
|
511 | message['buffers'] = msg[ikey+3:]# [ m.buffer for m in msg[3:] ] | |
485 | return message |
|
512 | return message | |
486 |
|
513 | |||
487 |
|
514 |
General Comments 0
You need to be logged in to leave comments.
Login now