##// END OF EJS Templates
hooks: store hook type for extensions.
marcink -
r555:315f8a04 default
parent child Browse files
Show More
@@ -1,658 +1,662 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # RhodeCode VCSServer provides access to different vcs backends via network.
4 4 # Copyright (C) 2014-2018 RhodeCode GmbH
5 5 #
6 6 # This program is free software; you can redistribute it and/or modify
7 7 # it under the terms of the GNU General Public License as published by
8 8 # the Free Software Foundation; either version 3 of the License, or
9 9 # (at your option) any later version.
10 10 #
11 11 # This program is distributed in the hope that it will be useful,
12 12 # but WITHOUT ANY WARRANTY; without even the implied warranty of
13 13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 14 # GNU General Public License for more details.
15 15 #
16 16 # You should have received a copy of the GNU General Public License
17 17 # along with this program; if not, write to the Free Software Foundation,
18 18 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
19 19
20 20 import io
21 21 import os
22 22 import sys
23 23 import logging
24 24 import collections
25 25 import importlib
26 26 import base64
27 27
28 28 from httplib import HTTPConnection
29 29
30 30
31 31 import mercurial.scmutil
32 32 import mercurial.node
33 33 import simplejson as json
34 34
35 35 from vcsserver import exceptions, subprocessio, settings
36 36
37 37 log = logging.getLogger(__name__)
38 38
39 39
40 40 class HooksHttpClient(object):
41 41 connection = None
42 42
43 43 def __init__(self, hooks_uri):
44 44 self.hooks_uri = hooks_uri
45 45
46 46 def __call__(self, method, extras):
47 47 connection = HTTPConnection(self.hooks_uri)
48 48 body = self._serialize(method, extras)
49 49 try:
50 50 connection.request('POST', '/', body)
51 51 except Exception:
52 52 log.error('Connection failed on %s', connection)
53 53 raise
54 54 response = connection.getresponse()
55 55 return json.loads(response.read())
56 56
57 57 def _serialize(self, hook_name, extras):
58 58 data = {
59 59 'method': hook_name,
60 60 'extras': extras
61 61 }
62 62 return json.dumps(data)
63 63
64 64
65 65 class HooksDummyClient(object):
66 66 def __init__(self, hooks_module):
67 67 self._hooks_module = importlib.import_module(hooks_module)
68 68
69 69 def __call__(self, hook_name, extras):
70 70 with self._hooks_module.Hooks() as hooks:
71 71 return getattr(hooks, hook_name)(extras)
72 72
73 73
74 74 class RemoteMessageWriter(object):
75 75 """Writer base class."""
76 76 def write(self, message):
77 77 raise NotImplementedError()
78 78
79 79
80 80 class HgMessageWriter(RemoteMessageWriter):
81 81 """Writer that knows how to send messages to mercurial clients."""
82 82
83 83 def __init__(self, ui):
84 84 self.ui = ui
85 85
86 86 def write(self, message):
87 87 # TODO: Check why the quiet flag is set by default.
88 88 old = self.ui.quiet
89 89 self.ui.quiet = False
90 90 self.ui.status(message.encode('utf-8'))
91 91 self.ui.quiet = old
92 92
93 93
94 94 class GitMessageWriter(RemoteMessageWriter):
95 95 """Writer that knows how to send messages to git clients."""
96 96
97 97 def __init__(self, stdout=None):
98 98 self.stdout = stdout or sys.stdout
99 99
100 100 def write(self, message):
101 101 self.stdout.write(message.encode('utf-8'))
102 102
103 103
104 104 class SvnMessageWriter(RemoteMessageWriter):
105 105 """Writer that knows how to send messages to svn clients."""
106 106
107 107 def __init__(self, stderr=None):
108 108 # SVN needs data sent to stderr for back-to-client messaging
109 109 self.stderr = stderr or sys.stderr
110 110
111 111 def write(self, message):
112 112 self.stderr.write(message.encode('utf-8'))
113 113
114 114
115 115 def _handle_exception(result):
116 116 exception_class = result.get('exception')
117 117 exception_traceback = result.get('exception_traceback')
118 118
119 119 if exception_traceback:
120 120 log.error('Got traceback from remote call:%s', exception_traceback)
121 121
122 122 if exception_class == 'HTTPLockedRC':
123 123 raise exceptions.RepositoryLockedException()(*result['exception_args'])
124 124 elif exception_class == 'HTTPBranchProtected':
125 125 raise exceptions.RepositoryBranchProtectedException()(*result['exception_args'])
126 126 elif exception_class == 'RepositoryError':
127 127 raise exceptions.VcsException()(*result['exception_args'])
128 128 elif exception_class:
129 129 raise Exception('Got remote exception "%s" with args "%s"' %
130 130 (exception_class, result['exception_args']))
131 131
132 132
133 133 def _get_hooks_client(extras):
134 134 if 'hooks_uri' in extras:
135 135 protocol = extras.get('hooks_protocol')
136 136 return HooksHttpClient(extras['hooks_uri'])
137 137 else:
138 138 return HooksDummyClient(extras['hooks_module'])
139 139
140 140
141 141 def _call_hook(hook_name, extras, writer):
142 142 hooks_client = _get_hooks_client(extras)
143 143 log.debug('Hooks, using client:%s', hooks_client)
144 144 result = hooks_client(hook_name, extras)
145 145 log.debug('Hooks got result: %s', result)
146 146
147 147 _handle_exception(result)
148 148 writer.write(result['output'])
149 149
150 150 return result['status']
151 151
152 152
153 153 def _extras_from_ui(ui):
154 154 hook_data = ui.config('rhodecode', 'RC_SCM_DATA')
155 155 if not hook_data:
156 156 # maybe it's inside environ ?
157 157 env_hook_data = os.environ.get('RC_SCM_DATA')
158 158 if env_hook_data:
159 159 hook_data = env_hook_data
160 160
161 161 extras = {}
162 162 if hook_data:
163 163 extras = json.loads(hook_data)
164 164 return extras
165 165
166 166
167 167 def _rev_range_hash(repo, node, check_heads=False):
168 168
169 169 commits = []
170 170 revs = []
171 171 start = repo[node].rev()
172 172 end = len(repo)
173 173 for rev in range(start, end):
174 174 revs.append(rev)
175 175 ctx = repo[rev]
176 176 commit_id = mercurial.node.hex(ctx.node())
177 177 branch = ctx.branch()
178 178 commits.append((commit_id, branch))
179 179
180 180 parent_heads = []
181 181 if check_heads:
182 182 parent_heads = _check_heads(repo, start, end, revs)
183 183 return commits, parent_heads
184 184
185 185
186 186 def _check_heads(repo, start, end, commits):
187 187 changelog = repo.changelog
188 188 parents = set()
189 189
190 190 for new_rev in commits:
191 191 for p in changelog.parentrevs(new_rev):
192 192 if p == mercurial.node.nullrev:
193 193 continue
194 194 if p < start:
195 195 parents.add(p)
196 196
197 197 for p in parents:
198 198 branch = repo[p].branch()
199 199 # The heads descending from that parent, on the same branch
200 200 parent_heads = set([p])
201 201 reachable = set([p])
202 202 for x in xrange(p + 1, end):
203 203 if repo[x].branch() != branch:
204 204 continue
205 205 for pp in changelog.parentrevs(x):
206 206 if pp in reachable:
207 207 reachable.add(x)
208 208 parent_heads.discard(pp)
209 209 parent_heads.add(x)
210 210 # More than one head? Suggest merging
211 211 if len(parent_heads) > 1:
212 212 return list(parent_heads)
213 213
214 214 return []
215 215
216 216
217 217 def repo_size(ui, repo, **kwargs):
218 218 extras = _extras_from_ui(ui)
219 219 return _call_hook('repo_size', extras, HgMessageWriter(ui))
220 220
221 221
222 222 def pre_pull(ui, repo, **kwargs):
223 223 extras = _extras_from_ui(ui)
224 224 return _call_hook('pre_pull', extras, HgMessageWriter(ui))
225 225
226 226
227 227 def pre_pull_ssh(ui, repo, **kwargs):
228 228 extras = _extras_from_ui(ui)
229 229 if extras and extras.get('SSH'):
230 230 return pre_pull(ui, repo, **kwargs)
231 231 return 0
232 232
233 233
234 234 def post_pull(ui, repo, **kwargs):
235 235 extras = _extras_from_ui(ui)
236 236 return _call_hook('post_pull', extras, HgMessageWriter(ui))
237 237
238 238
239 239 def post_pull_ssh(ui, repo, **kwargs):
240 240 extras = _extras_from_ui(ui)
241 241 if extras and extras.get('SSH'):
242 242 return post_pull(ui, repo, **kwargs)
243 243 return 0
244 244
245 245
246 246 def pre_push(ui, repo, node=None, **kwargs):
247 247 """
248 248 Mercurial pre_push hook
249 249 """
250 250 extras = _extras_from_ui(ui)
251 251 detect_force_push = extras.get('detect_force_push')
252 252
253 253 rev_data = []
254 254 if node and kwargs.get('hooktype') == 'pretxnchangegroup':
255 255 branches = collections.defaultdict(list)
256 256 commits, _heads = _rev_range_hash(repo, node, check_heads=detect_force_push)
257 257 for commit_id, branch in commits:
258 258 branches[branch].append(commit_id)
259 259
260 260 for branch, commits in branches.items():
261 261 old_rev = kwargs.get('node_last') or commits[0]
262 262 rev_data.append({
263 263 'old_rev': old_rev,
264 264 'new_rev': commits[-1],
265 265 'ref': '',
266 266 'type': 'branch',
267 267 'name': branch,
268 268 })
269 269
270 270 for push_ref in rev_data:
271 271 push_ref['multiple_heads'] = _heads
272 272
273 extras['hook_type'] = kwargs.get('hooktype', 'pre_push')
273 274 extras['commit_ids'] = rev_data
274 275 return _call_hook('pre_push', extras, HgMessageWriter(ui))
275 276
276 277
277 278 def pre_push_ssh(ui, repo, node=None, **kwargs):
278 279 extras = _extras_from_ui(ui)
279 280 if extras.get('SSH'):
280 281 return pre_push(ui, repo, node, **kwargs)
281 282
282 283 return 0
283 284
284 285
285 286 def pre_push_ssh_auth(ui, repo, node=None, **kwargs):
286 287 """
287 288 Mercurial pre_push hook for SSH
288 289 """
289 290 extras = _extras_from_ui(ui)
290 291 if extras.get('SSH'):
291 292 permission = extras['SSH_PERMISSIONS']
292 293
293 294 if 'repository.write' == permission or 'repository.admin' == permission:
294 295 return 0
295 296
296 297 # non-zero ret code
297 298 return 1
298 299
299 300 return 0
300 301
301 302
302 303 def post_push(ui, repo, node, **kwargs):
303 304 """
304 305 Mercurial post_push hook
305 306 """
306 307 extras = _extras_from_ui(ui)
307 308
308 309 commit_ids = []
309 310 branches = []
310 311 bookmarks = []
311 312 tags = []
312 313
313 314 commits, _heads = _rev_range_hash(repo, node)
314 315 for commit_id, branch in commits:
315 316 commit_ids.append(commit_id)
316 317 if branch not in branches:
317 318 branches.append(branch)
318 319
319 320 if hasattr(ui, '_rc_pushkey_branches'):
320 321 bookmarks = ui._rc_pushkey_branches
321 322
323 extras['hook_type'] = kwargs.get('hooktype', 'post_push')
322 324 extras['commit_ids'] = commit_ids
323 325 extras['new_refs'] = {
324 326 'branches': branches,
325 327 'bookmarks': bookmarks,
326 328 'tags': tags
327 329 }
328 330
329 331 return _call_hook('post_push', extras, HgMessageWriter(ui))
330 332
331 333
332 334 def post_push_ssh(ui, repo, node, **kwargs):
333 335 """
334 336 Mercurial post_push hook for SSH
335 337 """
336 338 if _extras_from_ui(ui).get('SSH'):
337 339 return post_push(ui, repo, node, **kwargs)
338 340 return 0
339 341
340 342
341 343 def key_push(ui, repo, **kwargs):
342 344 if kwargs['new'] != '0' and kwargs['namespace'] == 'bookmarks':
343 345 # store new bookmarks in our UI object propagated later to post_push
344 346 ui._rc_pushkey_branches = repo[kwargs['key']].bookmarks()
345 347 return
346 348
347 349
348 350 # backward compat
349 351 log_pull_action = post_pull
350 352
351 353 # backward compat
352 354 log_push_action = post_push
353 355
354 356
355 357 def handle_git_pre_receive(unused_repo_path, unused_revs, unused_env):
356 358 """
357 359 Old hook name: keep here for backward compatibility.
358 360
359 361 This is only required when the installed git hooks are not upgraded.
360 362 """
361 363 pass
362 364
363 365
364 366 def handle_git_post_receive(unused_repo_path, unused_revs, unused_env):
365 367 """
366 368 Old hook name: keep here for backward compatibility.
367 369
368 370 This is only required when the installed git hooks are not upgraded.
369 371 """
370 372 pass
371 373
372 374
373 375 HookResponse = collections.namedtuple('HookResponse', ('status', 'output'))
374 376
375 377
376 378 def git_pre_pull(extras):
377 379 """
378 380 Pre pull hook.
379 381
380 382 :param extras: dictionary containing the keys defined in simplevcs
381 383 :type extras: dict
382 384
383 385 :return: status code of the hook. 0 for success.
384 386 :rtype: int
385 387 """
386 388 if 'pull' not in extras['hooks']:
387 389 return HookResponse(0, '')
388 390
389 391 stdout = io.BytesIO()
390 392 try:
391 393 status = _call_hook('pre_pull', extras, GitMessageWriter(stdout))
392 394 except Exception as error:
393 395 status = 128
394 396 stdout.write('ERROR: %s\n' % str(error))
395 397
396 398 return HookResponse(status, stdout.getvalue())
397 399
398 400
399 401 def git_post_pull(extras):
400 402 """
401 403 Post pull hook.
402 404
403 405 :param extras: dictionary containing the keys defined in simplevcs
404 406 :type extras: dict
405 407
406 408 :return: status code of the hook. 0 for success.
407 409 :rtype: int
408 410 """
409 411 if 'pull' not in extras['hooks']:
410 412 return HookResponse(0, '')
411 413
412 414 stdout = io.BytesIO()
413 415 try:
414 416 status = _call_hook('post_pull', extras, GitMessageWriter(stdout))
415 417 except Exception as error:
416 418 status = 128
417 419 stdout.write('ERROR: %s\n' % error)
418 420
419 421 return HookResponse(status, stdout.getvalue())
420 422
421 423
422 424 def _parse_git_ref_lines(revision_lines):
423 425 rev_data = []
424 426 for revision_line in revision_lines or []:
425 427 old_rev, new_rev, ref = revision_line.strip().split(' ')
426 428 ref_data = ref.split('/', 2)
427 429 if ref_data[1] in ('tags', 'heads'):
428 430 rev_data.append({
429 431 'old_rev': old_rev,
430 432 'new_rev': new_rev,
431 433 'ref': ref,
432 434 'type': ref_data[1],
433 435 'name': ref_data[2],
434 436 })
435 437 return rev_data
436 438
437 439
438 440 def git_pre_receive(unused_repo_path, revision_lines, env):
439 441 """
440 442 Pre push hook.
441 443
442 444 :param extras: dictionary containing the keys defined in simplevcs
443 445 :type extras: dict
444 446
445 447 :return: status code of the hook. 0 for success.
446 448 :rtype: int
447 449 """
448 450 extras = json.loads(env['RC_SCM_DATA'])
449 451 rev_data = _parse_git_ref_lines(revision_lines)
450 452 if 'push' not in extras['hooks']:
451 453 return 0
452 454 empty_commit_id = '0' * 40
453 455
454 456 detect_force_push = extras.get('detect_force_push')
455 457
456 458 for push_ref in rev_data:
457 459 # store our git-env which holds the temp store
458 460 push_ref['git_env'] = [
459 461 (k, v) for k, v in os.environ.items() if k.startswith('GIT')]
460 462 push_ref['pruned_sha'] = ''
461 463 if not detect_force_push:
462 464 # don't check for forced-push when we don't need to
463 465 continue
464 466
465 467 type_ = push_ref['type']
466 468 new_branch = push_ref['old_rev'] == empty_commit_id
467 469 if type_ == 'heads' and not new_branch:
468 470 old_rev = push_ref['old_rev']
469 471 new_rev = push_ref['new_rev']
470 472 cmd = [settings.GIT_EXECUTABLE, 'rev-list',
471 473 old_rev, '^{}'.format(new_rev)]
472 474 stdout, stderr = subprocessio.run_command(
473 475 cmd, env=os.environ.copy())
474 476 # means we're having some non-reachable objects, this forced push
475 477 # was used
476 478 if stdout:
477 479 push_ref['pruned_sha'] = stdout.splitlines()
478 480
481 extras['hook_type'] = 'pre_receive'
479 482 extras['commit_ids'] = rev_data
480 483 return _call_hook('pre_push', extras, GitMessageWriter())
481 484
482 485
483 486 def git_post_receive(unused_repo_path, revision_lines, env):
484 487 """
485 488 Post push hook.
486 489
487 490 :param extras: dictionary containing the keys defined in simplevcs
488 491 :type extras: dict
489 492
490 493 :return: status code of the hook. 0 for success.
491 494 :rtype: int
492 495 """
493 496 extras = json.loads(env['RC_SCM_DATA'])
494 497 if 'push' not in extras['hooks']:
495 498 return 0
496 499
497 500 rev_data = _parse_git_ref_lines(revision_lines)
498 501
499 502 git_revs = []
500 503
501 504 # N.B.(skreft): it is ok to just call git, as git before calling a
502 505 # subcommand sets the PATH environment variable so that it point to the
503 506 # correct version of the git executable.
504 507 empty_commit_id = '0' * 40
505 508 branches = []
506 509 tags = []
507 510 for push_ref in rev_data:
508 511 type_ = push_ref['type']
509 512
510 513 if type_ == 'heads':
511 514 if push_ref['old_rev'] == empty_commit_id:
512 515 # starting new branch case
513 516 if push_ref['name'] not in branches:
514 517 branches.append(push_ref['name'])
515 518
516 519 # Fix up head revision if needed
517 520 cmd = [settings.GIT_EXECUTABLE, 'show', 'HEAD']
518 521 try:
519 522 subprocessio.run_command(cmd, env=os.environ.copy())
520 523 except Exception:
521 524 cmd = [settings.GIT_EXECUTABLE, 'symbolic-ref', 'HEAD',
522 525 'refs/heads/%s' % push_ref['name']]
523 526 print("Setting default branch to %s" % push_ref['name'])
524 527 subprocessio.run_command(cmd, env=os.environ.copy())
525 528
526 529 cmd = [settings.GIT_EXECUTABLE, 'for-each-ref',
527 530 '--format=%(refname)', 'refs/heads/*']
528 531 stdout, stderr = subprocessio.run_command(
529 532 cmd, env=os.environ.copy())
530 533 heads = stdout
531 534 heads = heads.replace(push_ref['ref'], '')
532 535 heads = ' '.join(head for head
533 536 in heads.splitlines() if head) or '.'
534 537 cmd = [settings.GIT_EXECUTABLE, 'log', '--reverse',
535 538 '--pretty=format:%H', '--', push_ref['new_rev'],
536 539 '--not', heads]
537 540 stdout, stderr = subprocessio.run_command(
538 541 cmd, env=os.environ.copy())
539 542 git_revs.extend(stdout.splitlines())
540 543 elif push_ref['new_rev'] == empty_commit_id:
541 544 # delete branch case
542 545 git_revs.append('delete_branch=>%s' % push_ref['name'])
543 546 else:
544 547 if push_ref['name'] not in branches:
545 548 branches.append(push_ref['name'])
546 549
547 550 cmd = [settings.GIT_EXECUTABLE, 'log',
548 551 '{old_rev}..{new_rev}'.format(**push_ref),
549 552 '--reverse', '--pretty=format:%H']
550 553 stdout, stderr = subprocessio.run_command(
551 554 cmd, env=os.environ.copy())
552 555 git_revs.extend(stdout.splitlines())
553 556 elif type_ == 'tags':
554 557 if push_ref['name'] not in tags:
555 558 tags.append(push_ref['name'])
556 559 git_revs.append('tag=>%s' % push_ref['name'])
557 560
561 extras['hook_type'] = 'post_receive'
558 562 extras['commit_ids'] = git_revs
559 563 extras['new_refs'] = {
560 564 'branches': branches,
561 565 'bookmarks': [],
562 566 'tags': tags,
563 567 }
564 568
565 569 if 'repo_size' in extras['hooks']:
566 570 try:
567 571 _call_hook('repo_size', extras, GitMessageWriter())
568 572 except:
569 573 pass
570 574
571 575 return _call_hook('post_push', extras, GitMessageWriter())
572 576
573 577
574 578 def _get_extras_from_txn_id(path, txn_id):
575 579 extras = {}
576 580 try:
577 581 cmd = ['svnlook', 'pget',
578 582 '-t', txn_id,
579 583 '--revprop', path, 'rc-scm-extras']
580 584 stdout, stderr = subprocessio.run_command(
581 585 cmd, env=os.environ.copy())
582 586 extras = json.loads(base64.urlsafe_b64decode(stdout))
583 587 except Exception:
584 588 log.exception('Failed to extract extras info from txn_id')
585 589
586 590 return extras
587 591
588 592
589 593 def svn_pre_commit(repo_path, commit_data, env):
590 594 path, txn_id = commit_data
591 595 branches = []
592 596 tags = []
593 597
594 598 if env.get('RC_SCM_DATA'):
595 599 extras = json.loads(env['RC_SCM_DATA'])
596 600 else:
597 601 # fallback method to read from TXN-ID stored data
598 602 extras = _get_extras_from_txn_id(path, txn_id)
599 603 if not extras:
600 604 return 0
601 605
602 606 extras['commit_ids'] = []
603 607 extras['txn_id'] = txn_id
604 608 extras['new_refs'] = {
605 609 'branches': branches,
606 610 'bookmarks': [],
607 611 'tags': tags,
608 612 }
609 613
610 614 return _call_hook('pre_push', extras, SvnMessageWriter())
611 615
612 616
613 617 def _get_extras_from_commit_id(commit_id, path):
614 618 extras = {}
615 619 try:
616 620 cmd = ['svnlook', 'pget',
617 621 '-r', commit_id,
618 622 '--revprop', path, 'rc-scm-extras']
619 623 stdout, stderr = subprocessio.run_command(
620 624 cmd, env=os.environ.copy())
621 625 extras = json.loads(base64.urlsafe_b64decode(stdout))
622 626 except Exception:
623 627 log.exception('Failed to extract extras info from commit_id')
624 628
625 629 return extras
626 630
627 631
628 632 def svn_post_commit(repo_path, commit_data, env):
629 633 """
630 634 commit_data is path, rev, txn_id
631 635 """
632 636 path, commit_id, txn_id = commit_data
633 637 branches = []
634 638 tags = []
635 639
636 640 if env.get('RC_SCM_DATA'):
637 641 extras = json.loads(env['RC_SCM_DATA'])
638 642 else:
639 643 # fallback method to read from TXN-ID stored data
640 644 extras = _get_extras_from_commit_id(commit_id, path)
641 645 if not extras:
642 646 return 0
643 647
644 648 extras['commit_ids'] = [commit_id]
645 649 extras['txn_id'] = txn_id
646 650 extras['new_refs'] = {
647 651 'branches': branches,
648 652 'bookmarks': [],
649 653 'tags': tags,
650 654 }
651 655
652 656 if 'repo_size' in extras['hooks']:
653 657 try:
654 658 _call_hook('repo_size', extras, SvnMessageWriter())
655 659 except Exception:
656 660 pass
657 661
658 662 return _call_hook('post_push', extras, SvnMessageWriter())
@@ -1,86 +1,86 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2018 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import os
19 19 import shutil
20 20 import tempfile
21 21
22 22 import configobj
23 23
24 24
25 25 class ContextINI(object):
26 26 """
27 27 Allows to create a new test.ini file as a copy of existing one with edited
28 28 data. If existing file is not present, it creates a new one. Example usage::
29 29
30 30 with TestINI('test.ini', [{'section': {'key': 'val'}}]) as new_test_ini_path:
31 31 print 'vcsserver --config=%s' % new_test_ini
32 32 """
33 33
34 34 def __init__(self, ini_file_path, ini_params, new_file_prefix=None,
35 35 destroy=True):
36 36 self.ini_file_path = ini_file_path
37 37 self.ini_params = ini_params
38 38 self.new_path = None
39 39 self.new_path_prefix = new_file_prefix or 'test'
40 40 self.destroy = destroy
41 41
42 42 def __enter__(self):
43 43 _, pref = tempfile.mkstemp()
44 44 loc = tempfile.gettempdir()
45 45 self.new_path = os.path.join(loc, '{}_{}_{}'.format(
46 46 pref, self.new_path_prefix, self.ini_file_path))
47 47
48 48 # copy ini file and modify according to the params, if we re-use a file
49 49 if os.path.isfile(self.ini_file_path):
50 50 shutil.copy(self.ini_file_path, self.new_path)
51 51 else:
52 52 # create new dump file for configObj to write to.
53 53 with open(self.new_path, 'wb'):
54 54 pass
55 55
56 56 config = configobj.ConfigObj(
57 57 self.new_path, file_error=True, write_empty_values=True)
58 58
59 59 for data in self.ini_params:
60 60 section, ini_params = data.items()[0]
61 61 key, val = ini_params.items()[0]
62 62 if section not in config:
63 63 config[section] = {}
64 64 config[section][key] = val
65 65
66 66 config.write()
67 67 return self.new_path
68 68
69 69 def __exit__(self, exc_type, exc_val, exc_tb):
70 70 if self.destroy:
71 71 os.remove(self.new_path)
72 72
73 73
74 74 def no_newline_id_generator(test_name):
75 75 """
76 76 Generates a test name without spaces or newlines characters. Used for
77 77 nicer output of progress of test
78 78 """
79 79 org_name = test_name
80 test_name = test_name\
80 test_name = str(test_name)\
81 81 .replace('\n', '_N') \
82 82 .replace('\r', '_N') \
83 83 .replace('\t', '_T') \
84 84 .replace(' ', '_S')
85 85
86 86 return test_name or 'test-with-empty-name'
@@ -1,241 +1,241 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2018 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import contextlib
19 19 import io
20 20 import threading
21 21 from BaseHTTPServer import BaseHTTPRequestHandler
22 22 from SocketServer import TCPServer
23 23
24 24 import mercurial.ui
25 25 import mock
26 26 import pytest
27 27 import simplejson as json
28 28
29 29 from vcsserver import hooks
30 30
31 31
32 32 def get_hg_ui(extras=None):
33 33 """Create a Config object with a valid RC_SCM_DATA entry."""
34 34 extras = extras or {}
35 35 required_extras = {
36 36 'username': '',
37 37 'repository': '',
38 38 'locked_by': '',
39 39 'scm': '',
40 40 'make_lock': '',
41 41 'action': '',
42 42 'ip': '',
43 43 'hooks_uri': 'fake_hooks_uri',
44 44 }
45 45 required_extras.update(extras)
46 46 hg_ui = mercurial.ui.ui()
47 47 hg_ui.setconfig('rhodecode', 'RC_SCM_DATA', json.dumps(required_extras))
48 48
49 49 return hg_ui
50 50
51 51
52 52 def test_git_pre_receive_is_disabled():
53 53 extras = {'hooks': ['pull']}
54 54 response = hooks.git_pre_receive(None, None,
55 55 {'RC_SCM_DATA': json.dumps(extras)})
56 56
57 57 assert response == 0
58 58
59 59
60 60 def test_git_post_receive_is_disabled():
61 61 extras = {'hooks': ['pull']}
62 62 response = hooks.git_post_receive(None, '',
63 63 {'RC_SCM_DATA': json.dumps(extras)})
64 64
65 65 assert response == 0
66 66
67 67
68 68 def test_git_post_receive_calls_repo_size():
69 69 extras = {'hooks': ['push', 'repo_size']}
70 70 with mock.patch.object(hooks, '_call_hook') as call_hook_mock:
71 71 hooks.git_post_receive(
72 72 None, '', {'RC_SCM_DATA': json.dumps(extras)})
73 extras.update({'commit_ids': [],
73 extras.update({'commit_ids': [], 'hook_type': 'post_receive',
74 74 'new_refs': {'bookmarks': [], 'branches': [], 'tags': []}})
75 75 expected_calls = [
76 76 mock.call('repo_size', extras, mock.ANY),
77 77 mock.call('post_push', extras, mock.ANY),
78 78 ]
79 79 assert call_hook_mock.call_args_list == expected_calls
80 80
81 81
82 82 def test_git_post_receive_does_not_call_disabled_repo_size():
83 83 extras = {'hooks': ['push']}
84 84 with mock.patch.object(hooks, '_call_hook') as call_hook_mock:
85 85 hooks.git_post_receive(
86 86 None, '', {'RC_SCM_DATA': json.dumps(extras)})
87 extras.update({'commit_ids': [],
87 extras.update({'commit_ids': [], 'hook_type': 'post_receive',
88 88 'new_refs': {'bookmarks': [], 'branches': [], 'tags': []}})
89 89 expected_calls = [
90 90 mock.call('post_push', extras, mock.ANY)
91 91 ]
92 92 assert call_hook_mock.call_args_list == expected_calls
93 93
94 94
95 95 def test_repo_size_exception_does_not_affect_git_post_receive():
96 96 extras = {'hooks': ['push', 'repo_size']}
97 97 status = 0
98 98
99 99 def side_effect(name, *args, **kwargs):
100 100 if name == 'repo_size':
101 101 raise Exception('Fake exception')
102 102 else:
103 103 return status
104 104
105 105 with mock.patch.object(hooks, '_call_hook') as call_hook_mock:
106 106 call_hook_mock.side_effect = side_effect
107 107 result = hooks.git_post_receive(
108 108 None, '', {'RC_SCM_DATA': json.dumps(extras)})
109 109 assert result == status
110 110
111 111
112 112 def test_git_pre_pull_is_disabled():
113 113 assert hooks.git_pre_pull({'hooks': ['push']}) == hooks.HookResponse(0, '')
114 114
115 115
116 116 def test_git_post_pull_is_disabled():
117 117 assert (
118 118 hooks.git_post_pull({'hooks': ['push']}) == hooks.HookResponse(0, ''))
119 119
120 120
121 121 class TestGetHooksClient(object):
122 122
123 123 def test_returns_http_client_when_protocol_matches(self):
124 124 hooks_uri = 'localhost:8000'
125 125 result = hooks._get_hooks_client({
126 126 'hooks_uri': hooks_uri,
127 127 'hooks_protocol': 'http'
128 128 })
129 129 assert isinstance(result, hooks.HooksHttpClient)
130 130 assert result.hooks_uri == hooks_uri
131 131
132 132 def test_returns_dummy_client_when_hooks_uri_not_specified(self):
133 133 fake_module = mock.Mock()
134 134 import_patcher = mock.patch.object(
135 135 hooks.importlib, 'import_module', return_value=fake_module)
136 136 fake_module_name = 'fake.module'
137 137 with import_patcher as import_mock:
138 138 result = hooks._get_hooks_client(
139 139 {'hooks_module': fake_module_name})
140 140
141 141 import_mock.assert_called_once_with(fake_module_name)
142 142 assert isinstance(result, hooks.HooksDummyClient)
143 143 assert result._hooks_module == fake_module
144 144
145 145
146 146 class TestHooksHttpClient(object):
147 147 def test_init_sets_hooks_uri(self):
148 148 uri = 'localhost:3000'
149 149 client = hooks.HooksHttpClient(uri)
150 150 assert client.hooks_uri == uri
151 151
152 152 def test_serialize_returns_json_string(self):
153 153 client = hooks.HooksHttpClient('localhost:3000')
154 154 hook_name = 'test'
155 155 extras = {
156 156 'first': 1,
157 157 'second': 'two'
158 158 }
159 159 result = client._serialize(hook_name, extras)
160 160 expected_result = json.dumps({
161 161 'method': hook_name,
162 162 'extras': extras
163 163 })
164 164 assert result == expected_result
165 165
166 166 def test_call_queries_http_server(self, http_mirror):
167 167 client = hooks.HooksHttpClient(http_mirror.uri)
168 168 hook_name = 'test'
169 169 extras = {
170 170 'first': 1,
171 171 'second': 'two'
172 172 }
173 173 result = client(hook_name, extras)
174 174 expected_result = {
175 175 'method': hook_name,
176 176 'extras': extras
177 177 }
178 178 assert result == expected_result
179 179
180 180
181 181 class TestHooksDummyClient(object):
182 182 def test_init_imports_hooks_module(self):
183 183 hooks_module_name = 'rhodecode.fake.module'
184 184 hooks_module = mock.MagicMock()
185 185
186 186 import_patcher = mock.patch.object(
187 187 hooks.importlib, 'import_module', return_value=hooks_module)
188 188 with import_patcher as import_mock:
189 189 client = hooks.HooksDummyClient(hooks_module_name)
190 190 import_mock.assert_called_once_with(hooks_module_name)
191 191 assert client._hooks_module == hooks_module
192 192
193 193 def test_call_returns_hook_result(self):
194 194 hooks_module_name = 'rhodecode.fake.module'
195 195 hooks_module = mock.MagicMock()
196 196 import_patcher = mock.patch.object(
197 197 hooks.importlib, 'import_module', return_value=hooks_module)
198 198 with import_patcher:
199 199 client = hooks.HooksDummyClient(hooks_module_name)
200 200
201 201 result = client('post_push', {})
202 202 hooks_module.Hooks.assert_called_once_with()
203 203 assert result == hooks_module.Hooks().__enter__().post_push()
204 204
205 205
206 206 @pytest.fixture
207 207 def http_mirror(request):
208 208 server = MirrorHttpServer()
209 209 request.addfinalizer(server.stop)
210 210 return server
211 211
212 212
213 213 class MirrorHttpHandler(BaseHTTPRequestHandler):
214 214 def do_POST(self):
215 215 length = int(self.headers['Content-Length'])
216 216 body = self.rfile.read(length).decode('utf-8')
217 217 self.send_response(200)
218 218 self.end_headers()
219 219 self.wfile.write(body)
220 220
221 221
222 222 class MirrorHttpServer(object):
223 223 ip_address = '127.0.0.1'
224 224 port = 0
225 225
226 226 def __init__(self):
227 227 self._daemon = TCPServer((self.ip_address, 0), MirrorHttpHandler)
228 228 _, self.port = self._daemon.server_address
229 229 self._thread = threading.Thread(target=self._daemon.serve_forever)
230 230 self._thread.daemon = True
231 231 self._thread.start()
232 232
233 233 def stop(self):
234 234 self._daemon.shutdown()
235 235 self._thread.join()
236 236 self._daemon = None
237 237 self._thread = None
238 238
239 239 @property
240 240 def uri(self):
241 241 return '{}:{}'.format(self.ip_address, self.port)
General Comments 0
You need to be logged in to leave comments. Login now