##// END OF EJS Templates
testapp: moved login/csrf session methods into TestApp itself....
marcink -
r2374:e331d3e6 default
parent child Browse files
Show More
@@ -1,411 +1,429 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2017 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import threading
22 22 import time
23 23 import logging
24 24 import os.path
25 25 import subprocess32
26 26 import tempfile
27 27 import urllib2
28 28 from lxml.html import fromstring, tostring
29 29 from lxml.cssselect import CSSSelector
30 30 from urlparse import urlparse, parse_qsl
31 31 from urllib import unquote_plus
32 32 import webob
33 33
34 34 from webtest.app import TestResponse, TestApp, string_types
35 35 from webtest.compat import print_stderr
36 36
37 37 import pytest
38 38 import rc_testdata
39 39
40 40 from rhodecode.model.db import User, Repository
41 41 from rhodecode.model.meta import Session
42 42 from rhodecode.model.scm import ScmModel
43 43 from rhodecode.lib.vcs.backends.svn.repository import SubversionRepository
44 44 from rhodecode.lib.vcs.backends.base import EmptyCommit
45
45 from rhodecode.tests import login_user_session
46 46
47 47 log = logging.getLogger(__name__)
48 48
49 49
50 50 class CustomTestResponse(TestResponse):
51 51 def _save_output(self, out):
52 52 f = tempfile.NamedTemporaryFile(
53 53 delete=False, prefix='rc-test-', suffix='.html')
54 54 f.write(out)
55 55 return f.name
56 56
57 57 def mustcontain(self, *strings, **kw):
58 58 """
59 59 Assert that the response contains all of the strings passed
60 60 in as arguments.
61 61
62 62 Equivalent to::
63 63
64 64 assert string in res
65 65 """
66 66 if 'no' in kw:
67 67 no = kw['no']
68 68 del kw['no']
69 69 if isinstance(no, string_types):
70 70 no = [no]
71 71 else:
72 72 no = []
73 73 if kw:
74 74 raise TypeError(
75 "The only keyword argument allowed is 'no'")
75 "The only keyword argument allowed is 'no' got %s" % kw)
76 76
77 77 f = self._save_output(str(self))
78 78
79 79 for s in strings:
80 80 if not s in self:
81 81 print_stderr("Actual response (no %r):" % s)
82 82 print_stderr(str(self))
83 83 raise IndexError(
84 84 "Body does not contain string %r, output saved as %s" % (
85 85 s, f))
86 86
87 87 for no_s in no:
88 88 if no_s in self:
89 89 print_stderr("Actual response (has %r)" % no_s)
90 90 print_stderr(str(self))
91 91 raise IndexError(
92 92 "Body contains bad string %r, output saved as %s" % (
93 93 no_s, f))
94 94
95 95 def assert_response(self):
96 96 return AssertResponse(self)
97 97
98 98 def get_session_from_response(self):
99 99 """
100 100 This returns the session from a response object.
101 101 """
102 102
103 103 from pyramid_beaker import session_factory_from_settings
104 104 session = session_factory_from_settings(
105 105 self.test_app.app.config.get_settings())
106 106 return session(self.request)
107 107
108 108
109 109 class TestRequest(webob.BaseRequest):
110 110
111 111 # for py.test
112 112 disabled = True
113 113 ResponseClass = CustomTestResponse
114 114
115 115 def add_response_callback(self, callback):
116 116 pass
117 117
118 118
119 119 class CustomTestApp(TestApp):
120 120 """
121 Custom app to make mustcontain more usefull
121 Custom app to make mustcontain more usefull, and extract special methods
122 122 """
123 123 RequestClass = TestRequest
124 rc_login_data = {}
125 rc_current_session = None
126
127 def login(self, username=None, password=None):
128 from rhodecode.lib import auth
129
130 if username and password:
131 session = login_user_session(self, username, password)
132 else:
133 session = login_user_session(self)
134
135 self.rc_login_data['csrf_token'] = auth.get_csrf_token(session)
136 self.rc_current_session = session
137 return session['rhodecode_user']
138
139 @property
140 def csrf_token(self):
141 return self.rc_login_data['csrf_token']
124 142
125 143
126 144 def set_anonymous_access(enabled):
127 145 """(Dis)allows anonymous access depending on parameter `enabled`"""
128 146 user = User.get_default_user()
129 147 user.active = enabled
130 148 Session().add(user)
131 149 Session().commit()
132 150 time.sleep(1.5) # must sleep for cache (1s to expire)
133 151 log.info('anonymous access is now: %s', enabled)
134 152 assert enabled == User.get_default_user().active, (
135 153 'Cannot set anonymous access')
136 154
137 155
138 156 def check_xfail_backends(node, backend_alias):
139 157 # Using "xfail_backends" here intentionally, since this marks work
140 158 # which is "to be done" soon.
141 159 skip_marker = node.get_marker('xfail_backends')
142 160 if skip_marker and backend_alias in skip_marker.args:
143 161 msg = "Support for backend %s to be developed." % (backend_alias, )
144 162 msg = skip_marker.kwargs.get('reason', msg)
145 163 pytest.xfail(msg)
146 164
147 165
148 166 def check_skip_backends(node, backend_alias):
149 167 # Using "skip_backends" here intentionally, since this marks work which is
150 168 # not supported.
151 169 skip_marker = node.get_marker('skip_backends')
152 170 if skip_marker and backend_alias in skip_marker.args:
153 171 msg = "Feature not supported for backend %s." % (backend_alias, )
154 172 msg = skip_marker.kwargs.get('reason', msg)
155 173 pytest.skip(msg)
156 174
157 175
158 176 def extract_git_repo_from_dump(dump_name, repo_name):
159 177 """Create git repo `repo_name` from dump `dump_name`."""
160 178 repos_path = ScmModel().repos_path
161 179 target_path = os.path.join(repos_path, repo_name)
162 180 rc_testdata.extract_git_dump(dump_name, target_path)
163 181 return target_path
164 182
165 183
166 184 def extract_hg_repo_from_dump(dump_name, repo_name):
167 185 """Create hg repo `repo_name` from dump `dump_name`."""
168 186 repos_path = ScmModel().repos_path
169 187 target_path = os.path.join(repos_path, repo_name)
170 188 rc_testdata.extract_hg_dump(dump_name, target_path)
171 189 return target_path
172 190
173 191
174 192 def extract_svn_repo_from_dump(dump_name, repo_name):
175 193 """Create a svn repo `repo_name` from dump `dump_name`."""
176 194 repos_path = ScmModel().repos_path
177 195 target_path = os.path.join(repos_path, repo_name)
178 196 SubversionRepository(target_path, create=True)
179 197 _load_svn_dump_into_repo(dump_name, target_path)
180 198 return target_path
181 199
182 200
183 201 def assert_message_in_log(log_records, message, levelno, module):
184 202 messages = [
185 203 r.message for r in log_records
186 204 if r.module == module and r.levelno == levelno
187 205 ]
188 206 assert message in messages
189 207
190 208
191 209 def _load_svn_dump_into_repo(dump_name, repo_path):
192 210 """
193 211 Utility to populate a svn repository with a named dump
194 212
195 213 Currently the dumps are in rc_testdata. They might later on be
196 214 integrated with the main repository once they stabilize more.
197 215 """
198 216 dump = rc_testdata.load_svn_dump(dump_name)
199 217 load_dump = subprocess32.Popen(
200 218 ['svnadmin', 'load', repo_path],
201 219 stdin=subprocess32.PIPE, stdout=subprocess32.PIPE,
202 220 stderr=subprocess32.PIPE)
203 221 out, err = load_dump.communicate(dump)
204 222 if load_dump.returncode != 0:
205 223 log.error("Output of load_dump command: %s", out)
206 224 log.error("Error output of load_dump command: %s", err)
207 225 raise Exception(
208 226 'Failed to load dump "%s" into repository at path "%s".'
209 227 % (dump_name, repo_path))
210 228
211 229
212 230 class AssertResponse(object):
213 231 """
214 232 Utility that helps to assert things about a given HTML response.
215 233 """
216 234
217 235 def __init__(self, response):
218 236 self.response = response
219 237
220 238 def get_imports(self):
221 239 return fromstring, tostring, CSSSelector
222 240
223 241 def one_element_exists(self, css_selector):
224 242 self.get_element(css_selector)
225 243
226 244 def no_element_exists(self, css_selector):
227 245 assert not self._get_elements(css_selector)
228 246
229 247 def element_equals_to(self, css_selector, expected_content):
230 248 element = self.get_element(css_selector)
231 249 element_text = self._element_to_string(element)
232 250 assert expected_content in element_text
233 251
234 252 def element_contains(self, css_selector, expected_content):
235 253 element = self.get_element(css_selector)
236 254 assert expected_content in element.text_content()
237 255
238 256 def element_value_contains(self, css_selector, expected_content):
239 257 element = self.get_element(css_selector)
240 258 assert expected_content in element.value
241 259
242 260 def contains_one_link(self, link_text, href):
243 261 fromstring, tostring, CSSSelector = self.get_imports()
244 262 doc = fromstring(self.response.body)
245 263 sel = CSSSelector('a[href]')
246 264 elements = [
247 265 e for e in sel(doc) if e.text_content().strip() == link_text]
248 266 assert len(elements) == 1, "Did not find link or found multiple links"
249 267 self._ensure_url_equal(elements[0].attrib.get('href'), href)
250 268
251 269 def contains_one_anchor(self, anchor_id):
252 270 fromstring, tostring, CSSSelector = self.get_imports()
253 271 doc = fromstring(self.response.body)
254 272 sel = CSSSelector('#' + anchor_id)
255 273 elements = sel(doc)
256 274 assert len(elements) == 1, 'cannot find 1 element {}'.format(anchor_id)
257 275
258 276 def _ensure_url_equal(self, found, expected):
259 277 assert _Url(found) == _Url(expected)
260 278
261 279 def get_element(self, css_selector):
262 280 elements = self._get_elements(css_selector)
263 281 assert len(elements) == 1, 'cannot find 1 element {}'.format(css_selector)
264 282 return elements[0]
265 283
266 284 def get_elements(self, css_selector):
267 285 return self._get_elements(css_selector)
268 286
269 287 def _get_elements(self, css_selector):
270 288 fromstring, tostring, CSSSelector = self.get_imports()
271 289 doc = fromstring(self.response.body)
272 290 sel = CSSSelector(css_selector)
273 291 elements = sel(doc)
274 292 return elements
275 293
276 294 def _element_to_string(self, element):
277 295 fromstring, tostring, CSSSelector = self.get_imports()
278 296 return tostring(element)
279 297
280 298
281 299 class _Url(object):
282 300 """
283 301 A url object that can be compared with other url orbjects
284 302 without regard to the vagaries of encoding, escaping, and ordering
285 303 of parameters in query strings.
286 304
287 305 Inspired by
288 306 http://stackoverflow.com/questions/5371992/comparing-two-urls-in-python
289 307 """
290 308
291 309 def __init__(self, url):
292 310 parts = urlparse(url)
293 311 _query = frozenset(parse_qsl(parts.query))
294 312 _path = unquote_plus(parts.path)
295 313 parts = parts._replace(query=_query, path=_path)
296 314 self.parts = parts
297 315
298 316 def __eq__(self, other):
299 317 return self.parts == other.parts
300 318
301 319 def __hash__(self):
302 320 return hash(self.parts)
303 321
304 322
305 323 def run_test_concurrently(times, raise_catched_exc=True):
306 324 """
307 325 Add this decorator to small pieces of code that you want to test
308 326 concurrently
309 327
310 328 ex:
311 329
312 330 @test_concurrently(25)
313 331 def my_test_function():
314 332 ...
315 333 """
316 334 def test_concurrently_decorator(test_func):
317 335 def wrapper(*args, **kwargs):
318 336 exceptions = []
319 337
320 338 def call_test_func():
321 339 try:
322 340 test_func(*args, **kwargs)
323 341 except Exception as e:
324 342 exceptions.append(e)
325 343 if raise_catched_exc:
326 344 raise
327 345 threads = []
328 346 for i in range(times):
329 347 threads.append(threading.Thread(target=call_test_func))
330 348 for t in threads:
331 349 t.start()
332 350 for t in threads:
333 351 t.join()
334 352 if exceptions:
335 353 raise Exception(
336 354 'test_concurrently intercepted %s exceptions: %s' % (
337 355 len(exceptions), exceptions))
338 356 return wrapper
339 357 return test_concurrently_decorator
340 358
341 359
342 360 def wait_for_url(url, timeout=10):
343 361 """
344 362 Wait until URL becomes reachable.
345 363
346 364 It polls the URL until the timeout is reached or it became reachable.
347 365 If will call to `py.test.fail` in case the URL is not reachable.
348 366 """
349 367 timeout = time.time() + timeout
350 368 last = 0
351 369 wait = 0.1
352 370
353 371 while timeout > last:
354 372 last = time.time()
355 373 if is_url_reachable(url):
356 374 break
357 375 elif (last + wait) > time.time():
358 376 # Go to sleep because not enough time has passed since last check.
359 377 time.sleep(wait)
360 378 else:
361 379 pytest.fail("Timeout while waiting for URL {}".format(url))
362 380
363 381
364 382 def is_url_reachable(url):
365 383 try:
366 384 urllib2.urlopen(url)
367 385 except urllib2.URLError:
368 386 return False
369 387 return True
370 388
371 389
372 390 def repo_on_filesystem(repo_name):
373 391 from rhodecode.lib import vcs
374 392 from rhodecode.tests import TESTS_TMP_PATH
375 393 repo = vcs.get_vcs_instance(
376 394 os.path.join(TESTS_TMP_PATH, repo_name), create=False)
377 395 return repo is not None
378 396
379 397
380 398 def commit_change(
381 399 repo, filename, content, message, vcs_type, parent=None, newfile=False):
382 400 from rhodecode.tests import TEST_USER_ADMIN_LOGIN
383 401
384 402 repo = Repository.get_by_repo_name(repo)
385 403 _commit = parent
386 404 if not parent:
387 405 _commit = EmptyCommit(alias=vcs_type)
388 406
389 407 if newfile:
390 408 nodes = {
391 409 filename: {
392 410 'content': content
393 411 }
394 412 }
395 413 commit = ScmModel().create_nodes(
396 414 user=TEST_USER_ADMIN_LOGIN, repo=repo,
397 415 message=message,
398 416 nodes=nodes,
399 417 parent_commit=_commit,
400 418 author=TEST_USER_ADMIN_LOGIN,
401 419 )
402 420 else:
403 421 commit = ScmModel().commit_change(
404 422 repo=repo.scm_instance(), repo_name=repo.repo_name,
405 423 commit=parent, user=TEST_USER_ADMIN_LOGIN,
406 424 author=TEST_USER_ADMIN_LOGIN,
407 425 message=message,
408 426 content=content,
409 427 f_path=filename
410 428 )
411 429 return commit
General Comments 0
You need to be logged in to leave comments. Login now