##// END OF EJS Templates
forward-port draft76 websockets...
Min RK -
Show More
@@ -0,0 +1,312
1 """WebsocketProtocol76 from tornado 3.2.2 for tornado >= 4.0
2
3 The contents of this file are Copyright (c) Tornado
4 Used under the Apache 2.0 license
5 """
6
7
8 from __future__ import absolute_import, division, print_function, with_statement
9 # Author: Jacob Kristhammar, 2010
10
11 import functools
12 import hashlib
13 import struct
14 import time
15 import tornado.escape
16 import tornado.web
17
18 from tornado.log import gen_log, app_log
19 from tornado.util import bytes_type, unicode_type
20
21 from tornado.websocket import WebSocketHandler, WebSocketProtocol13
22
23 class AllowDraftWebSocketHandler(WebSocketHandler):
24 """Restore Draft76 support for tornado 4
25
26 Remove when we can run tests without phantomjs + qt4
27 """
28
29 # get is unmodified except between the BEGIN/END PATCH lines
30 @tornado.web.asynchronous
31 def get(self, *args, **kwargs):
32 self.open_args = args
33 self.open_kwargs = kwargs
34
35 # Upgrade header should be present and should be equal to WebSocket
36 if self.request.headers.get("Upgrade", "").lower() != 'websocket':
37 self.set_status(400)
38 self.finish("Can \"Upgrade\" only to \"WebSocket\".")
39 return
40
41 # Connection header should be upgrade. Some proxy servers/load balancers
42 # might mess with it.
43 headers = self.request.headers
44 connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(","))
45 if 'upgrade' not in connection:
46 self.set_status(400)
47 self.finish("\"Connection\" must be \"Upgrade\".")
48 return
49
50 # Handle WebSocket Origin naming convention differences
51 # The difference between version 8 and 13 is that in 8 the
52 # client sends a "Sec-Websocket-Origin" header and in 13 it's
53 # simply "Origin".
54 if "Origin" in self.request.headers:
55 origin = self.request.headers.get("Origin")
56 else:
57 origin = self.request.headers.get("Sec-Websocket-Origin", None)
58
59
60 # If there was an origin header, check to make sure it matches
61 # according to check_origin. When the origin is None, we assume it
62 # did not come from a browser and that it can be passed on.
63 if origin is not None and not self.check_origin(origin):
64 self.set_status(403)
65 self.finish("Cross origin websockets not allowed")
66 return
67
68 self.stream = self.request.connection.detach()
69 self.stream.set_close_callback(self.on_connection_close)
70
71 if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
72 self.ws_connection = WebSocketProtocol13(
73 self, compression_options=self.get_compression_options())
74 self.ws_connection.accept_connection()
75 #--------------- BEGIN PATCH ----------------
76 elif (self.allow_draft76() and
77 "Sec-WebSocket-Version" not in self.request.headers):
78 self.ws_connection = WebSocketProtocol76(self)
79 self.ws_connection.accept_connection()
80 #--------------- END PATCH ----------------
81 else:
82 if not self.stream.closed():
83 self.stream.write(tornado.escape.utf8(
84 "HTTP/1.1 426 Upgrade Required\r\n"
85 "Sec-WebSocket-Version: 8\r\n\r\n"))
86 self.stream.close()
87
88 # 3.2 methods removed in 4.0:
89 def allow_draft76(self):
90 """Using this class allows draft76 connections by default"""
91 return True
92
93 def get_websocket_scheme(self):
94 """Return the url scheme used for this request, either "ws" or "wss".
95 This is normally decided by HTTPServer, but applications
96 may wish to override this if they are using an SSL proxy
97 that does not provide the X-Scheme header as understood
98 by HTTPServer.
99 Note that this is only used by the draft76 protocol.
100 """
101 return "wss" if self.request.protocol == "https" else "ws"
102
103
104
105 # No modifications from tornado-3.2.2 below this line
106
107 class WebSocketProtocol(object):
108 """Base class for WebSocket protocol versions.
109 """
110 def __init__(self, handler):
111 self.handler = handler
112 self.request = handler.request
113 self.stream = handler.stream
114 self.client_terminated = False
115 self.server_terminated = False
116
117 def async_callback(self, callback, *args, **kwargs):
118 """Wrap callbacks with this if they are used on asynchronous requests.
119
120 Catches exceptions properly and closes this WebSocket if an exception
121 is uncaught.
122 """
123 if args or kwargs:
124 callback = functools.partial(callback, *args, **kwargs)
125
126 def wrapper(*args, **kwargs):
127 try:
128 return callback(*args, **kwargs)
129 except Exception:
130 app_log.error("Uncaught exception in %s",
131 self.request.path, exc_info=True)
132 self._abort()
133 return wrapper
134
135 def on_connection_close(self):
136 self._abort()
137
138 def _abort(self):
139 """Instantly aborts the WebSocket connection by closing the socket"""
140 self.client_terminated = True
141 self.server_terminated = True
142 self.stream.close() # forcibly tear down the connection
143 self.close() # let the subclass cleanup
144
145
146 class WebSocketProtocol76(WebSocketProtocol):
147 """Implementation of the WebSockets protocol, version hixie-76.
148
149 This class provides basic functionality to process WebSockets requests as
150 specified in
151 http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76
152 """
153 def __init__(self, handler):
154 WebSocketProtocol.__init__(self, handler)
155 self.challenge = None
156 self._waiting = None
157
158 def accept_connection(self):
159 try:
160 self._handle_websocket_headers()
161 except ValueError:
162 gen_log.debug("Malformed WebSocket request received")
163 self._abort()
164 return
165
166 scheme = self.handler.get_websocket_scheme()
167
168 # draft76 only allows a single subprotocol
169 subprotocol_header = ''
170 subprotocol = self.request.headers.get("Sec-WebSocket-Protocol", None)
171 if subprotocol:
172 selected = self.handler.select_subprotocol([subprotocol])
173 if selected:
174 assert selected == subprotocol
175 subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
176
177 # Write the initial headers before attempting to read the challenge.
178 # This is necessary when using proxies (such as HAProxy), which
179 # need to see the Upgrade headers before passing through the
180 # non-HTTP traffic that follows.
181 self.stream.write(tornado.escape.utf8(
182 "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
183 "Upgrade: WebSocket\r\n"
184 "Connection: Upgrade\r\n"
185 "Server: TornadoServer/%(version)s\r\n"
186 "Sec-WebSocket-Origin: %(origin)s\r\n"
187 "Sec-WebSocket-Location: %(scheme)s://%(host)s%(uri)s\r\n"
188 "%(subprotocol)s"
189 "\r\n" % (dict(
190 version=tornado.version,
191 origin=self.request.headers["Origin"],
192 scheme=scheme,
193 host=self.request.host,
194 uri=self.request.uri,
195 subprotocol=subprotocol_header))))
196 self.stream.read_bytes(8, self._handle_challenge)
197
198 def challenge_response(self, challenge):
199 """Generates the challenge response that's needed in the handshake
200
201 The challenge parameter should be the raw bytes as sent from the
202 client.
203 """
204 key_1 = self.request.headers.get("Sec-Websocket-Key1")
205 key_2 = self.request.headers.get("Sec-Websocket-Key2")
206 try:
207 part_1 = self._calculate_part(key_1)
208 part_2 = self._calculate_part(key_2)
209 except ValueError:
210 raise ValueError("Invalid Keys/Challenge")
211 return self._generate_challenge_response(part_1, part_2, challenge)
212
213 def _handle_challenge(self, challenge):
214 try:
215 challenge_response = self.challenge_response(challenge)
216 except ValueError:
217 gen_log.debug("Malformed key data in WebSocket request")
218 self._abort()
219 return
220 self._write_response(challenge_response)
221
222 def _write_response(self, challenge):
223 self.stream.write(challenge)
224 self.async_callback(self.handler.open)(*self.handler.open_args, **self.handler.open_kwargs)
225 self._receive_message()
226
227 def _handle_websocket_headers(self):
228 """Verifies all invariant- and required headers
229
230 If a header is missing or have an incorrect value ValueError will be
231 raised
232 """
233 fields = ("Origin", "Host", "Sec-Websocket-Key1",
234 "Sec-Websocket-Key2")
235 if not all(map(lambda f: self.request.headers.get(f), fields)):
236 raise ValueError("Missing/Invalid WebSocket headers")
237
238 def _calculate_part(self, key):
239 """Processes the key headers and calculates their key value.
240
241 Raises ValueError when feed invalid key."""
242 # pyflakes complains about variable reuse if both of these lines use 'c'
243 number = int(''.join(c for c in key if c.isdigit()))
244 spaces = len([c2 for c2 in key if c2.isspace()])
245 try:
246 key_number = number // spaces
247 except (ValueError, ZeroDivisionError):
248 raise ValueError
249 return struct.pack(">I", key_number)
250
251 def _generate_challenge_response(self, part_1, part_2, part_3):
252 m = hashlib.md5()
253 m.update(part_1)
254 m.update(part_2)
255 m.update(part_3)
256 return m.digest()
257
258 def _receive_message(self):
259 self.stream.read_bytes(1, self._on_frame_type)
260
261 def _on_frame_type(self, byte):
262 frame_type = ord(byte)
263 if frame_type == 0x00:
264 self.stream.read_until(b"\xff", self._on_end_delimiter)
265 elif frame_type == 0xff:
266 self.stream.read_bytes(1, self._on_length_indicator)
267 else:
268 self._abort()
269
270 def _on_end_delimiter(self, frame):
271 if not self.client_terminated:
272 self.async_callback(self.handler.on_message)(
273 frame[:-1].decode("utf-8", "replace"))
274 if not self.client_terminated:
275 self._receive_message()
276
277 def _on_length_indicator(self, byte):
278 if ord(byte) != 0x00:
279 self._abort()
280 return
281 self.client_terminated = True
282 self.close()
283
284 def write_message(self, message, binary=False):
285 """Sends the given message to the client of this Web Socket."""
286 if binary:
287 raise ValueError(
288 "Binary messages not supported by this version of websockets")
289 if isinstance(message, unicode_type):
290 message = message.encode("utf-8")
291 assert isinstance(message, bytes_type)
292 self.stream.write(b"\x00" + message + b"\xff")
293
294 def write_ping(self, data):
295 """Send ping frame."""
296 raise ValueError("Ping messages not supported by this version of websockets")
297
298 def close(self):
299 """Closes the WebSocket connection."""
300 if not self.server_terminated:
301 if not self.stream.closed():
302 self.stream.write("\xff\x00")
303 self.server_terminated = True
304 if self.client_terminated:
305 if self._waiting is not None:
306 self.stream.io_loop.remove_timeout(self._waiting)
307 self._waiting = None
308 self.stream.close()
309 elif self._waiting is None:
310 self._waiting = self.stream.io_loop.add_timeout(
311 time.time() + 5, self._abort)
312
@@ -1,271 +1,272
1 1 # coding: utf-8
2 2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
3 3
4 4 # Copyright (c) IPython Development Team.
5 5 # Distributed under the terms of the Modified BSD License.
6 6
7 import os
7 8 import json
8 9 import struct
10 import warnings
9 11
10 12 try:
11 13 from urllib.parse import urlparse # Py 3
12 14 except ImportError:
13 15 from urlparse import urlparse # Py 2
14 16
15 17 import tornado
16 from tornado import gen, ioloop, web, websocket
18 from tornado import gen, ioloop, web
19 from tornado.websocket import WebSocketHandler
17 20
18 21 from IPython.kernel.zmq.session import Session
19 22 from IPython.utils.jsonutil import date_default, extract_dates
20 23 from IPython.utils.py3compat import cast_unicode
21 24
22 25 from .handlers import IPythonHandler
23 26
24
25 27 def serialize_binary_message(msg):
26 28 """serialize a message as a binary blob
27 29
28 30 Header:
29 31
30 32 4 bytes: number of msg parts (nbufs) as 32b int
31 33 4 * nbufs bytes: offset for each buffer as integer as 32b int
32 34
33 35 Offsets are from the start of the buffer, including the header.
34 36
35 37 Returns
36 38 -------
37 39
38 40 The message serialized to bytes.
39 41
40 42 """
41 43 # don't modify msg or buffer list in-place
42 44 msg = msg.copy()
43 45 buffers = list(msg.pop('buffers'))
44 46 bmsg = json.dumps(msg, default=date_default).encode('utf8')
45 47 buffers.insert(0, bmsg)
46 48 nbufs = len(buffers)
47 49 offsets = [4 * (nbufs + 1)]
48 50 for buf in buffers[:-1]:
49 51 offsets.append(offsets[-1] + len(buf))
50 52 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
51 53 buffers.insert(0, offsets_buf)
52 54 return b''.join(buffers)
53 55
54 56
55 57 def deserialize_binary_message(bmsg):
56 58 """deserialize a message from a binary blog
57 59
58 60 Header:
59 61
60 62 4 bytes: number of msg parts (nbufs) as 32b int
61 63 4 * nbufs bytes: offset for each buffer as integer as 32b int
62 64
63 65 Offsets are from the start of the buffer, including the header.
64 66
65 67 Returns
66 68 -------
67 69
68 70 message dictionary
69 71 """
70 72 nbufs = struct.unpack('!i', bmsg[:4])[0]
71 73 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
72 74 offsets.append(None)
73 75 bufs = []
74 76 for start, stop in zip(offsets[:-1], offsets[1:]):
75 77 bufs.append(bmsg[start:stop])
76 78 msg = json.loads(bufs[0].decode('utf8'))
77 79 msg['header'] = extract_dates(msg['header'])
78 80 msg['parent_header'] = extract_dates(msg['parent_header'])
79 81 msg['buffers'] = bufs[1:]
80 82 return msg
81 83
84 # ping interval for keeping websockets alive (30 seconds)
85 WS_PING_INTERVAL = 30000
86
87 if os.environ.get('IPYTHON_ALLOW_DRAFT_WEBSOCKETS_FOR_PHANTOMJS', False):
88 warnings.warn("""Allowing draft76 websocket connections!
89 This should only be done for testing with phantomjs!""")
90 from IPython.html import allow76
91 WebSocketHandler = allow76.AllowDraftWebSocketHandler
92 # draft 76 doesn't support ping
93 WS_PING_INTERVAL = 0
82 94
83 class ZMQStreamHandler(websocket.WebSocketHandler):
95 class ZMQStreamHandler(WebSocketHandler):
84 96
85 97 def check_origin(self, origin):
86 98 """Check Origin == Host or Access-Control-Allow-Origin.
87 99
88 100 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
89 101 We call it explicitly in `open` on Tornado < 4.
90 102 """
91 103 if self.allow_origin == '*':
92 104 return True
93 105
94 106 host = self.request.headers.get("Host")
95 107
96 108 # If no header is provided, assume we can't verify origin
97 109 if origin is None:
98 110 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
99 111 return False
100 112 if host is None:
101 113 self.log.warn("Missing Host header, rejecting WebSocket connection.")
102 114 return False
103 115
104 116 origin = origin.lower()
105 117 origin_host = urlparse(origin).netloc
106 118
107 119 # OK if origin matches host
108 120 if origin_host == host:
109 121 return True
110 122
111 123 # Check CORS headers
112 124 if self.allow_origin:
113 125 allow = self.allow_origin == origin
114 126 elif self.allow_origin_pat:
115 127 allow = bool(self.allow_origin_pat.match(origin))
116 128 else:
117 129 # No CORS headers deny the request
118 130 allow = False
119 131 if not allow:
120 132 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
121 133 origin, host,
122 134 )
123 135 return allow
124 136
125 137 def clear_cookie(self, *args, **kwargs):
126 138 """meaningless for websockets"""
127 139 pass
128 140
129 141 def _reserialize_reply(self, msg_list):
130 142 """Reserialize a reply message using JSON.
131 143
132 144 This takes the msg list from the ZMQ socket, deserializes it using
133 145 self.session and then serializes the result using JSON. This method
134 146 should be used by self._on_zmq_reply to build messages that can
135 147 be sent back to the browser.
136 148 """
137 149 idents, msg_list = self.session.feed_identities(msg_list)
138 150 msg = self.session.deserialize(msg_list)
139 151 if msg['buffers']:
140 152 buf = serialize_binary_message(msg)
141 153 return buf
142 154 else:
143 155 smsg = json.dumps(msg, default=date_default)
144 156 return cast_unicode(smsg)
145 157
146 158 def _on_zmq_reply(self, msg_list):
147 159 # Sometimes this gets triggered when the on_close method is scheduled in the
148 160 # eventloop but hasn't been called.
149 161 if self.stream.closed(): return
150 162 try:
151 163 msg = self._reserialize_reply(msg_list)
152 164 except Exception:
153 165 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
154 166 else:
155 167 self.write_message(msg, binary=isinstance(msg, bytes))
156 168
157 def allow_draft76(self):
158 """Allow draft 76, until browsers such as Safari update to RFC 6455.
159
160 This has been disabled by default in tornado in release 2.2.0, and
161 support will be removed in later versions.
162 """
163 return True
164
165 # ping interval for keeping websockets alive (30 seconds)
166 WS_PING_INTERVAL = 30000
167
168 169 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
169 170 ping_callback = None
170 171 last_ping = 0
171 172 last_pong = 0
172 173
173 174 @property
174 175 def ping_interval(self):
175 176 """The interval for websocket keep-alive pings.
176 177
177 178 Set ws_ping_interval = 0 to disable pings.
178 179 """
179 180 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
180 181
181 182 @property
182 183 def ping_timeout(self):
183 184 """If no ping is received in this many milliseconds,
184 185 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
185 186 Default is max of 3 pings or 30 seconds.
186 187 """
187 188 return self.settings.get('ws_ping_timeout',
188 189 max(3 * self.ping_interval, WS_PING_INTERVAL)
189 190 )
190 191
191 192 def set_default_headers(self):
192 193 """Undo the set_default_headers in IPythonHandler
193 194
194 195 which doesn't make sense for websockets
195 196 """
196 197 pass
197 198
198 199 def pre_get(self):
199 200 """Run before finishing the GET request
200 201
201 202 Extend this method to add logic that should fire before
202 203 the websocket finishes completing.
203 204 """
204 205 # Check to see that origin matches host directly, including ports
205 206 # Tornado 4 already does CORS checking
206 207 if tornado.version_info[0] < 4:
207 208 if not self.check_origin(self.get_origin()):
208 209 raise web.HTTPError(403)
209 210
210 211 # authenticate the request before opening the websocket
211 212 if self.get_current_user() is None:
212 213 self.log.warn("Couldn't authenticate WebSocket connection")
213 214 raise web.HTTPError(403)
214 215
215 216 if self.get_argument('session_id', False):
216 217 self.session.session = cast_unicode(self.get_argument('session_id'))
217 218 else:
218 219 self.log.warn("No session ID specified")
219 220
220 221 @gen.coroutine
221 222 def get(self, *args, **kwargs):
222 223 # pre_get can be a coroutine in subclasses
223 224 # assign and yield in two step to avoid tornado 3 issues
224 225 res = self.pre_get()
225 226 yield gen.maybe_future(res)
226 227 # FIXME: only do super get on tornado ≥ 4
227 228 # tornado 3 has no get, will raise 405
228 229 if tornado.version_info >= (4,):
229 230 super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
230 231
231 232 def initialize(self):
232 233 self.log.debug("Initializing websocket connection %s", self.request.path)
233 234 self.session = Session(config=self.config)
234 235
235 236 def open(self, *args, **kwargs):
236 237 self.log.debug("Opening websocket %s", self.request.path)
237 238 if tornado.version_info < (4,):
238 239 try:
239 240 self.get(*self.open_args, **self.open_kwargs)
240 241 except web.HTTPError:
241 242 self.close()
242 243 raise
243 244
244 245 # start the pinging
245 246 if self.ping_interval > 0:
246 247 self.last_ping = ioloop.IOLoop.instance().time() # Remember time of last ping
247 248 self.last_pong = self.last_ping
248 249 self.ping_callback = ioloop.PeriodicCallback(self.send_ping, self.ping_interval)
249 250 self.ping_callback.start()
250 251
251 252 def send_ping(self):
252 253 """send a ping to keep the websocket alive"""
253 254 if self.stream.closed() and self.ping_callback is not None:
254 255 self.ping_callback.stop()
255 256 return
256 257
257 258 # check for timeout on pong. Make sure that we really have sent a recent ping in
258 259 # case the machine with both server and client has been suspended since the last ping.
259 260 now = ioloop.IOLoop.instance().time()
260 261 since_last_pong = 1e3 * (now - self.last_pong)
261 262 since_last_ping = 1e3 * (now - self.last_ping)
262 263 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
263 264 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
264 265 self.close()
265 266 return
266 267
267 268 self.ping(b'')
268 269 self.last_ping = now
269 270
270 271 def on_pong(self, data):
271 272 self.last_pong = ioloop.IOLoop.instance().time()
@@ -1,705 +1,713
1 1 # -*- coding: utf-8 -*-
2 2 """IPython Test Process Controller
3 3
4 4 This module runs one or more subprocesses which will actually run the IPython
5 5 test suite.
6 6
7 7 """
8 8
9 9 # Copyright (c) IPython Development Team.
10 10 # Distributed under the terms of the Modified BSD License.
11 11
12 12 from __future__ import print_function
13 13
14 14 import argparse
15 15 import json
16 16 import multiprocessing.pool
17 17 import os
18 18 import re
19 19 import requests
20 20 import shutil
21 21 import signal
22 22 import sys
23 23 import subprocess
24 24 import time
25 25
26 26 from .iptest import (
27 27 have, test_group_names as py_test_group_names, test_sections, StreamCapturer,
28 28 test_for,
29 29 )
30 30 from IPython.utils.path import compress_user
31 31 from IPython.utils.py3compat import bytes_to_str
32 32 from IPython.utils.sysinfo import get_sys_info
33 33 from IPython.utils.tempdir import TemporaryDirectory
34 34 from IPython.utils.text import strip_ansi
35 35
36 36 try:
37 37 # Python >= 3.3
38 38 from subprocess import TimeoutExpired
39 39 def popen_wait(p, timeout):
40 40 return p.wait(timeout)
41 41 except ImportError:
42 42 class TimeoutExpired(Exception):
43 43 pass
44 44 def popen_wait(p, timeout):
45 45 """backport of Popen.wait from Python 3"""
46 46 for i in range(int(10 * timeout)):
47 47 if p.poll() is not None:
48 48 return
49 49 time.sleep(0.1)
50 50 if p.poll() is None:
51 51 raise TimeoutExpired
52 52
53 53 NOTEBOOK_SHUTDOWN_TIMEOUT = 10
54 54
55 55 class TestController(object):
56 56 """Run tests in a subprocess
57 57 """
58 58 #: str, IPython test suite to be executed.
59 59 section = None
60 60 #: list, command line arguments to be executed
61 61 cmd = None
62 62 #: dict, extra environment variables to set for the subprocess
63 63 env = None
64 64 #: list, TemporaryDirectory instances to clear up when the process finishes
65 65 dirs = None
66 66 #: subprocess.Popen instance
67 67 process = None
68 68 #: str, process stdout+stderr
69 69 stdout = None
70 70
71 71 def __init__(self):
72 72 self.cmd = []
73 73 self.env = {}
74 74 self.dirs = []
75 75
76 76 def setup(self):
77 77 """Create temporary directories etc.
78 78
79 79 This is only called when we know the test group will be run. Things
80 80 created here may be cleaned up by self.cleanup().
81 81 """
82 82 pass
83 83
84 84 def launch(self, buffer_output=False, capture_output=False):
85 85 # print('*** ENV:', self.env) # dbg
86 86 # print('*** CMD:', self.cmd) # dbg
87 87 env = os.environ.copy()
88 88 env.update(self.env)
89 89 if buffer_output:
90 90 capture_output = True
91 91 self.stdout_capturer = c = StreamCapturer(echo=not buffer_output)
92 92 c.start()
93 93 stdout = c.writefd if capture_output else None
94 94 stderr = subprocess.STDOUT if capture_output else None
95 95 self.process = subprocess.Popen(self.cmd, stdout=stdout,
96 96 stderr=stderr, env=env)
97 97
98 98 def wait(self):
99 99 self.process.wait()
100 100 self.stdout_capturer.halt()
101 101 self.stdout = self.stdout_capturer.get_buffer()
102 102 return self.process.returncode
103 103
104 104 def print_extra_info(self):
105 105 """Print extra information about this test run.
106 106
107 107 If we're running in parallel and showing the concise view, this is only
108 108 called if the test group fails. Otherwise, it's called before the test
109 109 group is started.
110 110
111 111 The base implementation does nothing, but it can be overridden by
112 112 subclasses.
113 113 """
114 114 return
115 115
116 116 def cleanup_process(self):
117 117 """Cleanup on exit by killing any leftover processes."""
118 118 subp = self.process
119 119 if subp is None or (subp.poll() is not None):
120 120 return # Process doesn't exist, or is already dead.
121 121
122 122 try:
123 123 print('Cleaning up stale PID: %d' % subp.pid)
124 124 subp.kill()
125 125 except: # (OSError, WindowsError) ?
126 126 # This is just a best effort, if we fail or the process was
127 127 # really gone, ignore it.
128 128 pass
129 129 else:
130 130 for i in range(10):
131 131 if subp.poll() is None:
132 132 time.sleep(0.1)
133 133 else:
134 134 break
135 135
136 136 if subp.poll() is None:
137 137 # The process did not die...
138 138 print('... failed. Manual cleanup may be required.')
139 139
140 140 def cleanup(self):
141 141 "Kill process if it's still alive, and clean up temporary directories"
142 142 self.cleanup_process()
143 143 for td in self.dirs:
144 144 td.cleanup()
145 145
146 146 __del__ = cleanup
147 147
148 148
149 149 class PyTestController(TestController):
150 150 """Run Python tests using IPython.testing.iptest"""
151 151 #: str, Python command to execute in subprocess
152 152 pycmd = None
153 153
154 154 def __init__(self, section, options):
155 155 """Create new test runner."""
156 156 TestController.__init__(self)
157 157 self.section = section
158 158 # pycmd is put into cmd[2] in PyTestController.launch()
159 159 self.cmd = [sys.executable, '-c', None, section]
160 160 self.pycmd = "from IPython.testing.iptest import run_iptest; run_iptest()"
161 161 self.options = options
162 162
163 163 def setup(self):
164 164 ipydir = TemporaryDirectory()
165 165 self.dirs.append(ipydir)
166 166 self.env['IPYTHONDIR'] = ipydir.name
167 167 self.workingdir = workingdir = TemporaryDirectory()
168 168 self.dirs.append(workingdir)
169 169 self.env['IPTEST_WORKING_DIR'] = workingdir.name
170 170 # This means we won't get odd effects from our own matplotlib config
171 171 self.env['MPLCONFIGDIR'] = workingdir.name
172 172
173 173 # From options:
174 174 if self.options.xunit:
175 175 self.add_xunit()
176 176 if self.options.coverage:
177 177 self.add_coverage()
178 178 self.env['IPTEST_SUBPROC_STREAMS'] = self.options.subproc_streams
179 179 self.cmd.extend(self.options.extra_args)
180 180
181 181 @property
182 182 def will_run(self):
183 183 try:
184 184 return test_sections[self.section].will_run
185 185 except KeyError:
186 186 return True
187 187
188 188 def add_xunit(self):
189 189 xunit_file = os.path.abspath(self.section + '.xunit.xml')
190 190 self.cmd.extend(['--with-xunit', '--xunit-file', xunit_file])
191 191
192 192 def add_coverage(self):
193 193 try:
194 194 sources = test_sections[self.section].includes
195 195 except KeyError:
196 196 sources = ['IPython']
197 197
198 198 coverage_rc = ("[run]\n"
199 199 "data_file = {data_file}\n"
200 200 "source =\n"
201 201 " {source}\n"
202 202 ).format(data_file=os.path.abspath('.coverage.'+self.section),
203 203 source="\n ".join(sources))
204 204 config_file = os.path.join(self.workingdir.name, '.coveragerc')
205 205 with open(config_file, 'w') as f:
206 206 f.write(coverage_rc)
207 207
208 208 self.env['COVERAGE_PROCESS_START'] = config_file
209 209 self.pycmd = "import coverage; coverage.process_startup(); " + self.pycmd
210 210
211 211 def launch(self, buffer_output=False):
212 212 self.cmd[2] = self.pycmd
213 213 super(PyTestController, self).launch(buffer_output=buffer_output)
214 214
215 215
216 216 js_prefix = 'js/'
217 217
218 218 def get_js_test_dir():
219 219 import IPython.html.tests as t
220 220 return os.path.join(os.path.dirname(t.__file__), '')
221 221
222 222 def all_js_groups():
223 223 import glob
224 224 test_dir = get_js_test_dir()
225 225 all_subdirs = glob.glob(test_dir + '[!_]*/')
226 226 return [js_prefix+os.path.relpath(x, test_dir) for x in all_subdirs]
227 227
228 228 class JSController(TestController):
229 229 """Run CasperJS tests """
230 230
231 231 requirements = ['zmq', 'tornado', 'jinja2', 'casperjs', 'sqlite3',
232 232 'jsonschema']
233 233
234 234 def __init__(self, section, xunit=True, engine='phantomjs', url=None):
235 235 """Create new test runner."""
236 236 TestController.__init__(self)
237 237 self.engine = engine
238 238 self.section = section
239 239 self.xunit = xunit
240 240 self.url = url
241 241 self.slimer_failure = re.compile('^FAIL.*', flags=re.MULTILINE)
242 242 js_test_dir = get_js_test_dir()
243 243 includes = '--includes=' + os.path.join(js_test_dir,'util.js')
244 244 test_cases = os.path.join(js_test_dir, self.section[len(js_prefix):])
245 245 self.cmd = ['casperjs', 'test', includes, test_cases, '--engine=%s' % self.engine]
246 246
247 247 def setup(self):
248 248 self.ipydir = TemporaryDirectory()
249 249 self.nbdir = TemporaryDirectory()
250 250 self.dirs.append(self.ipydir)
251 251 self.dirs.append(self.nbdir)
252 252 os.makedirs(os.path.join(self.nbdir.name, os.path.join(u'sub ∂ir1', u'sub ∂ir 1a')))
253 253 os.makedirs(os.path.join(self.nbdir.name, os.path.join(u'sub ∂ir2', u'sub ∂ir 1b')))
254 254
255 255 if self.xunit:
256 256 self.add_xunit()
257 257
258 258 # If a url was specified, use that for the testing.
259 259 if self.url:
260 260 try:
261 261 alive = requests.get(self.url).status_code == 200
262 262 except:
263 263 alive = False
264 264
265 265 if alive:
266 266 self.cmd.append("--url=%s" % self.url)
267 267 else:
268 268 raise Exception('Could not reach "%s".' % self.url)
269 269 else:
270 270 # start the ipython notebook, so we get the port number
271 271 self.server_port = 0
272 272 self._init_server()
273 273 if self.server_port:
274 274 self.cmd.append("--port=%i" % self.server_port)
275 275 else:
276 276 # don't launch tests if the server didn't start
277 277 self.cmd = [sys.executable, '-c', 'raise SystemExit(1)']
278 278
279 279 def add_xunit(self):
280 280 xunit_file = os.path.abspath(self.section.replace('/','.') + '.xunit.xml')
281 281 self.cmd.append('--xunit=%s' % xunit_file)
282 282
283 283 def launch(self, buffer_output):
284 284 # If the engine is SlimerJS, we need to buffer the output because
285 285 # SlimerJS does not support exit codes, so CasperJS always returns 0.
286 286 if self.engine == 'slimerjs' and not buffer_output:
287 287 return super(JSController, self).launch(capture_output=True)
288 288
289 289 else:
290 290 return super(JSController, self).launch(buffer_output=buffer_output)
291 291
292 292 def wait(self, *pargs, **kwargs):
293 293 """Wait for the JSController to finish"""
294 294 ret = super(JSController, self).wait(*pargs, **kwargs)
295 295 # If this is a SlimerJS controller, check the captured stdout for
296 296 # errors. Otherwise, just return the return code.
297 297 if self.engine == 'slimerjs':
298 298 stdout = bytes_to_str(self.stdout)
299 299 if ret != 0:
300 300 # This could still happen e.g. if it's stopped by SIGINT
301 301 return ret
302 302 return bool(self.slimer_failure.search(strip_ansi(stdout)))
303 303 else:
304 304 return ret
305 305
306 306 def print_extra_info(self):
307 307 print("Running tests with notebook directory %r" % self.nbdir.name)
308 308
309 309 @property
310 310 def will_run(self):
311 311 should_run = all(have[a] for a in self.requirements + [self.engine])
312 312 return should_run
313 313
314 314 def _init_server(self):
315 315 "Start the notebook server in a separate process"
316 316 self.server_command = command = [sys.executable,
317 317 '-m', 'IPython.html',
318 318 '--no-browser',
319 319 '--ipython-dir', self.ipydir.name,
320 320 '--notebook-dir', self.nbdir.name,
321 321 ]
322 322 # ipc doesn't work on Windows, and darwin has crazy-long temp paths,
323 323 # which run afoul of ipc's maximum path length.
324 324 if sys.platform.startswith('linux'):
325 325 command.append('--KernelManager.transport=ipc')
326 326 self.stream_capturer = c = StreamCapturer()
327 327 c.start()
328 self.server = subprocess.Popen(command, stdout=c.writefd, stderr=subprocess.STDOUT, cwd=self.nbdir.name)
328 env = os.environ.copy()
329 if self.engine == 'phantomjs':
330 env['IPYTHON_ALLOW_DRAFT_WEBSOCKETS_FOR_PHANTOMJS'] = '1'
331 self.server = subprocess.Popen(command,
332 stdout=c.writefd,
333 stderr=subprocess.STDOUT,
334 cwd=self.nbdir.name,
335 env=env,
336 )
329 337 self.server_info_file = os.path.join(self.ipydir.name,
330 338 'profile_default', 'security', 'nbserver-%i.json' % self.server.pid
331 339 )
332 340 self._wait_for_server()
333 341
334 342 def _wait_for_server(self):
335 343 """Wait 30 seconds for the notebook server to start"""
336 344 for i in range(300):
337 345 if self.server.poll() is not None:
338 346 return self._failed_to_start()
339 347 if os.path.exists(self.server_info_file):
340 348 try:
341 349 self._load_server_info()
342 350 except ValueError:
343 351 # If the server is halfway through writing the file, we may
344 352 # get invalid JSON; it should be ready next iteration.
345 353 pass
346 354 else:
347 355 return
348 356 time.sleep(0.1)
349 357 print("Notebook server-info file never arrived: %s" % self.server_info_file,
350 358 file=sys.stderr
351 359 )
352 360
353 361 def _failed_to_start(self):
354 362 """Notebook server exited prematurely"""
355 363 captured = self.stream_capturer.get_buffer().decode('utf-8', 'replace')
356 364 print("Notebook failed to start: ", file=sys.stderr)
357 365 print(self.server_command)
358 366 print(captured, file=sys.stderr)
359 367
360 368 def _load_server_info(self):
361 369 """Notebook server started, load connection info from JSON"""
362 370 with open(self.server_info_file) as f:
363 371 info = json.load(f)
364 372 self.server_port = info['port']
365 373
366 374 def cleanup(self):
367 375 try:
368 376 self.server.terminate()
369 377 except OSError:
370 378 # already dead
371 379 pass
372 380 # wait 10s for the server to shutdown
373 381 try:
374 382 popen_wait(self.server, NOTEBOOK_SHUTDOWN_TIMEOUT)
375 383 except TimeoutExpired:
376 384 # server didn't terminate, kill it
377 385 try:
378 386 print("Failed to terminate notebook server, killing it.",
379 387 file=sys.stderr
380 388 )
381 389 self.server.kill()
382 390 except OSError:
383 391 # already dead
384 392 pass
385 393 # wait another 10s
386 394 try:
387 395 popen_wait(self.server, NOTEBOOK_SHUTDOWN_TIMEOUT)
388 396 except TimeoutExpired:
389 397 print("Notebook server still running (%s)" % self.server_info_file,
390 398 file=sys.stderr
391 399 )
392 400
393 401 self.stream_capturer.halt()
394 402 TestController.cleanup(self)
395 403
396 404
397 405 def prepare_controllers(options):
398 406 """Returns two lists of TestController instances, those to run, and those
399 407 not to run."""
400 408 testgroups = options.testgroups
401 409 if testgroups:
402 410 if 'js' in testgroups:
403 411 js_testgroups = all_js_groups()
404 412 else:
405 413 js_testgroups = [g for g in testgroups if g.startswith(js_prefix)]
406 414
407 415 py_testgroups = [g for g in testgroups if g not in ['js'] + js_testgroups]
408 416 else:
409 417 py_testgroups = py_test_group_names
410 418 if not options.all:
411 419 js_testgroups = []
412 420 test_sections['parallel'].enabled = False
413 421 else:
414 422 js_testgroups = all_js_groups()
415 423
416 424 engine = 'slimerjs' if options.slimerjs else 'phantomjs'
417 425 c_js = [JSController(name, xunit=options.xunit, engine=engine, url=options.url) for name in js_testgroups]
418 426 c_py = [PyTestController(name, options) for name in py_testgroups]
419 427
420 428 controllers = c_py + c_js
421 429 to_run = [c for c in controllers if c.will_run]
422 430 not_run = [c for c in controllers if not c.will_run]
423 431 return to_run, not_run
424 432
425 433 def do_run(controller, buffer_output=True):
426 434 """Setup and run a test controller.
427 435
428 436 If buffer_output is True, no output is displayed, to avoid it appearing
429 437 interleaved. In this case, the caller is responsible for displaying test
430 438 output on failure.
431 439
432 440 Returns
433 441 -------
434 442 controller : TestController
435 443 The same controller as passed in, as a convenience for using map() type
436 444 APIs.
437 445 exitcode : int
438 446 The exit code of the test subprocess. Non-zero indicates failure.
439 447 """
440 448 try:
441 449 try:
442 450 controller.setup()
443 451 if not buffer_output:
444 452 controller.print_extra_info()
445 453 controller.launch(buffer_output=buffer_output)
446 454 except Exception:
447 455 import traceback
448 456 traceback.print_exc()
449 457 return controller, 1 # signal failure
450 458
451 459 exitcode = controller.wait()
452 460 return controller, exitcode
453 461
454 462 except KeyboardInterrupt:
455 463 return controller, -signal.SIGINT
456 464 finally:
457 465 controller.cleanup()
458 466
459 467 def report():
460 468 """Return a string with a summary report of test-related variables."""
461 469 inf = get_sys_info()
462 470 out = []
463 471 def _add(name, value):
464 472 out.append((name, value))
465 473
466 474 _add('IPython version', inf['ipython_version'])
467 475 _add('IPython commit', "{} ({})".format(inf['commit_hash'], inf['commit_source']))
468 476 _add('IPython package', compress_user(inf['ipython_path']))
469 477 _add('Python version', inf['sys_version'].replace('\n',''))
470 478 _add('sys.executable', compress_user(inf['sys_executable']))
471 479 _add('Platform', inf['platform'])
472 480
473 481 width = max(len(n) for (n,v) in out)
474 482 out = ["{:<{width}}: {}\n".format(n, v, width=width) for (n,v) in out]
475 483
476 484 avail = []
477 485 not_avail = []
478 486
479 487 for k, is_avail in have.items():
480 488 if is_avail:
481 489 avail.append(k)
482 490 else:
483 491 not_avail.append(k)
484 492
485 493 if avail:
486 494 out.append('\nTools and libraries available at test time:\n')
487 495 avail.sort()
488 496 out.append(' ' + ' '.join(avail)+'\n')
489 497
490 498 if not_avail:
491 499 out.append('\nTools and libraries NOT available at test time:\n')
492 500 not_avail.sort()
493 501 out.append(' ' + ' '.join(not_avail)+'\n')
494 502
495 503 return ''.join(out)
496 504
497 505 def run_iptestall(options):
498 506 """Run the entire IPython test suite by calling nose and trial.
499 507
500 508 This function constructs :class:`IPTester` instances for all IPython
501 509 modules and package and then runs each of them. This causes the modules
502 510 and packages of IPython to be tested each in their own subprocess using
503 511 nose.
504 512
505 513 Parameters
506 514 ----------
507 515
508 516 All parameters are passed as attributes of the options object.
509 517
510 518 testgroups : list of str
511 519 Run only these sections of the test suite. If empty, run all the available
512 520 sections.
513 521
514 522 fast : int or None
515 523 Run the test suite in parallel, using n simultaneous processes. If None
516 524 is passed, one process is used per CPU core. Default 1 (i.e. sequential)
517 525
518 526 inc_slow : bool
519 527 Include slow tests, like IPython.parallel. By default, these tests aren't
520 528 run.
521 529
522 530 slimerjs : bool
523 531 Use slimerjs if it's installed instead of phantomjs for casperjs tests.
524 532
525 533 url : unicode
526 534 Address:port to use when running the JS tests.
527 535
528 536 xunit : bool
529 537 Produce Xunit XML output. This is written to multiple foo.xunit.xml files.
530 538
531 539 coverage : bool or str
532 540 Measure code coverage from tests. True will store the raw coverage data,
533 541 or pass 'html' or 'xml' to get reports.
534 542
535 543 extra_args : list
536 544 Extra arguments to pass to the test subprocesses, e.g. '-v'
537 545 """
538 546 to_run, not_run = prepare_controllers(options)
539 547
540 548 def justify(ltext, rtext, width=70, fill='-'):
541 549 ltext += ' '
542 550 rtext = (' ' + rtext).rjust(width - len(ltext), fill)
543 551 return ltext + rtext
544 552
545 553 # Run all test runners, tracking execution time
546 554 failed = []
547 555 t_start = time.time()
548 556
549 557 print()
550 558 if options.fast == 1:
551 559 # This actually means sequential, i.e. with 1 job
552 560 for controller in to_run:
553 561 print('Test group:', controller.section)
554 562 sys.stdout.flush() # Show in correct order when output is piped
555 563 controller, res = do_run(controller, buffer_output=False)
556 564 if res:
557 565 failed.append(controller)
558 566 if res == -signal.SIGINT:
559 567 print("Interrupted")
560 568 break
561 569 print()
562 570
563 571 else:
564 572 # Run tests concurrently
565 573 try:
566 574 pool = multiprocessing.pool.ThreadPool(options.fast)
567 575 for (controller, res) in pool.imap_unordered(do_run, to_run):
568 576 res_string = 'OK' if res == 0 else 'FAILED'
569 577 print(justify('Test group: ' + controller.section, res_string))
570 578 if res:
571 579 controller.print_extra_info()
572 580 print(bytes_to_str(controller.stdout))
573 581 failed.append(controller)
574 582 if res == -signal.SIGINT:
575 583 print("Interrupted")
576 584 break
577 585 except KeyboardInterrupt:
578 586 return
579 587
580 588 for controller in not_run:
581 589 print(justify('Test group: ' + controller.section, 'NOT RUN'))
582 590
583 591 t_end = time.time()
584 592 t_tests = t_end - t_start
585 593 nrunners = len(to_run)
586 594 nfail = len(failed)
587 595 # summarize results
588 596 print('_'*70)
589 597 print('Test suite completed for system with the following information:')
590 598 print(report())
591 599 took = "Took %.3fs." % t_tests
592 600 print('Status: ', end='')
593 601 if not failed:
594 602 print('OK (%d test groups).' % nrunners, took)
595 603 else:
596 604 # If anything went wrong, point out what command to rerun manually to
597 605 # see the actual errors and individual summary
598 606 failed_sections = [c.section for c in failed]
599 607 print('ERROR - {} out of {} test groups failed ({}).'.format(nfail,
600 608 nrunners, ', '.join(failed_sections)), took)
601 609 print()
602 610 print('You may wish to rerun these, with:')
603 611 print(' iptest', *failed_sections)
604 612 print()
605 613
606 614 if options.coverage:
607 615 from coverage import coverage
608 616 cov = coverage(data_file='.coverage')
609 617 cov.combine()
610 618 cov.save()
611 619
612 620 # Coverage HTML report
613 621 if options.coverage == 'html':
614 622 html_dir = 'ipy_htmlcov'
615 623 shutil.rmtree(html_dir, ignore_errors=True)
616 624 print("Writing HTML coverage report to %s/ ... " % html_dir, end="")
617 625 sys.stdout.flush()
618 626
619 627 # Custom HTML reporter to clean up module names.
620 628 from coverage.html import HtmlReporter
621 629 class CustomHtmlReporter(HtmlReporter):
622 630 def find_code_units(self, morfs):
623 631 super(CustomHtmlReporter, self).find_code_units(morfs)
624 632 for cu in self.code_units:
625 633 nameparts = cu.name.split(os.sep)
626 634 if 'IPython' not in nameparts:
627 635 continue
628 636 ix = nameparts.index('IPython')
629 637 cu.name = '.'.join(nameparts[ix:])
630 638
631 639 # Reimplement the html_report method with our custom reporter
632 640 cov._harvest_data()
633 641 cov.config.from_args(omit='*{0}tests{0}*'.format(os.sep), html_dir=html_dir,
634 642 html_title='IPython test coverage',
635 643 )
636 644 reporter = CustomHtmlReporter(cov, cov.config)
637 645 reporter.report(None)
638 646 print('done.')
639 647
640 648 # Coverage XML report
641 649 elif options.coverage == 'xml':
642 650 cov.xml_report(outfile='ipy_coverage.xml')
643 651
644 652 if failed:
645 653 # Ensure that our exit code indicates failure
646 654 sys.exit(1)
647 655
648 656 argparser = argparse.ArgumentParser(description='Run IPython test suite')
649 657 argparser.add_argument('testgroups', nargs='*',
650 658 help='Run specified groups of tests. If omitted, run '
651 659 'all tests.')
652 660 argparser.add_argument('--all', action='store_true',
653 661 help='Include slow tests not run by default.')
654 662 argparser.add_argument('--slimerjs', action='store_true',
655 663 help="Use slimerjs if it's installed instead of phantomjs for casperjs tests.")
656 664 argparser.add_argument('--url', help="URL to use for the JS tests.")
657 665 argparser.add_argument('-j', '--fast', nargs='?', const=None, default=1, type=int,
658 666 help='Run test sections in parallel. This starts as many '
659 667 'processes as you have cores, or you can specify a number.')
660 668 argparser.add_argument('--xunit', action='store_true',
661 669 help='Produce Xunit XML results')
662 670 argparser.add_argument('--coverage', nargs='?', const=True, default=False,
663 671 help="Measure test coverage. Specify 'html' or "
664 672 "'xml' to get reports.")
665 673 argparser.add_argument('--subproc-streams', default='capture',
666 674 help="What to do with stdout/stderr from subprocesses. "
667 675 "'capture' (default), 'show' and 'discard' are the options.")
668 676
669 677 def default_options():
670 678 """Get an argparse Namespace object with the default arguments, to pass to
671 679 :func:`run_iptestall`.
672 680 """
673 681 options = argparser.parse_args([])
674 682 options.extra_args = []
675 683 return options
676 684
677 685 def main():
678 686 # iptest doesn't work correctly if the working directory is the
679 687 # root of the IPython source tree. Tell the user to avoid
680 688 # frustration.
681 689 if os.path.exists(os.path.join(os.getcwd(),
682 690 'IPython', 'testing', '__main__.py')):
683 691 print("Don't run iptest from the IPython source directory",
684 692 file=sys.stderr)
685 693 sys.exit(1)
686 694 # Arguments after -- should be passed through to nose. Argparse treats
687 695 # everything after -- as regular positional arguments, so we separate them
688 696 # first.
689 697 try:
690 698 ix = sys.argv.index('--')
691 699 except ValueError:
692 700 to_parse = sys.argv[1:]
693 701 extra_args = []
694 702 else:
695 703 to_parse = sys.argv[1:ix]
696 704 extra_args = sys.argv[ix+1:]
697 705
698 706 options = argparser.parse_args(to_parse)
699 707 options.extra_args = extra_args
700 708
701 709 run_iptestall(options)
702 710
703 711
704 712 if __name__ == '__main__':
705 713 main()
General Comments 0
You need to be logged in to leave comments. Login now