##// END OF EJS Templates
forward-port draft76 websockets...
Min RK -
Show More
@@ -0,0 +1,312
1 """WebsocketProtocol76 from tornado 3.2.2 for tornado >= 4.0
2
3 The contents of this file are Copyright (c) Tornado
4 Used under the Apache 2.0 license
5 """
6
7
8 from __future__ import absolute_import, division, print_function, with_statement
9 # Author: Jacob Kristhammar, 2010
10
11 import functools
12 import hashlib
13 import struct
14 import time
15 import tornado.escape
16 import tornado.web
17
18 from tornado.log import gen_log, app_log
19 from tornado.util import bytes_type, unicode_type
20
21 from tornado.websocket import WebSocketHandler, WebSocketProtocol13
22
23 class AllowDraftWebSocketHandler(WebSocketHandler):
24 """Restore Draft76 support for tornado 4
25
26 Remove when we can run tests without phantomjs + qt4
27 """
28
29 # get is unmodified except between the BEGIN/END PATCH lines
30 @tornado.web.asynchronous
31 def get(self, *args, **kwargs):
32 self.open_args = args
33 self.open_kwargs = kwargs
34
35 # Upgrade header should be present and should be equal to WebSocket
36 if self.request.headers.get("Upgrade", "").lower() != 'websocket':
37 self.set_status(400)
38 self.finish("Can \"Upgrade\" only to \"WebSocket\".")
39 return
40
41 # Connection header should be upgrade. Some proxy servers/load balancers
42 # might mess with it.
43 headers = self.request.headers
44 connection = map(lambda s: s.strip().lower(), headers.get("Connection", "").split(","))
45 if 'upgrade' not in connection:
46 self.set_status(400)
47 self.finish("\"Connection\" must be \"Upgrade\".")
48 return
49
50 # Handle WebSocket Origin naming convention differences
51 # The difference between version 8 and 13 is that in 8 the
52 # client sends a "Sec-Websocket-Origin" header and in 13 it's
53 # simply "Origin".
54 if "Origin" in self.request.headers:
55 origin = self.request.headers.get("Origin")
56 else:
57 origin = self.request.headers.get("Sec-Websocket-Origin", None)
58
59
60 # If there was an origin header, check to make sure it matches
61 # according to check_origin. When the origin is None, we assume it
62 # did not come from a browser and that it can be passed on.
63 if origin is not None and not self.check_origin(origin):
64 self.set_status(403)
65 self.finish("Cross origin websockets not allowed")
66 return
67
68 self.stream = self.request.connection.detach()
69 self.stream.set_close_callback(self.on_connection_close)
70
71 if self.request.headers.get("Sec-WebSocket-Version") in ("7", "8", "13"):
72 self.ws_connection = WebSocketProtocol13(
73 self, compression_options=self.get_compression_options())
74 self.ws_connection.accept_connection()
75 #--------------- BEGIN PATCH ----------------
76 elif (self.allow_draft76() and
77 "Sec-WebSocket-Version" not in self.request.headers):
78 self.ws_connection = WebSocketProtocol76(self)
79 self.ws_connection.accept_connection()
80 #--------------- END PATCH ----------------
81 else:
82 if not self.stream.closed():
83 self.stream.write(tornado.escape.utf8(
84 "HTTP/1.1 426 Upgrade Required\r\n"
85 "Sec-WebSocket-Version: 8\r\n\r\n"))
86 self.stream.close()
87
88 # 3.2 methods removed in 4.0:
89 def allow_draft76(self):
90 """Using this class allows draft76 connections by default"""
91 return True
92
93 def get_websocket_scheme(self):
94 """Return the url scheme used for this request, either "ws" or "wss".
95 This is normally decided by HTTPServer, but applications
96 may wish to override this if they are using an SSL proxy
97 that does not provide the X-Scheme header as understood
98 by HTTPServer.
99 Note that this is only used by the draft76 protocol.
100 """
101 return "wss" if self.request.protocol == "https" else "ws"
102
103
104
105 # No modifications from tornado-3.2.2 below this line
106
107 class WebSocketProtocol(object):
108 """Base class for WebSocket protocol versions.
109 """
110 def __init__(self, handler):
111 self.handler = handler
112 self.request = handler.request
113 self.stream = handler.stream
114 self.client_terminated = False
115 self.server_terminated = False
116
117 def async_callback(self, callback, *args, **kwargs):
118 """Wrap callbacks with this if they are used on asynchronous requests.
119
120 Catches exceptions properly and closes this WebSocket if an exception
121 is uncaught.
122 """
123 if args or kwargs:
124 callback = functools.partial(callback, *args, **kwargs)
125
126 def wrapper(*args, **kwargs):
127 try:
128 return callback(*args, **kwargs)
129 except Exception:
130 app_log.error("Uncaught exception in %s",
131 self.request.path, exc_info=True)
132 self._abort()
133 return wrapper
134
135 def on_connection_close(self):
136 self._abort()
137
138 def _abort(self):
139 """Instantly aborts the WebSocket connection by closing the socket"""
140 self.client_terminated = True
141 self.server_terminated = True
142 self.stream.close() # forcibly tear down the connection
143 self.close() # let the subclass cleanup
144
145
146 class WebSocketProtocol76(WebSocketProtocol):
147 """Implementation of the WebSockets protocol, version hixie-76.
148
149 This class provides basic functionality to process WebSockets requests as
150 specified in
151 http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76
152 """
153 def __init__(self, handler):
154 WebSocketProtocol.__init__(self, handler)
155 self.challenge = None
156 self._waiting = None
157
158 def accept_connection(self):
159 try:
160 self._handle_websocket_headers()
161 except ValueError:
162 gen_log.debug("Malformed WebSocket request received")
163 self._abort()
164 return
165
166 scheme = self.handler.get_websocket_scheme()
167
168 # draft76 only allows a single subprotocol
169 subprotocol_header = ''
170 subprotocol = self.request.headers.get("Sec-WebSocket-Protocol", None)
171 if subprotocol:
172 selected = self.handler.select_subprotocol([subprotocol])
173 if selected:
174 assert selected == subprotocol
175 subprotocol_header = "Sec-WebSocket-Protocol: %s\r\n" % selected
176
177 # Write the initial headers before attempting to read the challenge.
178 # This is necessary when using proxies (such as HAProxy), which
179 # need to see the Upgrade headers before passing through the
180 # non-HTTP traffic that follows.
181 self.stream.write(tornado.escape.utf8(
182 "HTTP/1.1 101 WebSocket Protocol Handshake\r\n"
183 "Upgrade: WebSocket\r\n"
184 "Connection: Upgrade\r\n"
185 "Server: TornadoServer/%(version)s\r\n"
186 "Sec-WebSocket-Origin: %(origin)s\r\n"
187 "Sec-WebSocket-Location: %(scheme)s://%(host)s%(uri)s\r\n"
188 "%(subprotocol)s"
189 "\r\n" % (dict(
190 version=tornado.version,
191 origin=self.request.headers["Origin"],
192 scheme=scheme,
193 host=self.request.host,
194 uri=self.request.uri,
195 subprotocol=subprotocol_header))))
196 self.stream.read_bytes(8, self._handle_challenge)
197
198 def challenge_response(self, challenge):
199 """Generates the challenge response that's needed in the handshake
200
201 The challenge parameter should be the raw bytes as sent from the
202 client.
203 """
204 key_1 = self.request.headers.get("Sec-Websocket-Key1")
205 key_2 = self.request.headers.get("Sec-Websocket-Key2")
206 try:
207 part_1 = self._calculate_part(key_1)
208 part_2 = self._calculate_part(key_2)
209 except ValueError:
210 raise ValueError("Invalid Keys/Challenge")
211 return self._generate_challenge_response(part_1, part_2, challenge)
212
213 def _handle_challenge(self, challenge):
214 try:
215 challenge_response = self.challenge_response(challenge)
216 except ValueError:
217 gen_log.debug("Malformed key data in WebSocket request")
218 self._abort()
219 return
220 self._write_response(challenge_response)
221
222 def _write_response(self, challenge):
223 self.stream.write(challenge)
224 self.async_callback(self.handler.open)(*self.handler.open_args, **self.handler.open_kwargs)
225 self._receive_message()
226
227 def _handle_websocket_headers(self):
228 """Verifies all invariant- and required headers
229
230 If a header is missing or have an incorrect value ValueError will be
231 raised
232 """
233 fields = ("Origin", "Host", "Sec-Websocket-Key1",
234 "Sec-Websocket-Key2")
235 if not all(map(lambda f: self.request.headers.get(f), fields)):
236 raise ValueError("Missing/Invalid WebSocket headers")
237
238 def _calculate_part(self, key):
239 """Processes the key headers and calculates their key value.
240
241 Raises ValueError when feed invalid key."""
242 # pyflakes complains about variable reuse if both of these lines use 'c'
243 number = int(''.join(c for c in key if c.isdigit()))
244 spaces = len([c2 for c2 in key if c2.isspace()])
245 try:
246 key_number = number // spaces
247 except (ValueError, ZeroDivisionError):
248 raise ValueError
249 return struct.pack(">I", key_number)
250
251 def _generate_challenge_response(self, part_1, part_2, part_3):
252 m = hashlib.md5()
253 m.update(part_1)
254 m.update(part_2)
255 m.update(part_3)
256 return m.digest()
257
258 def _receive_message(self):
259 self.stream.read_bytes(1, self._on_frame_type)
260
261 def _on_frame_type(self, byte):
262 frame_type = ord(byte)
263 if frame_type == 0x00:
264 self.stream.read_until(b"\xff", self._on_end_delimiter)
265 elif frame_type == 0xff:
266 self.stream.read_bytes(1, self._on_length_indicator)
267 else:
268 self._abort()
269
270 def _on_end_delimiter(self, frame):
271 if not self.client_terminated:
272 self.async_callback(self.handler.on_message)(
273 frame[:-1].decode("utf-8", "replace"))
274 if not self.client_terminated:
275 self._receive_message()
276
277 def _on_length_indicator(self, byte):
278 if ord(byte) != 0x00:
279 self._abort()
280 return
281 self.client_terminated = True
282 self.close()
283
284 def write_message(self, message, binary=False):
285 """Sends the given message to the client of this Web Socket."""
286 if binary:
287 raise ValueError(
288 "Binary messages not supported by this version of websockets")
289 if isinstance(message, unicode_type):
290 message = message.encode("utf-8")
291 assert isinstance(message, bytes_type)
292 self.stream.write(b"\x00" + message + b"\xff")
293
294 def write_ping(self, data):
295 """Send ping frame."""
296 raise ValueError("Ping messages not supported by this version of websockets")
297
298 def close(self):
299 """Closes the WebSocket connection."""
300 if not self.server_terminated:
301 if not self.stream.closed():
302 self.stream.write("\xff\x00")
303 self.server_terminated = True
304 if self.client_terminated:
305 if self._waiting is not None:
306 self.stream.io_loop.remove_timeout(self._waiting)
307 self._waiting = None
308 self.stream.close()
309 elif self._waiting is None:
310 self._waiting = self.stream.io_loop.add_timeout(
311 time.time() + 5, self._abort)
312
@@ -1,271 +1,272
1 # coding: utf-8
1 # coding: utf-8
2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
2 """Tornado handlers for WebSocket <-> ZMQ sockets."""
3
3
4 # Copyright (c) IPython Development Team.
4 # Copyright (c) IPython Development Team.
5 # Distributed under the terms of the Modified BSD License.
5 # Distributed under the terms of the Modified BSD License.
6
6
7 import os
7 import json
8 import json
8 import struct
9 import struct
10 import warnings
9
11
10 try:
12 try:
11 from urllib.parse import urlparse # Py 3
13 from urllib.parse import urlparse # Py 3
12 except ImportError:
14 except ImportError:
13 from urlparse import urlparse # Py 2
15 from urlparse import urlparse # Py 2
14
16
15 import tornado
17 import tornado
16 from tornado import gen, ioloop, web, websocket
18 from tornado import gen, ioloop, web
19 from tornado.websocket import WebSocketHandler
17
20
18 from IPython.kernel.zmq.session import Session
21 from IPython.kernel.zmq.session import Session
19 from IPython.utils.jsonutil import date_default, extract_dates
22 from IPython.utils.jsonutil import date_default, extract_dates
20 from IPython.utils.py3compat import cast_unicode
23 from IPython.utils.py3compat import cast_unicode
21
24
22 from .handlers import IPythonHandler
25 from .handlers import IPythonHandler
23
26
24
25 def serialize_binary_message(msg):
27 def serialize_binary_message(msg):
26 """serialize a message as a binary blob
28 """serialize a message as a binary blob
27
29
28 Header:
30 Header:
29
31
30 4 bytes: number of msg parts (nbufs) as 32b int
32 4 bytes: number of msg parts (nbufs) as 32b int
31 4 * nbufs bytes: offset for each buffer as integer as 32b int
33 4 * nbufs bytes: offset for each buffer as integer as 32b int
32
34
33 Offsets are from the start of the buffer, including the header.
35 Offsets are from the start of the buffer, including the header.
34
36
35 Returns
37 Returns
36 -------
38 -------
37
39
38 The message serialized to bytes.
40 The message serialized to bytes.
39
41
40 """
42 """
41 # don't modify msg or buffer list in-place
43 # don't modify msg or buffer list in-place
42 msg = msg.copy()
44 msg = msg.copy()
43 buffers = list(msg.pop('buffers'))
45 buffers = list(msg.pop('buffers'))
44 bmsg = json.dumps(msg, default=date_default).encode('utf8')
46 bmsg = json.dumps(msg, default=date_default).encode('utf8')
45 buffers.insert(0, bmsg)
47 buffers.insert(0, bmsg)
46 nbufs = len(buffers)
48 nbufs = len(buffers)
47 offsets = [4 * (nbufs + 1)]
49 offsets = [4 * (nbufs + 1)]
48 for buf in buffers[:-1]:
50 for buf in buffers[:-1]:
49 offsets.append(offsets[-1] + len(buf))
51 offsets.append(offsets[-1] + len(buf))
50 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
52 offsets_buf = struct.pack('!' + 'I' * (nbufs + 1), nbufs, *offsets)
51 buffers.insert(0, offsets_buf)
53 buffers.insert(0, offsets_buf)
52 return b''.join(buffers)
54 return b''.join(buffers)
53
55
54
56
55 def deserialize_binary_message(bmsg):
57 def deserialize_binary_message(bmsg):
56 """deserialize a message from a binary blog
58 """deserialize a message from a binary blog
57
59
58 Header:
60 Header:
59
61
60 4 bytes: number of msg parts (nbufs) as 32b int
62 4 bytes: number of msg parts (nbufs) as 32b int
61 4 * nbufs bytes: offset for each buffer as integer as 32b int
63 4 * nbufs bytes: offset for each buffer as integer as 32b int
62
64
63 Offsets are from the start of the buffer, including the header.
65 Offsets are from the start of the buffer, including the header.
64
66
65 Returns
67 Returns
66 -------
68 -------
67
69
68 message dictionary
70 message dictionary
69 """
71 """
70 nbufs = struct.unpack('!i', bmsg[:4])[0]
72 nbufs = struct.unpack('!i', bmsg[:4])[0]
71 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
73 offsets = list(struct.unpack('!' + 'I' * nbufs, bmsg[4:4*(nbufs+1)]))
72 offsets.append(None)
74 offsets.append(None)
73 bufs = []
75 bufs = []
74 for start, stop in zip(offsets[:-1], offsets[1:]):
76 for start, stop in zip(offsets[:-1], offsets[1:]):
75 bufs.append(bmsg[start:stop])
77 bufs.append(bmsg[start:stop])
76 msg = json.loads(bufs[0].decode('utf8'))
78 msg = json.loads(bufs[0].decode('utf8'))
77 msg['header'] = extract_dates(msg['header'])
79 msg['header'] = extract_dates(msg['header'])
78 msg['parent_header'] = extract_dates(msg['parent_header'])
80 msg['parent_header'] = extract_dates(msg['parent_header'])
79 msg['buffers'] = bufs[1:]
81 msg['buffers'] = bufs[1:]
80 return msg
82 return msg
81
83
84 # ping interval for keeping websockets alive (30 seconds)
85 WS_PING_INTERVAL = 30000
82
86
83 class ZMQStreamHandler(websocket.WebSocketHandler):
87 if os.environ.get('IPYTHON_ALLOW_DRAFT_WEBSOCKETS_FOR_PHANTOMJS', False):
88 warnings.warn("""Allowing draft76 websocket connections!
89 This should only be done for testing with phantomjs!""")
90 from IPython.html import allow76
91 WebSocketHandler = allow76.AllowDraftWebSocketHandler
92 # draft 76 doesn't support ping
93 WS_PING_INTERVAL = 0
94
95 class ZMQStreamHandler(WebSocketHandler):
84
96
85 def check_origin(self, origin):
97 def check_origin(self, origin):
86 """Check Origin == Host or Access-Control-Allow-Origin.
98 """Check Origin == Host or Access-Control-Allow-Origin.
87
99
88 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
100 Tornado >= 4 calls this method automatically, raising 403 if it returns False.
89 We call it explicitly in `open` on Tornado < 4.
101 We call it explicitly in `open` on Tornado < 4.
90 """
102 """
91 if self.allow_origin == '*':
103 if self.allow_origin == '*':
92 return True
104 return True
93
105
94 host = self.request.headers.get("Host")
106 host = self.request.headers.get("Host")
95
107
96 # If no header is provided, assume we can't verify origin
108 # If no header is provided, assume we can't verify origin
97 if origin is None:
109 if origin is None:
98 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
110 self.log.warn("Missing Origin header, rejecting WebSocket connection.")
99 return False
111 return False
100 if host is None:
112 if host is None:
101 self.log.warn("Missing Host header, rejecting WebSocket connection.")
113 self.log.warn("Missing Host header, rejecting WebSocket connection.")
102 return False
114 return False
103
115
104 origin = origin.lower()
116 origin = origin.lower()
105 origin_host = urlparse(origin).netloc
117 origin_host = urlparse(origin).netloc
106
118
107 # OK if origin matches host
119 # OK if origin matches host
108 if origin_host == host:
120 if origin_host == host:
109 return True
121 return True
110
122
111 # Check CORS headers
123 # Check CORS headers
112 if self.allow_origin:
124 if self.allow_origin:
113 allow = self.allow_origin == origin
125 allow = self.allow_origin == origin
114 elif self.allow_origin_pat:
126 elif self.allow_origin_pat:
115 allow = bool(self.allow_origin_pat.match(origin))
127 allow = bool(self.allow_origin_pat.match(origin))
116 else:
128 else:
117 # No CORS headers deny the request
129 # No CORS headers deny the request
118 allow = False
130 allow = False
119 if not allow:
131 if not allow:
120 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
132 self.log.warn("Blocking Cross Origin WebSocket Attempt. Origin: %s, Host: %s",
121 origin, host,
133 origin, host,
122 )
134 )
123 return allow
135 return allow
124
136
125 def clear_cookie(self, *args, **kwargs):
137 def clear_cookie(self, *args, **kwargs):
126 """meaningless for websockets"""
138 """meaningless for websockets"""
127 pass
139 pass
128
140
129 def _reserialize_reply(self, msg_list):
141 def _reserialize_reply(self, msg_list):
130 """Reserialize a reply message using JSON.
142 """Reserialize a reply message using JSON.
131
143
132 This takes the msg list from the ZMQ socket, deserializes it using
144 This takes the msg list from the ZMQ socket, deserializes it using
133 self.session and then serializes the result using JSON. This method
145 self.session and then serializes the result using JSON. This method
134 should be used by self._on_zmq_reply to build messages that can
146 should be used by self._on_zmq_reply to build messages that can
135 be sent back to the browser.
147 be sent back to the browser.
136 """
148 """
137 idents, msg_list = self.session.feed_identities(msg_list)
149 idents, msg_list = self.session.feed_identities(msg_list)
138 msg = self.session.deserialize(msg_list)
150 msg = self.session.deserialize(msg_list)
139 if msg['buffers']:
151 if msg['buffers']:
140 buf = serialize_binary_message(msg)
152 buf = serialize_binary_message(msg)
141 return buf
153 return buf
142 else:
154 else:
143 smsg = json.dumps(msg, default=date_default)
155 smsg = json.dumps(msg, default=date_default)
144 return cast_unicode(smsg)
156 return cast_unicode(smsg)
145
157
146 def _on_zmq_reply(self, msg_list):
158 def _on_zmq_reply(self, msg_list):
147 # Sometimes this gets triggered when the on_close method is scheduled in the
159 # Sometimes this gets triggered when the on_close method is scheduled in the
148 # eventloop but hasn't been called.
160 # eventloop but hasn't been called.
149 if self.stream.closed(): return
161 if self.stream.closed(): return
150 try:
162 try:
151 msg = self._reserialize_reply(msg_list)
163 msg = self._reserialize_reply(msg_list)
152 except Exception:
164 except Exception:
153 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
165 self.log.critical("Malformed message: %r" % msg_list, exc_info=True)
154 else:
166 else:
155 self.write_message(msg, binary=isinstance(msg, bytes))
167 self.write_message(msg, binary=isinstance(msg, bytes))
156
168
157 def allow_draft76(self):
158 """Allow draft 76, until browsers such as Safari update to RFC 6455.
159
160 This has been disabled by default in tornado in release 2.2.0, and
161 support will be removed in later versions.
162 """
163 return True
164
165 # ping interval for keeping websockets alive (30 seconds)
166 WS_PING_INTERVAL = 30000
167
168 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
169 class AuthenticatedZMQStreamHandler(ZMQStreamHandler, IPythonHandler):
169 ping_callback = None
170 ping_callback = None
170 last_ping = 0
171 last_ping = 0
171 last_pong = 0
172 last_pong = 0
172
173
173 @property
174 @property
174 def ping_interval(self):
175 def ping_interval(self):
175 """The interval for websocket keep-alive pings.
176 """The interval for websocket keep-alive pings.
176
177
177 Set ws_ping_interval = 0 to disable pings.
178 Set ws_ping_interval = 0 to disable pings.
178 """
179 """
179 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
180 return self.settings.get('ws_ping_interval', WS_PING_INTERVAL)
180
181
181 @property
182 @property
182 def ping_timeout(self):
183 def ping_timeout(self):
183 """If no ping is received in this many milliseconds,
184 """If no ping is received in this many milliseconds,
184 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
185 close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
185 Default is max of 3 pings or 30 seconds.
186 Default is max of 3 pings or 30 seconds.
186 """
187 """
187 return self.settings.get('ws_ping_timeout',
188 return self.settings.get('ws_ping_timeout',
188 max(3 * self.ping_interval, WS_PING_INTERVAL)
189 max(3 * self.ping_interval, WS_PING_INTERVAL)
189 )
190 )
190
191
191 def set_default_headers(self):
192 def set_default_headers(self):
192 """Undo the set_default_headers in IPythonHandler
193 """Undo the set_default_headers in IPythonHandler
193
194
194 which doesn't make sense for websockets
195 which doesn't make sense for websockets
195 """
196 """
196 pass
197 pass
197
198
198 def pre_get(self):
199 def pre_get(self):
199 """Run before finishing the GET request
200 """Run before finishing the GET request
200
201
201 Extend this method to add logic that should fire before
202 Extend this method to add logic that should fire before
202 the websocket finishes completing.
203 the websocket finishes completing.
203 """
204 """
204 # Check to see that origin matches host directly, including ports
205 # Check to see that origin matches host directly, including ports
205 # Tornado 4 already does CORS checking
206 # Tornado 4 already does CORS checking
206 if tornado.version_info[0] < 4:
207 if tornado.version_info[0] < 4:
207 if not self.check_origin(self.get_origin()):
208 if not self.check_origin(self.get_origin()):
208 raise web.HTTPError(403)
209 raise web.HTTPError(403)
209
210
210 # authenticate the request before opening the websocket
211 # authenticate the request before opening the websocket
211 if self.get_current_user() is None:
212 if self.get_current_user() is None:
212 self.log.warn("Couldn't authenticate WebSocket connection")
213 self.log.warn("Couldn't authenticate WebSocket connection")
213 raise web.HTTPError(403)
214 raise web.HTTPError(403)
214
215
215 if self.get_argument('session_id', False):
216 if self.get_argument('session_id', False):
216 self.session.session = cast_unicode(self.get_argument('session_id'))
217 self.session.session = cast_unicode(self.get_argument('session_id'))
217 else:
218 else:
218 self.log.warn("No session ID specified")
219 self.log.warn("No session ID specified")
219
220
220 @gen.coroutine
221 @gen.coroutine
221 def get(self, *args, **kwargs):
222 def get(self, *args, **kwargs):
222 # pre_get can be a coroutine in subclasses
223 # pre_get can be a coroutine in subclasses
223 # assign and yield in two step to avoid tornado 3 issues
224 # assign and yield in two step to avoid tornado 3 issues
224 res = self.pre_get()
225 res = self.pre_get()
225 yield gen.maybe_future(res)
226 yield gen.maybe_future(res)
226 # FIXME: only do super get on tornado ≥ 4
227 # FIXME: only do super get on tornado ≥ 4
227 # tornado 3 has no get, will raise 405
228 # tornado 3 has no get, will raise 405
228 if tornado.version_info >= (4,):
229 if tornado.version_info >= (4,):
229 super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
230 super(AuthenticatedZMQStreamHandler, self).get(*args, **kwargs)
230
231
231 def initialize(self):
232 def initialize(self):
232 self.log.debug("Initializing websocket connection %s", self.request.path)
233 self.log.debug("Initializing websocket connection %s", self.request.path)
233 self.session = Session(config=self.config)
234 self.session = Session(config=self.config)
234
235
235 def open(self, *args, **kwargs):
236 def open(self, *args, **kwargs):
236 self.log.debug("Opening websocket %s", self.request.path)
237 self.log.debug("Opening websocket %s", self.request.path)
237 if tornado.version_info < (4,):
238 if tornado.version_info < (4,):
238 try:
239 try:
239 self.get(*self.open_args, **self.open_kwargs)
240 self.get(*self.open_args, **self.open_kwargs)
240 except web.HTTPError:
241 except web.HTTPError:
241 self.close()
242 self.close()
242 raise
243 raise
243
244
244 # start the pinging
245 # start the pinging
245 if self.ping_interval > 0:
246 if self.ping_interval > 0:
246 self.last_ping = ioloop.IOLoop.instance().time() # Remember time of last ping
247 self.last_ping = ioloop.IOLoop.instance().time() # Remember time of last ping
247 self.last_pong = self.last_ping
248 self.last_pong = self.last_ping
248 self.ping_callback = ioloop.PeriodicCallback(self.send_ping, self.ping_interval)
249 self.ping_callback = ioloop.PeriodicCallback(self.send_ping, self.ping_interval)
249 self.ping_callback.start()
250 self.ping_callback.start()
250
251
251 def send_ping(self):
252 def send_ping(self):
252 """send a ping to keep the websocket alive"""
253 """send a ping to keep the websocket alive"""
253 if self.stream.closed() and self.ping_callback is not None:
254 if self.stream.closed() and self.ping_callback is not None:
254 self.ping_callback.stop()
255 self.ping_callback.stop()
255 return
256 return
256
257
257 # check for timeout on pong. Make sure that we really have sent a recent ping in
258 # check for timeout on pong. Make sure that we really have sent a recent ping in
258 # case the machine with both server and client has been suspended since the last ping.
259 # case the machine with both server and client has been suspended since the last ping.
259 now = ioloop.IOLoop.instance().time()
260 now = ioloop.IOLoop.instance().time()
260 since_last_pong = 1e3 * (now - self.last_pong)
261 since_last_pong = 1e3 * (now - self.last_pong)
261 since_last_ping = 1e3 * (now - self.last_ping)
262 since_last_ping = 1e3 * (now - self.last_ping)
262 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
263 if since_last_ping < 2*self.ping_interval and since_last_pong > self.ping_timeout:
263 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
264 self.log.warn("WebSocket ping timeout after %i ms.", since_last_pong)
264 self.close()
265 self.close()
265 return
266 return
266
267
267 self.ping(b'')
268 self.ping(b'')
268 self.last_ping = now
269 self.last_ping = now
269
270
270 def on_pong(self, data):
271 def on_pong(self, data):
271 self.last_pong = ioloop.IOLoop.instance().time()
272 self.last_pong = ioloop.IOLoop.instance().time()
@@ -1,705 +1,713
1 # -*- coding: utf-8 -*-
1 # -*- coding: utf-8 -*-
2 """IPython Test Process Controller
2 """IPython Test Process Controller
3
3
4 This module runs one or more subprocesses which will actually run the IPython
4 This module runs one or more subprocesses which will actually run the IPython
5 test suite.
5 test suite.
6
6
7 """
7 """
8
8
9 # Copyright (c) IPython Development Team.
9 # Copyright (c) IPython Development Team.
10 # Distributed under the terms of the Modified BSD License.
10 # Distributed under the terms of the Modified BSD License.
11
11
12 from __future__ import print_function
12 from __future__ import print_function
13
13
14 import argparse
14 import argparse
15 import json
15 import json
16 import multiprocessing.pool
16 import multiprocessing.pool
17 import os
17 import os
18 import re
18 import re
19 import requests
19 import requests
20 import shutil
20 import shutil
21 import signal
21 import signal
22 import sys
22 import sys
23 import subprocess
23 import subprocess
24 import time
24 import time
25
25
26 from .iptest import (
26 from .iptest import (
27 have, test_group_names as py_test_group_names, test_sections, StreamCapturer,
27 have, test_group_names as py_test_group_names, test_sections, StreamCapturer,
28 test_for,
28 test_for,
29 )
29 )
30 from IPython.utils.path import compress_user
30 from IPython.utils.path import compress_user
31 from IPython.utils.py3compat import bytes_to_str
31 from IPython.utils.py3compat import bytes_to_str
32 from IPython.utils.sysinfo import get_sys_info
32 from IPython.utils.sysinfo import get_sys_info
33 from IPython.utils.tempdir import TemporaryDirectory
33 from IPython.utils.tempdir import TemporaryDirectory
34 from IPython.utils.text import strip_ansi
34 from IPython.utils.text import strip_ansi
35
35
36 try:
36 try:
37 # Python >= 3.3
37 # Python >= 3.3
38 from subprocess import TimeoutExpired
38 from subprocess import TimeoutExpired
39 def popen_wait(p, timeout):
39 def popen_wait(p, timeout):
40 return p.wait(timeout)
40 return p.wait(timeout)
41 except ImportError:
41 except ImportError:
42 class TimeoutExpired(Exception):
42 class TimeoutExpired(Exception):
43 pass
43 pass
44 def popen_wait(p, timeout):
44 def popen_wait(p, timeout):
45 """backport of Popen.wait from Python 3"""
45 """backport of Popen.wait from Python 3"""
46 for i in range(int(10 * timeout)):
46 for i in range(int(10 * timeout)):
47 if p.poll() is not None:
47 if p.poll() is not None:
48 return
48 return
49 time.sleep(0.1)
49 time.sleep(0.1)
50 if p.poll() is None:
50 if p.poll() is None:
51 raise TimeoutExpired
51 raise TimeoutExpired
52
52
53 NOTEBOOK_SHUTDOWN_TIMEOUT = 10
53 NOTEBOOK_SHUTDOWN_TIMEOUT = 10
54
54
55 class TestController(object):
55 class TestController(object):
56 """Run tests in a subprocess
56 """Run tests in a subprocess
57 """
57 """
58 #: str, IPython test suite to be executed.
58 #: str, IPython test suite to be executed.
59 section = None
59 section = None
60 #: list, command line arguments to be executed
60 #: list, command line arguments to be executed
61 cmd = None
61 cmd = None
62 #: dict, extra environment variables to set for the subprocess
62 #: dict, extra environment variables to set for the subprocess
63 env = None
63 env = None
64 #: list, TemporaryDirectory instances to clear up when the process finishes
64 #: list, TemporaryDirectory instances to clear up when the process finishes
65 dirs = None
65 dirs = None
66 #: subprocess.Popen instance
66 #: subprocess.Popen instance
67 process = None
67 process = None
68 #: str, process stdout+stderr
68 #: str, process stdout+stderr
69 stdout = None
69 stdout = None
70
70
71 def __init__(self):
71 def __init__(self):
72 self.cmd = []
72 self.cmd = []
73 self.env = {}
73 self.env = {}
74 self.dirs = []
74 self.dirs = []
75
75
76 def setup(self):
76 def setup(self):
77 """Create temporary directories etc.
77 """Create temporary directories etc.
78
78
79 This is only called when we know the test group will be run. Things
79 This is only called when we know the test group will be run. Things
80 created here may be cleaned up by self.cleanup().
80 created here may be cleaned up by self.cleanup().
81 """
81 """
82 pass
82 pass
83
83
84 def launch(self, buffer_output=False, capture_output=False):
84 def launch(self, buffer_output=False, capture_output=False):
85 # print('*** ENV:', self.env) # dbg
85 # print('*** ENV:', self.env) # dbg
86 # print('*** CMD:', self.cmd) # dbg
86 # print('*** CMD:', self.cmd) # dbg
87 env = os.environ.copy()
87 env = os.environ.copy()
88 env.update(self.env)
88 env.update(self.env)
89 if buffer_output:
89 if buffer_output:
90 capture_output = True
90 capture_output = True
91 self.stdout_capturer = c = StreamCapturer(echo=not buffer_output)
91 self.stdout_capturer = c = StreamCapturer(echo=not buffer_output)
92 c.start()
92 c.start()
93 stdout = c.writefd if capture_output else None
93 stdout = c.writefd if capture_output else None
94 stderr = subprocess.STDOUT if capture_output else None
94 stderr = subprocess.STDOUT if capture_output else None
95 self.process = subprocess.Popen(self.cmd, stdout=stdout,
95 self.process = subprocess.Popen(self.cmd, stdout=stdout,
96 stderr=stderr, env=env)
96 stderr=stderr, env=env)
97
97
98 def wait(self):
98 def wait(self):
99 self.process.wait()
99 self.process.wait()
100 self.stdout_capturer.halt()
100 self.stdout_capturer.halt()
101 self.stdout = self.stdout_capturer.get_buffer()
101 self.stdout = self.stdout_capturer.get_buffer()
102 return self.process.returncode
102 return self.process.returncode
103
103
104 def print_extra_info(self):
104 def print_extra_info(self):
105 """Print extra information about this test run.
105 """Print extra information about this test run.
106
106
107 If we're running in parallel and showing the concise view, this is only
107 If we're running in parallel and showing the concise view, this is only
108 called if the test group fails. Otherwise, it's called before the test
108 called if the test group fails. Otherwise, it's called before the test
109 group is started.
109 group is started.
110
110
111 The base implementation does nothing, but it can be overridden by
111 The base implementation does nothing, but it can be overridden by
112 subclasses.
112 subclasses.
113 """
113 """
114 return
114 return
115
115
116 def cleanup_process(self):
116 def cleanup_process(self):
117 """Cleanup on exit by killing any leftover processes."""
117 """Cleanup on exit by killing any leftover processes."""
118 subp = self.process
118 subp = self.process
119 if subp is None or (subp.poll() is not None):
119 if subp is None or (subp.poll() is not None):
120 return # Process doesn't exist, or is already dead.
120 return # Process doesn't exist, or is already dead.
121
121
122 try:
122 try:
123 print('Cleaning up stale PID: %d' % subp.pid)
123 print('Cleaning up stale PID: %d' % subp.pid)
124 subp.kill()
124 subp.kill()
125 except: # (OSError, WindowsError) ?
125 except: # (OSError, WindowsError) ?
126 # This is just a best effort, if we fail or the process was
126 # This is just a best effort, if we fail or the process was
127 # really gone, ignore it.
127 # really gone, ignore it.
128 pass
128 pass
129 else:
129 else:
130 for i in range(10):
130 for i in range(10):
131 if subp.poll() is None:
131 if subp.poll() is None:
132 time.sleep(0.1)
132 time.sleep(0.1)
133 else:
133 else:
134 break
134 break
135
135
136 if subp.poll() is None:
136 if subp.poll() is None:
137 # The process did not die...
137 # The process did not die...
138 print('... failed. Manual cleanup may be required.')
138 print('... failed. Manual cleanup may be required.')
139
139
140 def cleanup(self):
140 def cleanup(self):
141 "Kill process if it's still alive, and clean up temporary directories"
141 "Kill process if it's still alive, and clean up temporary directories"
142 self.cleanup_process()
142 self.cleanup_process()
143 for td in self.dirs:
143 for td in self.dirs:
144 td.cleanup()
144 td.cleanup()
145
145
146 __del__ = cleanup
146 __del__ = cleanup
147
147
148
148
149 class PyTestController(TestController):
149 class PyTestController(TestController):
150 """Run Python tests using IPython.testing.iptest"""
150 """Run Python tests using IPython.testing.iptest"""
151 #: str, Python command to execute in subprocess
151 #: str, Python command to execute in subprocess
152 pycmd = None
152 pycmd = None
153
153
154 def __init__(self, section, options):
154 def __init__(self, section, options):
155 """Create new test runner."""
155 """Create new test runner."""
156 TestController.__init__(self)
156 TestController.__init__(self)
157 self.section = section
157 self.section = section
158 # pycmd is put into cmd[2] in PyTestController.launch()
158 # pycmd is put into cmd[2] in PyTestController.launch()
159 self.cmd = [sys.executable, '-c', None, section]
159 self.cmd = [sys.executable, '-c', None, section]
160 self.pycmd = "from IPython.testing.iptest import run_iptest; run_iptest()"
160 self.pycmd = "from IPython.testing.iptest import run_iptest; run_iptest()"
161 self.options = options
161 self.options = options
162
162
163 def setup(self):
163 def setup(self):
164 ipydir = TemporaryDirectory()
164 ipydir = TemporaryDirectory()
165 self.dirs.append(ipydir)
165 self.dirs.append(ipydir)
166 self.env['IPYTHONDIR'] = ipydir.name
166 self.env['IPYTHONDIR'] = ipydir.name
167 self.workingdir = workingdir = TemporaryDirectory()
167 self.workingdir = workingdir = TemporaryDirectory()
168 self.dirs.append(workingdir)
168 self.dirs.append(workingdir)
169 self.env['IPTEST_WORKING_DIR'] = workingdir.name
169 self.env['IPTEST_WORKING_DIR'] = workingdir.name
170 # This means we won't get odd effects from our own matplotlib config
170 # This means we won't get odd effects from our own matplotlib config
171 self.env['MPLCONFIGDIR'] = workingdir.name
171 self.env['MPLCONFIGDIR'] = workingdir.name
172
172
173 # From options:
173 # From options:
174 if self.options.xunit:
174 if self.options.xunit:
175 self.add_xunit()
175 self.add_xunit()
176 if self.options.coverage:
176 if self.options.coverage:
177 self.add_coverage()
177 self.add_coverage()
178 self.env['IPTEST_SUBPROC_STREAMS'] = self.options.subproc_streams
178 self.env['IPTEST_SUBPROC_STREAMS'] = self.options.subproc_streams
179 self.cmd.extend(self.options.extra_args)
179 self.cmd.extend(self.options.extra_args)
180
180
181 @property
181 @property
182 def will_run(self):
182 def will_run(self):
183 try:
183 try:
184 return test_sections[self.section].will_run
184 return test_sections[self.section].will_run
185 except KeyError:
185 except KeyError:
186 return True
186 return True
187
187
188 def add_xunit(self):
188 def add_xunit(self):
189 xunit_file = os.path.abspath(self.section + '.xunit.xml')
189 xunit_file = os.path.abspath(self.section + '.xunit.xml')
190 self.cmd.extend(['--with-xunit', '--xunit-file', xunit_file])
190 self.cmd.extend(['--with-xunit', '--xunit-file', xunit_file])
191
191
192 def add_coverage(self):
192 def add_coverage(self):
193 try:
193 try:
194 sources = test_sections[self.section].includes
194 sources = test_sections[self.section].includes
195 except KeyError:
195 except KeyError:
196 sources = ['IPython']
196 sources = ['IPython']
197
197
198 coverage_rc = ("[run]\n"
198 coverage_rc = ("[run]\n"
199 "data_file = {data_file}\n"
199 "data_file = {data_file}\n"
200 "source =\n"
200 "source =\n"
201 " {source}\n"
201 " {source}\n"
202 ).format(data_file=os.path.abspath('.coverage.'+self.section),
202 ).format(data_file=os.path.abspath('.coverage.'+self.section),
203 source="\n ".join(sources))
203 source="\n ".join(sources))
204 config_file = os.path.join(self.workingdir.name, '.coveragerc')
204 config_file = os.path.join(self.workingdir.name, '.coveragerc')
205 with open(config_file, 'w') as f:
205 with open(config_file, 'w') as f:
206 f.write(coverage_rc)
206 f.write(coverage_rc)
207
207
208 self.env['COVERAGE_PROCESS_START'] = config_file
208 self.env['COVERAGE_PROCESS_START'] = config_file
209 self.pycmd = "import coverage; coverage.process_startup(); " + self.pycmd
209 self.pycmd = "import coverage; coverage.process_startup(); " + self.pycmd
210
210
211 def launch(self, buffer_output=False):
211 def launch(self, buffer_output=False):
212 self.cmd[2] = self.pycmd
212 self.cmd[2] = self.pycmd
213 super(PyTestController, self).launch(buffer_output=buffer_output)
213 super(PyTestController, self).launch(buffer_output=buffer_output)
214
214
215
215
216 js_prefix = 'js/'
216 js_prefix = 'js/'
217
217
218 def get_js_test_dir():
218 def get_js_test_dir():
219 import IPython.html.tests as t
219 import IPython.html.tests as t
220 return os.path.join(os.path.dirname(t.__file__), '')
220 return os.path.join(os.path.dirname(t.__file__), '')
221
221
222 def all_js_groups():
222 def all_js_groups():
223 import glob
223 import glob
224 test_dir = get_js_test_dir()
224 test_dir = get_js_test_dir()
225 all_subdirs = glob.glob(test_dir + '[!_]*/')
225 all_subdirs = glob.glob(test_dir + '[!_]*/')
226 return [js_prefix+os.path.relpath(x, test_dir) for x in all_subdirs]
226 return [js_prefix+os.path.relpath(x, test_dir) for x in all_subdirs]
227
227
228 class JSController(TestController):
228 class JSController(TestController):
229 """Run CasperJS tests """
229 """Run CasperJS tests """
230
230
231 requirements = ['zmq', 'tornado', 'jinja2', 'casperjs', 'sqlite3',
231 requirements = ['zmq', 'tornado', 'jinja2', 'casperjs', 'sqlite3',
232 'jsonschema']
232 'jsonschema']
233
233
234 def __init__(self, section, xunit=True, engine='phantomjs', url=None):
234 def __init__(self, section, xunit=True, engine='phantomjs', url=None):
235 """Create new test runner."""
235 """Create new test runner."""
236 TestController.__init__(self)
236 TestController.__init__(self)
237 self.engine = engine
237 self.engine = engine
238 self.section = section
238 self.section = section
239 self.xunit = xunit
239 self.xunit = xunit
240 self.url = url
240 self.url = url
241 self.slimer_failure = re.compile('^FAIL.*', flags=re.MULTILINE)
241 self.slimer_failure = re.compile('^FAIL.*', flags=re.MULTILINE)
242 js_test_dir = get_js_test_dir()
242 js_test_dir = get_js_test_dir()
243 includes = '--includes=' + os.path.join(js_test_dir,'util.js')
243 includes = '--includes=' + os.path.join(js_test_dir,'util.js')
244 test_cases = os.path.join(js_test_dir, self.section[len(js_prefix):])
244 test_cases = os.path.join(js_test_dir, self.section[len(js_prefix):])
245 self.cmd = ['casperjs', 'test', includes, test_cases, '--engine=%s' % self.engine]
245 self.cmd = ['casperjs', 'test', includes, test_cases, '--engine=%s' % self.engine]
246
246
247 def setup(self):
247 def setup(self):
248 self.ipydir = TemporaryDirectory()
248 self.ipydir = TemporaryDirectory()
249 self.nbdir = TemporaryDirectory()
249 self.nbdir = TemporaryDirectory()
250 self.dirs.append(self.ipydir)
250 self.dirs.append(self.ipydir)
251 self.dirs.append(self.nbdir)
251 self.dirs.append(self.nbdir)
252 os.makedirs(os.path.join(self.nbdir.name, os.path.join(u'sub ∂ir1', u'sub ∂ir 1a')))
252 os.makedirs(os.path.join(self.nbdir.name, os.path.join(u'sub ∂ir1', u'sub ∂ir 1a')))
253 os.makedirs(os.path.join(self.nbdir.name, os.path.join(u'sub ∂ir2', u'sub ∂ir 1b')))
253 os.makedirs(os.path.join(self.nbdir.name, os.path.join(u'sub ∂ir2', u'sub ∂ir 1b')))
254
254
255 if self.xunit:
255 if self.xunit:
256 self.add_xunit()
256 self.add_xunit()
257
257
258 # If a url was specified, use that for the testing.
258 # If a url was specified, use that for the testing.
259 if self.url:
259 if self.url:
260 try:
260 try:
261 alive = requests.get(self.url).status_code == 200
261 alive = requests.get(self.url).status_code == 200
262 except:
262 except:
263 alive = False
263 alive = False
264
264
265 if alive:
265 if alive:
266 self.cmd.append("--url=%s" % self.url)
266 self.cmd.append("--url=%s" % self.url)
267 else:
267 else:
268 raise Exception('Could not reach "%s".' % self.url)
268 raise Exception('Could not reach "%s".' % self.url)
269 else:
269 else:
270 # start the ipython notebook, so we get the port number
270 # start the ipython notebook, so we get the port number
271 self.server_port = 0
271 self.server_port = 0
272 self._init_server()
272 self._init_server()
273 if self.server_port:
273 if self.server_port:
274 self.cmd.append("--port=%i" % self.server_port)
274 self.cmd.append("--port=%i" % self.server_port)
275 else:
275 else:
276 # don't launch tests if the server didn't start
276 # don't launch tests if the server didn't start
277 self.cmd = [sys.executable, '-c', 'raise SystemExit(1)']
277 self.cmd = [sys.executable, '-c', 'raise SystemExit(1)']
278
278
279 def add_xunit(self):
279 def add_xunit(self):
280 xunit_file = os.path.abspath(self.section.replace('/','.') + '.xunit.xml')
280 xunit_file = os.path.abspath(self.section.replace('/','.') + '.xunit.xml')
281 self.cmd.append('--xunit=%s' % xunit_file)
281 self.cmd.append('--xunit=%s' % xunit_file)
282
282
283 def launch(self, buffer_output):
283 def launch(self, buffer_output):
284 # If the engine is SlimerJS, we need to buffer the output because
284 # If the engine is SlimerJS, we need to buffer the output because
285 # SlimerJS does not support exit codes, so CasperJS always returns 0.
285 # SlimerJS does not support exit codes, so CasperJS always returns 0.
286 if self.engine == 'slimerjs' and not buffer_output:
286 if self.engine == 'slimerjs' and not buffer_output:
287 return super(JSController, self).launch(capture_output=True)
287 return super(JSController, self).launch(capture_output=True)
288
288
289 else:
289 else:
290 return super(JSController, self).launch(buffer_output=buffer_output)
290 return super(JSController, self).launch(buffer_output=buffer_output)
291
291
292 def wait(self, *pargs, **kwargs):
292 def wait(self, *pargs, **kwargs):
293 """Wait for the JSController to finish"""
293 """Wait for the JSController to finish"""
294 ret = super(JSController, self).wait(*pargs, **kwargs)
294 ret = super(JSController, self).wait(*pargs, **kwargs)
295 # If this is a SlimerJS controller, check the captured stdout for
295 # If this is a SlimerJS controller, check the captured stdout for
296 # errors. Otherwise, just return the return code.
296 # errors. Otherwise, just return the return code.
297 if self.engine == 'slimerjs':
297 if self.engine == 'slimerjs':
298 stdout = bytes_to_str(self.stdout)
298 stdout = bytes_to_str(self.stdout)
299 if ret != 0:
299 if ret != 0:
300 # This could still happen e.g. if it's stopped by SIGINT
300 # This could still happen e.g. if it's stopped by SIGINT
301 return ret
301 return ret
302 return bool(self.slimer_failure.search(strip_ansi(stdout)))
302 return bool(self.slimer_failure.search(strip_ansi(stdout)))
303 else:
303 else:
304 return ret
304 return ret
305
305
306 def print_extra_info(self):
306 def print_extra_info(self):
307 print("Running tests with notebook directory %r" % self.nbdir.name)
307 print("Running tests with notebook directory %r" % self.nbdir.name)
308
308
309 @property
309 @property
310 def will_run(self):
310 def will_run(self):
311 should_run = all(have[a] for a in self.requirements + [self.engine])
311 should_run = all(have[a] for a in self.requirements + [self.engine])
312 return should_run
312 return should_run
313
313
314 def _init_server(self):
314 def _init_server(self):
315 "Start the notebook server in a separate process"
315 "Start the notebook server in a separate process"
316 self.server_command = command = [sys.executable,
316 self.server_command = command = [sys.executable,
317 '-m', 'IPython.html',
317 '-m', 'IPython.html',
318 '--no-browser',
318 '--no-browser',
319 '--ipython-dir', self.ipydir.name,
319 '--ipython-dir', self.ipydir.name,
320 '--notebook-dir', self.nbdir.name,
320 '--notebook-dir', self.nbdir.name,
321 ]
321 ]
322 # ipc doesn't work on Windows, and darwin has crazy-long temp paths,
322 # ipc doesn't work on Windows, and darwin has crazy-long temp paths,
323 # which run afoul of ipc's maximum path length.
323 # which run afoul of ipc's maximum path length.
324 if sys.platform.startswith('linux'):
324 if sys.platform.startswith('linux'):
325 command.append('--KernelManager.transport=ipc')
325 command.append('--KernelManager.transport=ipc')
326 self.stream_capturer = c = StreamCapturer()
326 self.stream_capturer = c = StreamCapturer()
327 c.start()
327 c.start()
328 self.server = subprocess.Popen(command, stdout=c.writefd, stderr=subprocess.STDOUT, cwd=self.nbdir.name)
328 env = os.environ.copy()
329 if self.engine == 'phantomjs':
330 env['IPYTHON_ALLOW_DRAFT_WEBSOCKETS_FOR_PHANTOMJS'] = '1'
331 self.server = subprocess.Popen(command,
332 stdout=c.writefd,
333 stderr=subprocess.STDOUT,
334 cwd=self.nbdir.name,
335 env=env,
336 )
329 self.server_info_file = os.path.join(self.ipydir.name,
337 self.server_info_file = os.path.join(self.ipydir.name,
330 'profile_default', 'security', 'nbserver-%i.json' % self.server.pid
338 'profile_default', 'security', 'nbserver-%i.json' % self.server.pid
331 )
339 )
332 self._wait_for_server()
340 self._wait_for_server()
333
341
334 def _wait_for_server(self):
342 def _wait_for_server(self):
335 """Wait 30 seconds for the notebook server to start"""
343 """Wait 30 seconds for the notebook server to start"""
336 for i in range(300):
344 for i in range(300):
337 if self.server.poll() is not None:
345 if self.server.poll() is not None:
338 return self._failed_to_start()
346 return self._failed_to_start()
339 if os.path.exists(self.server_info_file):
347 if os.path.exists(self.server_info_file):
340 try:
348 try:
341 self._load_server_info()
349 self._load_server_info()
342 except ValueError:
350 except ValueError:
343 # If the server is halfway through writing the file, we may
351 # If the server is halfway through writing the file, we may
344 # get invalid JSON; it should be ready next iteration.
352 # get invalid JSON; it should be ready next iteration.
345 pass
353 pass
346 else:
354 else:
347 return
355 return
348 time.sleep(0.1)
356 time.sleep(0.1)
349 print("Notebook server-info file never arrived: %s" % self.server_info_file,
357 print("Notebook server-info file never arrived: %s" % self.server_info_file,
350 file=sys.stderr
358 file=sys.stderr
351 )
359 )
352
360
353 def _failed_to_start(self):
361 def _failed_to_start(self):
354 """Notebook server exited prematurely"""
362 """Notebook server exited prematurely"""
355 captured = self.stream_capturer.get_buffer().decode('utf-8', 'replace')
363 captured = self.stream_capturer.get_buffer().decode('utf-8', 'replace')
356 print("Notebook failed to start: ", file=sys.stderr)
364 print("Notebook failed to start: ", file=sys.stderr)
357 print(self.server_command)
365 print(self.server_command)
358 print(captured, file=sys.stderr)
366 print(captured, file=sys.stderr)
359
367
360 def _load_server_info(self):
368 def _load_server_info(self):
361 """Notebook server started, load connection info from JSON"""
369 """Notebook server started, load connection info from JSON"""
362 with open(self.server_info_file) as f:
370 with open(self.server_info_file) as f:
363 info = json.load(f)
371 info = json.load(f)
364 self.server_port = info['port']
372 self.server_port = info['port']
365
373
366 def cleanup(self):
374 def cleanup(self):
367 try:
375 try:
368 self.server.terminate()
376 self.server.terminate()
369 except OSError:
377 except OSError:
370 # already dead
378 # already dead
371 pass
379 pass
372 # wait 10s for the server to shutdown
380 # wait 10s for the server to shutdown
373 try:
381 try:
374 popen_wait(self.server, NOTEBOOK_SHUTDOWN_TIMEOUT)
382 popen_wait(self.server, NOTEBOOK_SHUTDOWN_TIMEOUT)
375 except TimeoutExpired:
383 except TimeoutExpired:
376 # server didn't terminate, kill it
384 # server didn't terminate, kill it
377 try:
385 try:
378 print("Failed to terminate notebook server, killing it.",
386 print("Failed to terminate notebook server, killing it.",
379 file=sys.stderr
387 file=sys.stderr
380 )
388 )
381 self.server.kill()
389 self.server.kill()
382 except OSError:
390 except OSError:
383 # already dead
391 # already dead
384 pass
392 pass
385 # wait another 10s
393 # wait another 10s
386 try:
394 try:
387 popen_wait(self.server, NOTEBOOK_SHUTDOWN_TIMEOUT)
395 popen_wait(self.server, NOTEBOOK_SHUTDOWN_TIMEOUT)
388 except TimeoutExpired:
396 except TimeoutExpired:
389 print("Notebook server still running (%s)" % self.server_info_file,
397 print("Notebook server still running (%s)" % self.server_info_file,
390 file=sys.stderr
398 file=sys.stderr
391 )
399 )
392
400
393 self.stream_capturer.halt()
401 self.stream_capturer.halt()
394 TestController.cleanup(self)
402 TestController.cleanup(self)
395
403
396
404
397 def prepare_controllers(options):
405 def prepare_controllers(options):
398 """Returns two lists of TestController instances, those to run, and those
406 """Returns two lists of TestController instances, those to run, and those
399 not to run."""
407 not to run."""
400 testgroups = options.testgroups
408 testgroups = options.testgroups
401 if testgroups:
409 if testgroups:
402 if 'js' in testgroups:
410 if 'js' in testgroups:
403 js_testgroups = all_js_groups()
411 js_testgroups = all_js_groups()
404 else:
412 else:
405 js_testgroups = [g for g in testgroups if g.startswith(js_prefix)]
413 js_testgroups = [g for g in testgroups if g.startswith(js_prefix)]
406
414
407 py_testgroups = [g for g in testgroups if g not in ['js'] + js_testgroups]
415 py_testgroups = [g for g in testgroups if g not in ['js'] + js_testgroups]
408 else:
416 else:
409 py_testgroups = py_test_group_names
417 py_testgroups = py_test_group_names
410 if not options.all:
418 if not options.all:
411 js_testgroups = []
419 js_testgroups = []
412 test_sections['parallel'].enabled = False
420 test_sections['parallel'].enabled = False
413 else:
421 else:
414 js_testgroups = all_js_groups()
422 js_testgroups = all_js_groups()
415
423
416 engine = 'slimerjs' if options.slimerjs else 'phantomjs'
424 engine = 'slimerjs' if options.slimerjs else 'phantomjs'
417 c_js = [JSController(name, xunit=options.xunit, engine=engine, url=options.url) for name in js_testgroups]
425 c_js = [JSController(name, xunit=options.xunit, engine=engine, url=options.url) for name in js_testgroups]
418 c_py = [PyTestController(name, options) for name in py_testgroups]
426 c_py = [PyTestController(name, options) for name in py_testgroups]
419
427
420 controllers = c_py + c_js
428 controllers = c_py + c_js
421 to_run = [c for c in controllers if c.will_run]
429 to_run = [c for c in controllers if c.will_run]
422 not_run = [c for c in controllers if not c.will_run]
430 not_run = [c for c in controllers if not c.will_run]
423 return to_run, not_run
431 return to_run, not_run
424
432
425 def do_run(controller, buffer_output=True):
433 def do_run(controller, buffer_output=True):
426 """Setup and run a test controller.
434 """Setup and run a test controller.
427
435
428 If buffer_output is True, no output is displayed, to avoid it appearing
436 If buffer_output is True, no output is displayed, to avoid it appearing
429 interleaved. In this case, the caller is responsible for displaying test
437 interleaved. In this case, the caller is responsible for displaying test
430 output on failure.
438 output on failure.
431
439
432 Returns
440 Returns
433 -------
441 -------
434 controller : TestController
442 controller : TestController
435 The same controller as passed in, as a convenience for using map() type
443 The same controller as passed in, as a convenience for using map() type
436 APIs.
444 APIs.
437 exitcode : int
445 exitcode : int
438 The exit code of the test subprocess. Non-zero indicates failure.
446 The exit code of the test subprocess. Non-zero indicates failure.
439 """
447 """
440 try:
448 try:
441 try:
449 try:
442 controller.setup()
450 controller.setup()
443 if not buffer_output:
451 if not buffer_output:
444 controller.print_extra_info()
452 controller.print_extra_info()
445 controller.launch(buffer_output=buffer_output)
453 controller.launch(buffer_output=buffer_output)
446 except Exception:
454 except Exception:
447 import traceback
455 import traceback
448 traceback.print_exc()
456 traceback.print_exc()
449 return controller, 1 # signal failure
457 return controller, 1 # signal failure
450
458
451 exitcode = controller.wait()
459 exitcode = controller.wait()
452 return controller, exitcode
460 return controller, exitcode
453
461
454 except KeyboardInterrupt:
462 except KeyboardInterrupt:
455 return controller, -signal.SIGINT
463 return controller, -signal.SIGINT
456 finally:
464 finally:
457 controller.cleanup()
465 controller.cleanup()
458
466
459 def report():
467 def report():
460 """Return a string with a summary report of test-related variables."""
468 """Return a string with a summary report of test-related variables."""
461 inf = get_sys_info()
469 inf = get_sys_info()
462 out = []
470 out = []
463 def _add(name, value):
471 def _add(name, value):
464 out.append((name, value))
472 out.append((name, value))
465
473
466 _add('IPython version', inf['ipython_version'])
474 _add('IPython version', inf['ipython_version'])
467 _add('IPython commit', "{} ({})".format(inf['commit_hash'], inf['commit_source']))
475 _add('IPython commit', "{} ({})".format(inf['commit_hash'], inf['commit_source']))
468 _add('IPython package', compress_user(inf['ipython_path']))
476 _add('IPython package', compress_user(inf['ipython_path']))
469 _add('Python version', inf['sys_version'].replace('\n',''))
477 _add('Python version', inf['sys_version'].replace('\n',''))
470 _add('sys.executable', compress_user(inf['sys_executable']))
478 _add('sys.executable', compress_user(inf['sys_executable']))
471 _add('Platform', inf['platform'])
479 _add('Platform', inf['platform'])
472
480
473 width = max(len(n) for (n,v) in out)
481 width = max(len(n) for (n,v) in out)
474 out = ["{:<{width}}: {}\n".format(n, v, width=width) for (n,v) in out]
482 out = ["{:<{width}}: {}\n".format(n, v, width=width) for (n,v) in out]
475
483
476 avail = []
484 avail = []
477 not_avail = []
485 not_avail = []
478
486
479 for k, is_avail in have.items():
487 for k, is_avail in have.items():
480 if is_avail:
488 if is_avail:
481 avail.append(k)
489 avail.append(k)
482 else:
490 else:
483 not_avail.append(k)
491 not_avail.append(k)
484
492
485 if avail:
493 if avail:
486 out.append('\nTools and libraries available at test time:\n')
494 out.append('\nTools and libraries available at test time:\n')
487 avail.sort()
495 avail.sort()
488 out.append(' ' + ' '.join(avail)+'\n')
496 out.append(' ' + ' '.join(avail)+'\n')
489
497
490 if not_avail:
498 if not_avail:
491 out.append('\nTools and libraries NOT available at test time:\n')
499 out.append('\nTools and libraries NOT available at test time:\n')
492 not_avail.sort()
500 not_avail.sort()
493 out.append(' ' + ' '.join(not_avail)+'\n')
501 out.append(' ' + ' '.join(not_avail)+'\n')
494
502
495 return ''.join(out)
503 return ''.join(out)
496
504
497 def run_iptestall(options):
505 def run_iptestall(options):
498 """Run the entire IPython test suite by calling nose and trial.
506 """Run the entire IPython test suite by calling nose and trial.
499
507
500 This function constructs :class:`IPTester` instances for all IPython
508 This function constructs :class:`IPTester` instances for all IPython
501 modules and package and then runs each of them. This causes the modules
509 modules and package and then runs each of them. This causes the modules
502 and packages of IPython to be tested each in their own subprocess using
510 and packages of IPython to be tested each in their own subprocess using
503 nose.
511 nose.
504
512
505 Parameters
513 Parameters
506 ----------
514 ----------
507
515
508 All parameters are passed as attributes of the options object.
516 All parameters are passed as attributes of the options object.
509
517
510 testgroups : list of str
518 testgroups : list of str
511 Run only these sections of the test suite. If empty, run all the available
519 Run only these sections of the test suite. If empty, run all the available
512 sections.
520 sections.
513
521
514 fast : int or None
522 fast : int or None
515 Run the test suite in parallel, using n simultaneous processes. If None
523 Run the test suite in parallel, using n simultaneous processes. If None
516 is passed, one process is used per CPU core. Default 1 (i.e. sequential)
524 is passed, one process is used per CPU core. Default 1 (i.e. sequential)
517
525
518 inc_slow : bool
526 inc_slow : bool
519 Include slow tests, like IPython.parallel. By default, these tests aren't
527 Include slow tests, like IPython.parallel. By default, these tests aren't
520 run.
528 run.
521
529
522 slimerjs : bool
530 slimerjs : bool
523 Use slimerjs if it's installed instead of phantomjs for casperjs tests.
531 Use slimerjs if it's installed instead of phantomjs for casperjs tests.
524
532
525 url : unicode
533 url : unicode
526 Address:port to use when running the JS tests.
534 Address:port to use when running the JS tests.
527
535
528 xunit : bool
536 xunit : bool
529 Produce Xunit XML output. This is written to multiple foo.xunit.xml files.
537 Produce Xunit XML output. This is written to multiple foo.xunit.xml files.
530
538
531 coverage : bool or str
539 coverage : bool or str
532 Measure code coverage from tests. True will store the raw coverage data,
540 Measure code coverage from tests. True will store the raw coverage data,
533 or pass 'html' or 'xml' to get reports.
541 or pass 'html' or 'xml' to get reports.
534
542
535 extra_args : list
543 extra_args : list
536 Extra arguments to pass to the test subprocesses, e.g. '-v'
544 Extra arguments to pass to the test subprocesses, e.g. '-v'
537 """
545 """
538 to_run, not_run = prepare_controllers(options)
546 to_run, not_run = prepare_controllers(options)
539
547
540 def justify(ltext, rtext, width=70, fill='-'):
548 def justify(ltext, rtext, width=70, fill='-'):
541 ltext += ' '
549 ltext += ' '
542 rtext = (' ' + rtext).rjust(width - len(ltext), fill)
550 rtext = (' ' + rtext).rjust(width - len(ltext), fill)
543 return ltext + rtext
551 return ltext + rtext
544
552
545 # Run all test runners, tracking execution time
553 # Run all test runners, tracking execution time
546 failed = []
554 failed = []
547 t_start = time.time()
555 t_start = time.time()
548
556
549 print()
557 print()
550 if options.fast == 1:
558 if options.fast == 1:
551 # This actually means sequential, i.e. with 1 job
559 # This actually means sequential, i.e. with 1 job
552 for controller in to_run:
560 for controller in to_run:
553 print('Test group:', controller.section)
561 print('Test group:', controller.section)
554 sys.stdout.flush() # Show in correct order when output is piped
562 sys.stdout.flush() # Show in correct order when output is piped
555 controller, res = do_run(controller, buffer_output=False)
563 controller, res = do_run(controller, buffer_output=False)
556 if res:
564 if res:
557 failed.append(controller)
565 failed.append(controller)
558 if res == -signal.SIGINT:
566 if res == -signal.SIGINT:
559 print("Interrupted")
567 print("Interrupted")
560 break
568 break
561 print()
569 print()
562
570
563 else:
571 else:
564 # Run tests concurrently
572 # Run tests concurrently
565 try:
573 try:
566 pool = multiprocessing.pool.ThreadPool(options.fast)
574 pool = multiprocessing.pool.ThreadPool(options.fast)
567 for (controller, res) in pool.imap_unordered(do_run, to_run):
575 for (controller, res) in pool.imap_unordered(do_run, to_run):
568 res_string = 'OK' if res == 0 else 'FAILED'
576 res_string = 'OK' if res == 0 else 'FAILED'
569 print(justify('Test group: ' + controller.section, res_string))
577 print(justify('Test group: ' + controller.section, res_string))
570 if res:
578 if res:
571 controller.print_extra_info()
579 controller.print_extra_info()
572 print(bytes_to_str(controller.stdout))
580 print(bytes_to_str(controller.stdout))
573 failed.append(controller)
581 failed.append(controller)
574 if res == -signal.SIGINT:
582 if res == -signal.SIGINT:
575 print("Interrupted")
583 print("Interrupted")
576 break
584 break
577 except KeyboardInterrupt:
585 except KeyboardInterrupt:
578 return
586 return
579
587
580 for controller in not_run:
588 for controller in not_run:
581 print(justify('Test group: ' + controller.section, 'NOT RUN'))
589 print(justify('Test group: ' + controller.section, 'NOT RUN'))
582
590
583 t_end = time.time()
591 t_end = time.time()
584 t_tests = t_end - t_start
592 t_tests = t_end - t_start
585 nrunners = len(to_run)
593 nrunners = len(to_run)
586 nfail = len(failed)
594 nfail = len(failed)
587 # summarize results
595 # summarize results
588 print('_'*70)
596 print('_'*70)
589 print('Test suite completed for system with the following information:')
597 print('Test suite completed for system with the following information:')
590 print(report())
598 print(report())
591 took = "Took %.3fs." % t_tests
599 took = "Took %.3fs." % t_tests
592 print('Status: ', end='')
600 print('Status: ', end='')
593 if not failed:
601 if not failed:
594 print('OK (%d test groups).' % nrunners, took)
602 print('OK (%d test groups).' % nrunners, took)
595 else:
603 else:
596 # If anything went wrong, point out what command to rerun manually to
604 # If anything went wrong, point out what command to rerun manually to
597 # see the actual errors and individual summary
605 # see the actual errors and individual summary
598 failed_sections = [c.section for c in failed]
606 failed_sections = [c.section for c in failed]
599 print('ERROR - {} out of {} test groups failed ({}).'.format(nfail,
607 print('ERROR - {} out of {} test groups failed ({}).'.format(nfail,
600 nrunners, ', '.join(failed_sections)), took)
608 nrunners, ', '.join(failed_sections)), took)
601 print()
609 print()
602 print('You may wish to rerun these, with:')
610 print('You may wish to rerun these, with:')
603 print(' iptest', *failed_sections)
611 print(' iptest', *failed_sections)
604 print()
612 print()
605
613
606 if options.coverage:
614 if options.coverage:
607 from coverage import coverage
615 from coverage import coverage
608 cov = coverage(data_file='.coverage')
616 cov = coverage(data_file='.coverage')
609 cov.combine()
617 cov.combine()
610 cov.save()
618 cov.save()
611
619
612 # Coverage HTML report
620 # Coverage HTML report
613 if options.coverage == 'html':
621 if options.coverage == 'html':
614 html_dir = 'ipy_htmlcov'
622 html_dir = 'ipy_htmlcov'
615 shutil.rmtree(html_dir, ignore_errors=True)
623 shutil.rmtree(html_dir, ignore_errors=True)
616 print("Writing HTML coverage report to %s/ ... " % html_dir, end="")
624 print("Writing HTML coverage report to %s/ ... " % html_dir, end="")
617 sys.stdout.flush()
625 sys.stdout.flush()
618
626
619 # Custom HTML reporter to clean up module names.
627 # Custom HTML reporter to clean up module names.
620 from coverage.html import HtmlReporter
628 from coverage.html import HtmlReporter
621 class CustomHtmlReporter(HtmlReporter):
629 class CustomHtmlReporter(HtmlReporter):
622 def find_code_units(self, morfs):
630 def find_code_units(self, morfs):
623 super(CustomHtmlReporter, self).find_code_units(morfs)
631 super(CustomHtmlReporter, self).find_code_units(morfs)
624 for cu in self.code_units:
632 for cu in self.code_units:
625 nameparts = cu.name.split(os.sep)
633 nameparts = cu.name.split(os.sep)
626 if 'IPython' not in nameparts:
634 if 'IPython' not in nameparts:
627 continue
635 continue
628 ix = nameparts.index('IPython')
636 ix = nameparts.index('IPython')
629 cu.name = '.'.join(nameparts[ix:])
637 cu.name = '.'.join(nameparts[ix:])
630
638
631 # Reimplement the html_report method with our custom reporter
639 # Reimplement the html_report method with our custom reporter
632 cov._harvest_data()
640 cov._harvest_data()
633 cov.config.from_args(omit='*{0}tests{0}*'.format(os.sep), html_dir=html_dir,
641 cov.config.from_args(omit='*{0}tests{0}*'.format(os.sep), html_dir=html_dir,
634 html_title='IPython test coverage',
642 html_title='IPython test coverage',
635 )
643 )
636 reporter = CustomHtmlReporter(cov, cov.config)
644 reporter = CustomHtmlReporter(cov, cov.config)
637 reporter.report(None)
645 reporter.report(None)
638 print('done.')
646 print('done.')
639
647
640 # Coverage XML report
648 # Coverage XML report
641 elif options.coverage == 'xml':
649 elif options.coverage == 'xml':
642 cov.xml_report(outfile='ipy_coverage.xml')
650 cov.xml_report(outfile='ipy_coverage.xml')
643
651
644 if failed:
652 if failed:
645 # Ensure that our exit code indicates failure
653 # Ensure that our exit code indicates failure
646 sys.exit(1)
654 sys.exit(1)
647
655
648 argparser = argparse.ArgumentParser(description='Run IPython test suite')
656 argparser = argparse.ArgumentParser(description='Run IPython test suite')
649 argparser.add_argument('testgroups', nargs='*',
657 argparser.add_argument('testgroups', nargs='*',
650 help='Run specified groups of tests. If omitted, run '
658 help='Run specified groups of tests. If omitted, run '
651 'all tests.')
659 'all tests.')
652 argparser.add_argument('--all', action='store_true',
660 argparser.add_argument('--all', action='store_true',
653 help='Include slow tests not run by default.')
661 help='Include slow tests not run by default.')
654 argparser.add_argument('--slimerjs', action='store_true',
662 argparser.add_argument('--slimerjs', action='store_true',
655 help="Use slimerjs if it's installed instead of phantomjs for casperjs tests.")
663 help="Use slimerjs if it's installed instead of phantomjs for casperjs tests.")
656 argparser.add_argument('--url', help="URL to use for the JS tests.")
664 argparser.add_argument('--url', help="URL to use for the JS tests.")
657 argparser.add_argument('-j', '--fast', nargs='?', const=None, default=1, type=int,
665 argparser.add_argument('-j', '--fast', nargs='?', const=None, default=1, type=int,
658 help='Run test sections in parallel. This starts as many '
666 help='Run test sections in parallel. This starts as many '
659 'processes as you have cores, or you can specify a number.')
667 'processes as you have cores, or you can specify a number.')
660 argparser.add_argument('--xunit', action='store_true',
668 argparser.add_argument('--xunit', action='store_true',
661 help='Produce Xunit XML results')
669 help='Produce Xunit XML results')
662 argparser.add_argument('--coverage', nargs='?', const=True, default=False,
670 argparser.add_argument('--coverage', nargs='?', const=True, default=False,
663 help="Measure test coverage. Specify 'html' or "
671 help="Measure test coverage. Specify 'html' or "
664 "'xml' to get reports.")
672 "'xml' to get reports.")
665 argparser.add_argument('--subproc-streams', default='capture',
673 argparser.add_argument('--subproc-streams', default='capture',
666 help="What to do with stdout/stderr from subprocesses. "
674 help="What to do with stdout/stderr from subprocesses. "
667 "'capture' (default), 'show' and 'discard' are the options.")
675 "'capture' (default), 'show' and 'discard' are the options.")
668
676
669 def default_options():
677 def default_options():
670 """Get an argparse Namespace object with the default arguments, to pass to
678 """Get an argparse Namespace object with the default arguments, to pass to
671 :func:`run_iptestall`.
679 :func:`run_iptestall`.
672 """
680 """
673 options = argparser.parse_args([])
681 options = argparser.parse_args([])
674 options.extra_args = []
682 options.extra_args = []
675 return options
683 return options
676
684
677 def main():
685 def main():
678 # iptest doesn't work correctly if the working directory is the
686 # iptest doesn't work correctly if the working directory is the
679 # root of the IPython source tree. Tell the user to avoid
687 # root of the IPython source tree. Tell the user to avoid
680 # frustration.
688 # frustration.
681 if os.path.exists(os.path.join(os.getcwd(),
689 if os.path.exists(os.path.join(os.getcwd(),
682 'IPython', 'testing', '__main__.py')):
690 'IPython', 'testing', '__main__.py')):
683 print("Don't run iptest from the IPython source directory",
691 print("Don't run iptest from the IPython source directory",
684 file=sys.stderr)
692 file=sys.stderr)
685 sys.exit(1)
693 sys.exit(1)
686 # Arguments after -- should be passed through to nose. Argparse treats
694 # Arguments after -- should be passed through to nose. Argparse treats
687 # everything after -- as regular positional arguments, so we separate them
695 # everything after -- as regular positional arguments, so we separate them
688 # first.
696 # first.
689 try:
697 try:
690 ix = sys.argv.index('--')
698 ix = sys.argv.index('--')
691 except ValueError:
699 except ValueError:
692 to_parse = sys.argv[1:]
700 to_parse = sys.argv[1:]
693 extra_args = []
701 extra_args = []
694 else:
702 else:
695 to_parse = sys.argv[1:ix]
703 to_parse = sys.argv[1:ix]
696 extra_args = sys.argv[ix+1:]
704 extra_args = sys.argv[ix+1:]
697
705
698 options = argparser.parse_args(to_parse)
706 options = argparser.parse_args(to_parse)
699 options.extra_args = extra_args
707 options.extra_args = extra_args
700
708
701 run_iptestall(options)
709 run_iptestall(options)
702
710
703
711
704 if __name__ == '__main__':
712 if __name__ == '__main__':
705 main()
713 main()
General Comments 0
You need to be logged in to leave comments. Login now