##// END OF EJS Templates
feat(celery-hooks): added HooksCeleryClient, removed support od HooksDummyClient, updated tests. Fixes: RCCE-55
ilin.s -
r1204:79967b24 default
parent child Browse files
Show More
@@ -1,57 +1,77 b''
1 1 # deps, generated via pipdeptree --exclude setuptools,wheel,pipdeptree,pip -f | tr '[:upper:]' '[:lower:]'
2 2
3 3 async-timeout==4.0.3
4 4 atomicwrites==1.4.1
5 celery==5.3.4
6 billiard==4.1.0
7 click==8.1.3
8 click-didyoumean==0.3.0
9 click==8.1.3
10 click-plugins==1.1.1
11 click==8.1.3
12 click-repl==0.2.0
13 click==8.1.3
14 prompt-toolkit==3.0.38
15 wcwidth==0.2.6
16 six==1.16.0
17 kombu==5.3.2
18 amqp==5.1.1
19 vine==5.1.0
20 vine==5.1.0
21 python-dateutil==2.8.2
22 six==1.16.0
23 tzdata==2023.4
24 vine==5.1.0
5 25 contextlib2==21.6.0
6 26 cov-core==1.15.0
7 27 coverage==7.2.3
8 28 diskcache==5.6.3
9 29 dogpile.cache==1.3.0
10 30 decorator==5.1.1
11 31 stevedore==5.1.0
12 32 pbr==5.11.1
13 33 dulwich==0.21.6
14 34 urllib3==1.26.14
15 35 gunicorn==21.2.0
16 36 packaging==23.1
17 37 hg-evolve==11.0.2
18 38 importlib-metadata==6.0.0
19 39 zipp==3.15.0
20 40 mercurial==6.3.3
21 41 mock==5.0.2
22 42 more-itertools==9.1.0
23 43 msgpack==1.0.7
24 44 orjson==3.9.13
25 45 psutil==5.9.8
26 46 py==1.11.0
27 47 pygit2==1.13.3
28 48 cffi==1.16.0
29 49 pycparser==2.21
30 50 pygments==2.15.1
31 51 pyparsing==3.1.1
32 52 pyramid==2.0.2
33 53 hupper==1.12
34 54 plaster==1.1.2
35 55 plaster-pastedeploy==1.0.1
36 56 pastedeploy==3.1.0
37 57 plaster==1.1.2
38 58 translationstring==1.4
39 59 venusian==3.0.0
40 60 webob==1.8.7
41 61 zope.deprecation==5.0.0
42 62 zope.interface==6.1.0
43 63 redis==5.0.1
44 64 async-timeout==4.0.3
45 65 repoze.lru==0.7
46 66 scandir==1.10.0
47 67 setproctitle==1.3.3
48 68 subvertpy==0.11.0
49 69 waitress==3.0.0
50 70 wcwidth==0.2.6
51 71
52 72
53 73 ## test related requirements
54 74 #-r requirements_test.txt
55 75
56 76 ## uncomment to add the debug libraries
57 77 #-r requirements_debug.txt
@@ -1,785 +1,796 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import io
19 19 import os
20 20 import sys
21 21 import logging
22 22 import collections
23 23 import importlib
24 24 import base64
25 25 import msgpack
26 26 import dataclasses
27 27 import pygit2
28 28
29 29 import http.client
30 from celery import Celery
30 31
31 32
32 33 import mercurial.scmutil
33 34 import mercurial.node
34 35
35 36 from vcsserver.lib.rc_json import json
36 37 from vcsserver import exceptions, subprocessio, settings
37 38 from vcsserver.str_utils import ascii_str, safe_str
38 39 from vcsserver.remote.git_remote import Repository
39 40
41 celery_app = Celery()
40 42 log = logging.getLogger(__name__)
41 43
42 44
43 45 class HooksHttpClient:
44 46 proto = 'msgpack.v1'
45 47 connection = None
46 48
47 49 def __init__(self, hooks_uri):
48 50 self.hooks_uri = hooks_uri
49 51
50 52 def __repr__(self):
51 53 return f'{self.__class__}(hook_uri={self.hooks_uri}, proto={self.proto})'
52 54
53 55 def __call__(self, method, extras):
54 56 connection = http.client.HTTPConnection(self.hooks_uri)
55 57 # binary msgpack body
56 58 headers, body = self._serialize(method, extras)
57 59 log.debug('Doing a new hooks call using HTTPConnection to %s', self.hooks_uri)
58 60
59 61 try:
60 62 try:
61 63 connection.request('POST', '/', body, headers)
62 64 except Exception as error:
63 65 log.error('Hooks calling Connection failed on %s, org error: %s', connection.__dict__, error)
64 66 raise
65 67
66 68 response = connection.getresponse()
67 69 try:
68 70 return msgpack.load(response)
69 71 except Exception:
70 72 response_data = response.read()
71 73 log.exception('Failed to decode hook response json data. '
72 74 'response_code:%s, raw_data:%s',
73 75 response.status, response_data)
74 76 raise
75 77 finally:
76 78 connection.close()
77 79
78 80 @classmethod
79 81 def _serialize(cls, hook_name, extras):
80 82 data = {
81 83 'method': hook_name,
82 84 'extras': extras
83 85 }
84 86 headers = {
85 87 "rc-hooks-protocol": cls.proto,
86 88 "Connection": "keep-alive"
87 89 }
88 90 return headers, msgpack.packb(data)
89 91
90 92
91 class HooksDummyClient:
92 def __init__(self, hooks_module):
93 log.debug('HooksDummyClient import: %s', hooks_module)
94 self._hooks_module = importlib.import_module(hooks_module)
93 class HooksCeleryClient:
94 TASK_TIMEOUT = 60 # time in seconds
95 95
96 def __call__(self, hook_name, extras):
97 with self._hooks_module.Hooks() as hooks:
98 return getattr(hooks, hook_name)(extras)
96 def __init__(self, queue, backend):
97 celery_app.config_from_object({'broker_url': queue, 'result_backend': backend,
98 'broker_connection_retry_on_startup': True,
99 'task_serializer': 'msgpack',
100 'accept_content': ['json', 'msgpack'],
101 'result_serializer': 'msgpack',
102 'result_accept_content': ['json', 'msgpack']
103 })
104 self.celery_app = celery_app
105
106 def __call__(self, method, extras):
107 inquired_task = self.celery_app.signature(
108 f'rhodecode.lib.celerylib.tasks.{method}'
109 )
110 return inquired_task.delay(extras).get(timeout=self.TASK_TIMEOUT)
99 111
100 112
101 113 class HooksShadowRepoClient:
102 114
103 115 def __call__(self, hook_name, extras):
104 116 return {'output': '', 'status': 0}
105 117
106 118
107 119 class RemoteMessageWriter:
108 120 """Writer base class."""
109 121 def write(self, message):
110 122 raise NotImplementedError()
111 123
112 124
113 125 class HgMessageWriter(RemoteMessageWriter):
114 126 """Writer that knows how to send messages to mercurial clients."""
115 127
116 128 def __init__(self, ui):
117 129 self.ui = ui
118 130
119 131 def write(self, message: str):
120 132 # TODO: Check why the quiet flag is set by default.
121 133 old = self.ui.quiet
122 134 self.ui.quiet = False
123 135 self.ui.status(message.encode('utf-8'))
124 136 self.ui.quiet = old
125 137
126 138
127 139 class GitMessageWriter(RemoteMessageWriter):
128 140 """Writer that knows how to send messages to git clients."""
129 141
130 142 def __init__(self, stdout=None):
131 143 self.stdout = stdout or sys.stdout
132 144
133 145 def write(self, message: str):
134 146 self.stdout.write(message)
135 147
136 148
137 149 class SvnMessageWriter(RemoteMessageWriter):
138 150 """Writer that knows how to send messages to svn clients."""
139 151
140 152 def __init__(self, stderr=None):
141 153 # SVN needs data sent to stderr for back-to-client messaging
142 154 self.stderr = stderr or sys.stderr
143 155
144 156 def write(self, message):
145 157 self.stderr.write(message.encode('utf-8'))
146 158
147 159
148 160 def _handle_exception(result):
149 161 exception_class = result.get('exception')
150 162 exception_traceback = result.get('exception_traceback')
151 163 log.debug('Handling hook-call exception: %s', exception_class)
152 164
153 165 if exception_traceback:
154 166 log.error('Got traceback from remote call:%s', exception_traceback)
155 167
156 168 if exception_class == 'HTTPLockedRC':
157 169 raise exceptions.RepositoryLockedException()(*result['exception_args'])
158 170 elif exception_class == 'HTTPBranchProtected':
159 171 raise exceptions.RepositoryBranchProtectedException()(*result['exception_args'])
160 172 elif exception_class == 'RepositoryError':
161 173 raise exceptions.VcsException()(*result['exception_args'])
162 174 elif exception_class:
163 175 raise Exception(
164 176 f"""Got remote exception "{exception_class}" with args "{result['exception_args']}" """
165 177 )
166 178
167 179
168 180 def _get_hooks_client(extras):
169 181 hooks_uri = extras.get('hooks_uri')
182 task_queue = extras.get('task_queue')
183 task_backend = extras.get('task_backend')
170 184 is_shadow_repo = extras.get('is_shadow_repo')
171 185
172 186 if hooks_uri:
173 return HooksHttpClient(extras['hooks_uri'])
187 return HooksHttpClient(hooks_uri)
188 elif task_queue and task_backend:
189 return HooksCeleryClient(task_queue, task_backend)
174 190 elif is_shadow_repo:
175 191 return HooksShadowRepoClient()
176 192 else:
177 try:
178 import_module = extras['hooks_module']
179 except KeyError:
180 log.error('Failed to get "hooks_module" from extras: %s', extras)
181 raise
182 return HooksDummyClient(import_module)
193 raise Exception("Hooks client not found!")
183 194
184 195
185 196 def _call_hook(hook_name, extras, writer):
186 197 hooks_client = _get_hooks_client(extras)
187 198 log.debug('Hooks, using client:%s', hooks_client)
188 199 result = hooks_client(hook_name, extras)
189 200 log.debug('Hooks got result: %s', result)
190 201 _handle_exception(result)
191 202 writer.write(result['output'])
192 203
193 204 return result['status']
194 205
195 206
196 207 def _extras_from_ui(ui):
197 208 hook_data = ui.config(b'rhodecode', b'RC_SCM_DATA')
198 209 if not hook_data:
199 210 # maybe it's inside environ ?
200 211 env_hook_data = os.environ.get('RC_SCM_DATA')
201 212 if env_hook_data:
202 213 hook_data = env_hook_data
203 214
204 215 extras = {}
205 216 if hook_data:
206 217 extras = json.loads(hook_data)
207 218 return extras
208 219
209 220
210 221 def _rev_range_hash(repo, node, check_heads=False):
211 222 from vcsserver.hgcompat import get_ctx
212 223
213 224 commits = []
214 225 revs = []
215 226 start = get_ctx(repo, node).rev()
216 227 end = len(repo)
217 228 for rev in range(start, end):
218 229 revs.append(rev)
219 230 ctx = get_ctx(repo, rev)
220 231 commit_id = ascii_str(mercurial.node.hex(ctx.node()))
221 232 branch = safe_str(ctx.branch())
222 233 commits.append((commit_id, branch))
223 234
224 235 parent_heads = []
225 236 if check_heads:
226 237 parent_heads = _check_heads(repo, start, end, revs)
227 238 return commits, parent_heads
228 239
229 240
230 241 def _check_heads(repo, start, end, commits):
231 242 from vcsserver.hgcompat import get_ctx
232 243 changelog = repo.changelog
233 244 parents = set()
234 245
235 246 for new_rev in commits:
236 247 for p in changelog.parentrevs(new_rev):
237 248 if p == mercurial.node.nullrev:
238 249 continue
239 250 if p < start:
240 251 parents.add(p)
241 252
242 253 for p in parents:
243 254 branch = get_ctx(repo, p).branch()
244 255 # The heads descending from that parent, on the same branch
245 256 parent_heads = {p}
246 257 reachable = {p}
247 258 for x in range(p + 1, end):
248 259 if get_ctx(repo, x).branch() != branch:
249 260 continue
250 261 for pp in changelog.parentrevs(x):
251 262 if pp in reachable:
252 263 reachable.add(x)
253 264 parent_heads.discard(pp)
254 265 parent_heads.add(x)
255 266 # More than one head? Suggest merging
256 267 if len(parent_heads) > 1:
257 268 return list(parent_heads)
258 269
259 270 return []
260 271
261 272
262 273 def _get_git_env():
263 274 env = {}
264 275 for k, v in os.environ.items():
265 276 if k.startswith('GIT'):
266 277 env[k] = v
267 278
268 279 # serialized version
269 280 return [(k, v) for k, v in env.items()]
270 281
271 282
272 283 def _get_hg_env(old_rev, new_rev, txnid, repo_path):
273 284 env = {}
274 285 for k, v in os.environ.items():
275 286 if k.startswith('HG'):
276 287 env[k] = v
277 288
278 289 env['HG_NODE'] = old_rev
279 290 env['HG_NODE_LAST'] = new_rev
280 291 env['HG_TXNID'] = txnid
281 292 env['HG_PENDING'] = repo_path
282 293
283 294 return [(k, v) for k, v in env.items()]
284 295
285 296
286 297 def repo_size(ui, repo, **kwargs):
287 298 extras = _extras_from_ui(ui)
288 299 return _call_hook('repo_size', extras, HgMessageWriter(ui))
289 300
290 301
291 302 def pre_pull(ui, repo, **kwargs):
292 303 extras = _extras_from_ui(ui)
293 304 return _call_hook('pre_pull', extras, HgMessageWriter(ui))
294 305
295 306
296 307 def pre_pull_ssh(ui, repo, **kwargs):
297 308 extras = _extras_from_ui(ui)
298 309 if extras and extras.get('SSH'):
299 310 return pre_pull(ui, repo, **kwargs)
300 311 return 0
301 312
302 313
303 314 def post_pull(ui, repo, **kwargs):
304 315 extras = _extras_from_ui(ui)
305 316 return _call_hook('post_pull', extras, HgMessageWriter(ui))
306 317
307 318
308 319 def post_pull_ssh(ui, repo, **kwargs):
309 320 extras = _extras_from_ui(ui)
310 321 if extras and extras.get('SSH'):
311 322 return post_pull(ui, repo, **kwargs)
312 323 return 0
313 324
314 325
315 326 def pre_push(ui, repo, node=None, **kwargs):
316 327 """
317 328 Mercurial pre_push hook
318 329 """
319 330 extras = _extras_from_ui(ui)
320 331 detect_force_push = extras.get('detect_force_push')
321 332
322 333 rev_data = []
323 334 hook_type: str = safe_str(kwargs.get('hooktype'))
324 335
325 336 if node and hook_type == 'pretxnchangegroup':
326 337 branches = collections.defaultdict(list)
327 338 commits, _heads = _rev_range_hash(repo, node, check_heads=detect_force_push)
328 339 for commit_id, branch in commits:
329 340 branches[branch].append(commit_id)
330 341
331 342 for branch, commits in branches.items():
332 343 old_rev = ascii_str(kwargs.get('node_last')) or commits[0]
333 344 rev_data.append({
334 345 'total_commits': len(commits),
335 346 'old_rev': old_rev,
336 347 'new_rev': commits[-1],
337 348 'ref': '',
338 349 'type': 'branch',
339 350 'name': branch,
340 351 })
341 352
342 353 for push_ref in rev_data:
343 354 push_ref['multiple_heads'] = _heads
344 355
345 356 repo_path = os.path.join(
346 357 extras.get('repo_store', ''), extras.get('repository', ''))
347 358 push_ref['hg_env'] = _get_hg_env(
348 359 old_rev=push_ref['old_rev'],
349 360 new_rev=push_ref['new_rev'], txnid=ascii_str(kwargs.get('txnid')),
350 361 repo_path=repo_path)
351 362
352 363 extras['hook_type'] = hook_type or 'pre_push'
353 364 extras['commit_ids'] = rev_data
354 365
355 366 return _call_hook('pre_push', extras, HgMessageWriter(ui))
356 367
357 368
358 369 def pre_push_ssh(ui, repo, node=None, **kwargs):
359 370 extras = _extras_from_ui(ui)
360 371 if extras.get('SSH'):
361 372 return pre_push(ui, repo, node, **kwargs)
362 373
363 374 return 0
364 375
365 376
366 377 def pre_push_ssh_auth(ui, repo, node=None, **kwargs):
367 378 """
368 379 Mercurial pre_push hook for SSH
369 380 """
370 381 extras = _extras_from_ui(ui)
371 382 if extras.get('SSH'):
372 383 permission = extras['SSH_PERMISSIONS']
373 384
374 385 if 'repository.write' == permission or 'repository.admin' == permission:
375 386 return 0
376 387
377 388 # non-zero ret code
378 389 return 1
379 390
380 391 return 0
381 392
382 393
383 394 def post_push(ui, repo, node, **kwargs):
384 395 """
385 396 Mercurial post_push hook
386 397 """
387 398 extras = _extras_from_ui(ui)
388 399
389 400 commit_ids = []
390 401 branches = []
391 402 bookmarks = []
392 403 tags = []
393 404 hook_type: str = safe_str(kwargs.get('hooktype'))
394 405
395 406 commits, _heads = _rev_range_hash(repo, node)
396 407 for commit_id, branch in commits:
397 408 commit_ids.append(commit_id)
398 409 if branch not in branches:
399 410 branches.append(branch)
400 411
401 412 if hasattr(ui, '_rc_pushkey_bookmarks'):
402 413 bookmarks = ui._rc_pushkey_bookmarks
403 414
404 415 extras['hook_type'] = hook_type or 'post_push'
405 416 extras['commit_ids'] = commit_ids
406 417
407 418 extras['new_refs'] = {
408 419 'branches': branches,
409 420 'bookmarks': bookmarks,
410 421 'tags': tags
411 422 }
412 423
413 424 return _call_hook('post_push', extras, HgMessageWriter(ui))
414 425
415 426
416 427 def post_push_ssh(ui, repo, node, **kwargs):
417 428 """
418 429 Mercurial post_push hook for SSH
419 430 """
420 431 if _extras_from_ui(ui).get('SSH'):
421 432 return post_push(ui, repo, node, **kwargs)
422 433 return 0
423 434
424 435
425 436 def key_push(ui, repo, **kwargs):
426 437 from vcsserver.hgcompat import get_ctx
427 438
428 439 if kwargs['new'] != b'0' and kwargs['namespace'] == b'bookmarks':
429 440 # store new bookmarks in our UI object propagated later to post_push
430 441 ui._rc_pushkey_bookmarks = get_ctx(repo, kwargs['key']).bookmarks()
431 442 return
432 443
433 444
434 445 # backward compat
435 446 log_pull_action = post_pull
436 447
437 448 # backward compat
438 449 log_push_action = post_push
439 450
440 451
441 452 def handle_git_pre_receive(unused_repo_path, unused_revs, unused_env):
442 453 """
443 454 Old hook name: keep here for backward compatibility.
444 455
445 456 This is only required when the installed git hooks are not upgraded.
446 457 """
447 458 pass
448 459
449 460
450 461 def handle_git_post_receive(unused_repo_path, unused_revs, unused_env):
451 462 """
452 463 Old hook name: keep here for backward compatibility.
453 464
454 465 This is only required when the installed git hooks are not upgraded.
455 466 """
456 467 pass
457 468
458 469
459 470 @dataclasses.dataclass
460 471 class HookResponse:
461 472 status: int
462 473 output: str
463 474
464 475
465 476 def git_pre_pull(extras) -> HookResponse:
466 477 """
467 478 Pre pull hook.
468 479
469 480 :param extras: dictionary containing the keys defined in simplevcs
470 481 :type extras: dict
471 482
472 483 :return: status code of the hook. 0 for success.
473 484 :rtype: int
474 485 """
475 486
476 487 if 'pull' not in extras['hooks']:
477 488 return HookResponse(0, '')
478 489
479 490 stdout = io.StringIO()
480 491 try:
481 492 status_code = _call_hook('pre_pull', extras, GitMessageWriter(stdout))
482 493
483 494 except Exception as error:
484 495 log.exception('Failed to call pre_pull hook')
485 496 status_code = 128
486 497 stdout.write(f'ERROR: {error}\n')
487 498
488 499 return HookResponse(status_code, stdout.getvalue())
489 500
490 501
491 502 def git_post_pull(extras) -> HookResponse:
492 503 """
493 504 Post pull hook.
494 505
495 506 :param extras: dictionary containing the keys defined in simplevcs
496 507 :type extras: dict
497 508
498 509 :return: status code of the hook. 0 for success.
499 510 :rtype: int
500 511 """
501 512 if 'pull' not in extras['hooks']:
502 513 return HookResponse(0, '')
503 514
504 515 stdout = io.StringIO()
505 516 try:
506 517 status = _call_hook('post_pull', extras, GitMessageWriter(stdout))
507 518 except Exception as error:
508 519 status = 128
509 520 stdout.write(f'ERROR: {error}\n')
510 521
511 522 return HookResponse(status, stdout.getvalue())
512 523
513 524
514 525 def _parse_git_ref_lines(revision_lines):
515 526 rev_data = []
516 527 for revision_line in revision_lines or []:
517 528 old_rev, new_rev, ref = revision_line.strip().split(' ')
518 529 ref_data = ref.split('/', 2)
519 530 if ref_data[1] in ('tags', 'heads'):
520 531 rev_data.append({
521 532 # NOTE(marcink):
522 533 # we're unable to tell total_commits for git at this point
523 534 # but we set the variable for consistency with GIT
524 535 'total_commits': -1,
525 536 'old_rev': old_rev,
526 537 'new_rev': new_rev,
527 538 'ref': ref,
528 539 'type': ref_data[1],
529 540 'name': ref_data[2],
530 541 })
531 542 return rev_data
532 543
533 544
534 545 def git_pre_receive(unused_repo_path, revision_lines, env) -> int:
535 546 """
536 547 Pre push hook.
537 548
538 549 :return: status code of the hook. 0 for success.
539 550 """
540 551 extras = json.loads(env['RC_SCM_DATA'])
541 552 rev_data = _parse_git_ref_lines(revision_lines)
542 553 if 'push' not in extras['hooks']:
543 554 return 0
544 555 empty_commit_id = '0' * 40
545 556
546 557 detect_force_push = extras.get('detect_force_push')
547 558
548 559 for push_ref in rev_data:
549 560 # store our git-env which holds the temp store
550 561 push_ref['git_env'] = _get_git_env()
551 562 push_ref['pruned_sha'] = ''
552 563 if not detect_force_push:
553 564 # don't check for forced-push when we don't need to
554 565 continue
555 566
556 567 type_ = push_ref['type']
557 568 new_branch = push_ref['old_rev'] == empty_commit_id
558 569 delete_branch = push_ref['new_rev'] == empty_commit_id
559 570 if type_ == 'heads' and not (new_branch or delete_branch):
560 571 old_rev = push_ref['old_rev']
561 572 new_rev = push_ref['new_rev']
562 573 cmd = [settings.GIT_EXECUTABLE, 'rev-list', old_rev, f'^{new_rev}']
563 574 stdout, stderr = subprocessio.run_command(
564 575 cmd, env=os.environ.copy())
565 576 # means we're having some non-reachable objects, this forced push was used
566 577 if stdout:
567 578 push_ref['pruned_sha'] = stdout.splitlines()
568 579
569 580 extras['hook_type'] = 'pre_receive'
570 581 extras['commit_ids'] = rev_data
571 582
572 583 stdout = sys.stdout
573 584 status_code = _call_hook('pre_push', extras, GitMessageWriter(stdout))
574 585
575 586 return status_code
576 587
577 588
578 589 def git_post_receive(unused_repo_path, revision_lines, env) -> int:
579 590 """
580 591 Post push hook.
581 592
582 593 :return: status code of the hook. 0 for success.
583 594 """
584 595 extras = json.loads(env['RC_SCM_DATA'])
585 596 if 'push' not in extras['hooks']:
586 597 return 0
587 598
588 599 rev_data = _parse_git_ref_lines(revision_lines)
589 600
590 601 git_revs = []
591 602
592 603 # N.B.(skreft): it is ok to just call git, as git before calling a
593 604 # subcommand sets the PATH environment variable so that it point to the
594 605 # correct version of the git executable.
595 606 empty_commit_id = '0' * 40
596 607 branches = []
597 608 tags = []
598 609 for push_ref in rev_data:
599 610 type_ = push_ref['type']
600 611
601 612 if type_ == 'heads':
602 613 # starting new branch case
603 614 if push_ref['old_rev'] == empty_commit_id:
604 615 push_ref_name = push_ref['name']
605 616
606 617 if push_ref_name not in branches:
607 618 branches.append(push_ref_name)
608 619
609 620 need_head_set = ''
610 621 with Repository(os.getcwd()) as repo:
611 622 try:
612 623 repo.head
613 624 except pygit2.GitError:
614 625 need_head_set = f'refs/heads/{push_ref_name}'
615 626
616 627 if need_head_set:
617 628 repo.set_head(need_head_set)
618 629 print(f"Setting default branch to {push_ref_name}")
619 630
620 631 cmd = [settings.GIT_EXECUTABLE, 'for-each-ref', '--format=%(refname)', 'refs/heads/*']
621 632 stdout, stderr = subprocessio.run_command(
622 633 cmd, env=os.environ.copy())
623 634 heads = safe_str(stdout)
624 635 heads = heads.replace(push_ref['ref'], '')
625 636 heads = ' '.join(head for head
626 637 in heads.splitlines() if head) or '.'
627 638 cmd = [settings.GIT_EXECUTABLE, 'log', '--reverse',
628 639 '--pretty=format:%H', '--', push_ref['new_rev'],
629 640 '--not', heads]
630 641 stdout, stderr = subprocessio.run_command(
631 642 cmd, env=os.environ.copy())
632 643 git_revs.extend(list(map(ascii_str, stdout.splitlines())))
633 644
634 645 # delete branch case
635 646 elif push_ref['new_rev'] == empty_commit_id:
636 647 git_revs.append(f'delete_branch=>{push_ref["name"]}')
637 648 else:
638 649 if push_ref['name'] not in branches:
639 650 branches.append(push_ref['name'])
640 651
641 652 cmd = [settings.GIT_EXECUTABLE, 'log',
642 653 f'{push_ref["old_rev"]}..{push_ref["new_rev"]}',
643 654 '--reverse', '--pretty=format:%H']
644 655 stdout, stderr = subprocessio.run_command(
645 656 cmd, env=os.environ.copy())
646 657 # we get bytes from stdout, we need str to be consistent
647 658 log_revs = list(map(ascii_str, stdout.splitlines()))
648 659 git_revs.extend(log_revs)
649 660
650 661 # Pure pygit2 impl. but still 2-3x slower :/
651 662 # results = []
652 663 #
653 664 # with Repository(os.getcwd()) as repo:
654 665 # repo_new_rev = repo[push_ref['new_rev']]
655 666 # repo_old_rev = repo[push_ref['old_rev']]
656 667 # walker = repo.walk(repo_new_rev.id, pygit2.GIT_SORT_TOPOLOGICAL)
657 668 #
658 669 # for commit in walker:
659 670 # if commit.id == repo_old_rev.id:
660 671 # break
661 672 # results.append(commit.id.hex)
662 673 # # reverse the order, can't use GIT_SORT_REVERSE
663 674 # log_revs = results[::-1]
664 675
665 676 elif type_ == 'tags':
666 677 if push_ref['name'] not in tags:
667 678 tags.append(push_ref['name'])
668 679 git_revs.append(f'tag=>{push_ref["name"]}')
669 680
670 681 extras['hook_type'] = 'post_receive'
671 682 extras['commit_ids'] = git_revs
672 683 extras['new_refs'] = {
673 684 'branches': branches,
674 685 'bookmarks': [],
675 686 'tags': tags,
676 687 }
677 688
678 689 stdout = sys.stdout
679 690
680 691 if 'repo_size' in extras['hooks']:
681 692 try:
682 693 _call_hook('repo_size', extras, GitMessageWriter(stdout))
683 694 except Exception:
684 695 pass
685 696
686 697 status_code = _call_hook('post_push', extras, GitMessageWriter(stdout))
687 698 return status_code
688 699
689 700
690 701 def _get_extras_from_txn_id(path, txn_id):
691 702 extras = {}
692 703 try:
693 704 cmd = [settings.SVNLOOK_EXECUTABLE, 'pget',
694 705 '-t', txn_id,
695 706 '--revprop', path, 'rc-scm-extras']
696 707 stdout, stderr = subprocessio.run_command(
697 708 cmd, env=os.environ.copy())
698 709 extras = json.loads(base64.urlsafe_b64decode(stdout))
699 710 except Exception:
700 711 log.exception('Failed to extract extras info from txn_id')
701 712
702 713 return extras
703 714
704 715
705 716 def _get_extras_from_commit_id(commit_id, path):
706 717 extras = {}
707 718 try:
708 719 cmd = [settings.SVNLOOK_EXECUTABLE, 'pget',
709 720 '-r', commit_id,
710 721 '--revprop', path, 'rc-scm-extras']
711 722 stdout, stderr = subprocessio.run_command(
712 723 cmd, env=os.environ.copy())
713 724 extras = json.loads(base64.urlsafe_b64decode(stdout))
714 725 except Exception:
715 726 log.exception('Failed to extract extras info from commit_id')
716 727
717 728 return extras
718 729
719 730
720 731 def svn_pre_commit(repo_path, commit_data, env):
721 732 path, txn_id = commit_data
722 733 branches = []
723 734 tags = []
724 735
725 736 if env.get('RC_SCM_DATA'):
726 737 extras = json.loads(env['RC_SCM_DATA'])
727 738 else:
728 739 # fallback method to read from TXN-ID stored data
729 740 extras = _get_extras_from_txn_id(path, txn_id)
730 741 if not extras:
731 742 return 0
732 743
733 744 extras['hook_type'] = 'pre_commit'
734 745 extras['commit_ids'] = [txn_id]
735 746 extras['txn_id'] = txn_id
736 747 extras['new_refs'] = {
737 748 'total_commits': 1,
738 749 'branches': branches,
739 750 'bookmarks': [],
740 751 'tags': tags,
741 752 }
742 753
743 754 return _call_hook('pre_push', extras, SvnMessageWriter())
744 755
745 756
746 757 def svn_post_commit(repo_path, commit_data, env):
747 758 """
748 759 commit_data is path, rev, txn_id
749 760 """
750 761 if len(commit_data) == 3:
751 762 path, commit_id, txn_id = commit_data
752 763 elif len(commit_data) == 2:
753 764 log.error('Failed to extract txn_id from commit_data using legacy method. '
754 765 'Some functionality might be limited')
755 766 path, commit_id = commit_data
756 767 txn_id = None
757 768
758 769 branches = []
759 770 tags = []
760 771
761 772 if env.get('RC_SCM_DATA'):
762 773 extras = json.loads(env['RC_SCM_DATA'])
763 774 else:
764 775 # fallback method to read from TXN-ID stored data
765 776 extras = _get_extras_from_commit_id(commit_id, path)
766 777 if not extras:
767 778 return 0
768 779
769 780 extras['hook_type'] = 'post_commit'
770 781 extras['commit_ids'] = [commit_id]
771 782 extras['txn_id'] = txn_id
772 783 extras['new_refs'] = {
773 784 'branches': branches,
774 785 'bookmarks': [],
775 786 'tags': tags,
776 787 'total_commits': 1,
777 788 }
778 789
779 790 if 'repo_size' in extras['hooks']:
780 791 try:
781 792 _call_hook('repo_size', extras, SvnMessageWriter())
782 793 except Exception:
783 794 pass
784 795
785 796 return _call_hook('post_push', extras, SvnMessageWriter())
@@ -1,286 +1,257 b''
1 1 # RhodeCode VCSServer provides access to different vcs backends via network.
2 2 # Copyright (C) 2014-2023 RhodeCode GmbH
3 3 #
4 4 # This program is free software; you can redistribute it and/or modify
5 5 # it under the terms of the GNU General Public License as published by
6 6 # the Free Software Foundation; either version 3 of the License, or
7 7 # (at your option) any later version.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU General Public License
15 15 # along with this program; if not, write to the Free Software Foundation,
16 16 # Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 17
18 18 import threading
19 19 import msgpack
20 20
21 21 from http.server import BaseHTTPRequestHandler
22 22 from socketserver import TCPServer
23 23
24 24 import mercurial.ui
25 25 import mock
26 26 import pytest
27 27
28 28 from vcsserver.hooks import HooksHttpClient
29 29 from vcsserver.lib.rc_json import json
30 30 from vcsserver import hooks
31 31
32 32
33 33 def get_hg_ui(extras=None):
34 34 """Create a Config object with a valid RC_SCM_DATA entry."""
35 35 extras = extras or {}
36 36 required_extras = {
37 37 'username': '',
38 38 'repository': '',
39 39 'locked_by': '',
40 40 'scm': '',
41 41 'make_lock': '',
42 42 'action': '',
43 43 'ip': '',
44 44 'hooks_uri': 'fake_hooks_uri',
45 45 }
46 46 required_extras.update(extras)
47 47 hg_ui = mercurial.ui.ui()
48 48 hg_ui.setconfig(b'rhodecode', b'RC_SCM_DATA', json.dumps(required_extras))
49 49
50 50 return hg_ui
51 51
52 52
53 53 def test_git_pre_receive_is_disabled():
54 54 extras = {'hooks': ['pull']}
55 55 response = hooks.git_pre_receive(None, None,
56 56 {'RC_SCM_DATA': json.dumps(extras)})
57 57
58 58 assert response == 0
59 59
60 60
61 61 def test_git_post_receive_is_disabled():
62 62 extras = {'hooks': ['pull']}
63 63 response = hooks.git_post_receive(None, '',
64 64 {'RC_SCM_DATA': json.dumps(extras)})
65 65
66 66 assert response == 0
67 67
68 68
69 69 def test_git_post_receive_calls_repo_size():
70 70 extras = {'hooks': ['push', 'repo_size']}
71 71
72 72 with mock.patch.object(hooks, '_call_hook') as call_hook_mock:
73 73 hooks.git_post_receive(
74 74 None, '', {'RC_SCM_DATA': json.dumps(extras)})
75 75 extras.update({'commit_ids': [], 'hook_type': 'post_receive',
76 76 'new_refs': {'bookmarks': [], 'branches': [], 'tags': []}})
77 77 expected_calls = [
78 78 mock.call('repo_size', extras, mock.ANY),
79 79 mock.call('post_push', extras, mock.ANY),
80 80 ]
81 81 assert call_hook_mock.call_args_list == expected_calls
82 82
83 83
84 84 def test_git_post_receive_does_not_call_disabled_repo_size():
85 85 extras = {'hooks': ['push']}
86 86
87 87 with mock.patch.object(hooks, '_call_hook') as call_hook_mock:
88 88 hooks.git_post_receive(
89 89 None, '', {'RC_SCM_DATA': json.dumps(extras)})
90 90 extras.update({'commit_ids': [], 'hook_type': 'post_receive',
91 91 'new_refs': {'bookmarks': [], 'branches': [], 'tags': []}})
92 92 expected_calls = [
93 93 mock.call('post_push', extras, mock.ANY)
94 94 ]
95 95 assert call_hook_mock.call_args_list == expected_calls
96 96
97 97
98 98 def test_repo_size_exception_does_not_affect_git_post_receive():
99 99 extras = {'hooks': ['push', 'repo_size']}
100 100 status = 0
101 101
102 102 def side_effect(name, *args, **kwargs):
103 103 if name == 'repo_size':
104 104 raise Exception('Fake exception')
105 105 else:
106 106 return status
107 107
108 108 with mock.patch.object(hooks, '_call_hook') as call_hook_mock:
109 109 call_hook_mock.side_effect = side_effect
110 110 result = hooks.git_post_receive(
111 111 None, '', {'RC_SCM_DATA': json.dumps(extras)})
112 112 assert result == status
113 113
114 114
115 115 def test_git_pre_pull_is_disabled():
116 116 assert hooks.git_pre_pull({'hooks': ['push']}) == hooks.HookResponse(0, '')
117 117
118 118
119 119 def test_git_post_pull_is_disabled():
120 120 assert (
121 121 hooks.git_post_pull({'hooks': ['push']}) == hooks.HookResponse(0, ''))
122 122
123 123
124 124 class TestGetHooksClient:
125 125
126 126 def test_returns_http_client_when_protocol_matches(self):
127 127 hooks_uri = 'localhost:8000'
128 128 result = hooks._get_hooks_client({
129 129 'hooks_uri': hooks_uri,
130 130 'hooks_protocol': 'http'
131 131 })
132 132 assert isinstance(result, hooks.HooksHttpClient)
133 133 assert result.hooks_uri == hooks_uri
134 134
135 def test_returns_dummy_client_when_hooks_uri_not_specified(self):
136 fake_module = mock.Mock()
137 import_patcher = mock.patch.object(
138 hooks.importlib, 'import_module', return_value=fake_module)
139 fake_module_name = 'fake.module'
140 with import_patcher as import_mock:
141 result = hooks._get_hooks_client(
142 {'hooks_module': fake_module_name})
143
144 import_mock.assert_called_once_with(fake_module_name)
145 assert isinstance(result, hooks.HooksDummyClient)
146 assert result._hooks_module == fake_module
135 def test_return_celery_client_when_queue_and_backend_provided(self):
136 task_queue = 'redis://task_queue:0'
137 task_backend = task_queue
138 result = hooks._get_hooks_client({
139 'task_queue': task_queue,
140 'task_backend': task_backend
141 })
142 assert isinstance(result, hooks.HooksCeleryClient)
147 143
148 144
149 145 class TestHooksHttpClient:
150 146 def test_init_sets_hooks_uri(self):
151 147 uri = 'localhost:3000'
152 148 client = hooks.HooksHttpClient(uri)
153 149 assert client.hooks_uri == uri
154 150
155 151 def test_serialize_returns_serialized_string(self):
156 152 client = hooks.HooksHttpClient('localhost:3000')
157 153 hook_name = 'test'
158 154 extras = {
159 155 'first': 1,
160 156 'second': 'two'
161 157 }
162 158 hooks_proto, result = client._serialize(hook_name, extras)
163 159 expected_result = msgpack.packb({
164 160 'method': hook_name,
165 161 'extras': extras,
166 162 })
167 163 assert hooks_proto == {'rc-hooks-protocol': 'msgpack.v1', 'Connection': 'keep-alive'}
168 164 assert result == expected_result
169 165
170 166 def test_call_queries_http_server(self, http_mirror):
171 167 client = hooks.HooksHttpClient(http_mirror.uri)
172 168 hook_name = 'test'
173 169 extras = {
174 170 'first': 1,
175 171 'second': 'two'
176 172 }
177 173 result = client(hook_name, extras)
178 174 expected_result = msgpack.unpackb(msgpack.packb({
179 175 'method': hook_name,
180 176 'extras': extras
181 177 }), raw=False)
182 178 assert result == expected_result
183 179
184 180
185 class TestHooksDummyClient:
186 def test_init_imports_hooks_module(self):
187 hooks_module_name = 'rhodecode.fake.module'
188 hooks_module = mock.MagicMock()
189
190 import_patcher = mock.patch.object(
191 hooks.importlib, 'import_module', return_value=hooks_module)
192 with import_patcher as import_mock:
193 client = hooks.HooksDummyClient(hooks_module_name)
194 import_mock.assert_called_once_with(hooks_module_name)
195 assert client._hooks_module == hooks_module
196
197 def test_call_returns_hook_result(self):
198 hooks_module_name = 'rhodecode.fake.module'
199 hooks_module = mock.MagicMock()
200 import_patcher = mock.patch.object(
201 hooks.importlib, 'import_module', return_value=hooks_module)
202 with import_patcher:
203 client = hooks.HooksDummyClient(hooks_module_name)
204
205 result = client('post_push', {})
206 hooks_module.Hooks.assert_called_once_with()
207 assert result == hooks_module.Hooks().__enter__().post_push()
208
209
210 181 @pytest.fixture
211 182 def http_mirror(request):
212 183 server = MirrorHttpServer()
213 184 request.addfinalizer(server.stop)
214 185 return server
215 186
216 187
217 188 class MirrorHttpHandler(BaseHTTPRequestHandler):
218 189
219 190 def do_POST(self):
220 191 length = int(self.headers['Content-Length'])
221 192 body = self.rfile.read(length)
222 193 self.send_response(200)
223 194 self.end_headers()
224 195 self.wfile.write(body)
225 196
226 197
227 198 class MirrorHttpServer:
228 199 ip_address = '127.0.0.1'
229 200 port = 0
230 201
231 202 def __init__(self):
232 203 self._daemon = TCPServer((self.ip_address, 0), MirrorHttpHandler)
233 204 _, self.port = self._daemon.server_address
234 205 self._thread = threading.Thread(target=self._daemon.serve_forever)
235 206 self._thread.daemon = True
236 207 self._thread.start()
237 208
238 209 def stop(self):
239 210 self._daemon.shutdown()
240 211 self._thread.join()
241 212 self._daemon = None
242 213 self._thread = None
243 214
244 215 @property
245 216 def uri(self):
246 217 return '{}:{}'.format(self.ip_address, self.port)
247 218
248 219
249 220 def test_hooks_http_client_init():
250 221 hooks_uri = 'http://localhost:8000'
251 222 client = HooksHttpClient(hooks_uri)
252 223 assert client.hooks_uri == hooks_uri
253 224
254 225
255 226 def test_hooks_http_client_call():
256 227 hooks_uri = 'http://localhost:8000'
257 228
258 229 method = 'test_method'
259 230 extras = {'key': 'value'}
260 231
261 232 with \
262 233 mock.patch('http.client.HTTPConnection') as mock_connection,\
263 234 mock.patch('msgpack.load') as mock_load:
264 235
265 236 client = HooksHttpClient(hooks_uri)
266 237
267 238 mock_load.return_value = {'result': 'success'}
268 239 response = mock.MagicMock()
269 240 response.status = 200
270 241 mock_connection.request.side_effect = None
271 242 mock_connection.getresponse.return_value = response
272 243
273 244 result = client(method, extras)
274 245
275 246 mock_connection.assert_called_with(hooks_uri)
276 247 mock_connection.return_value.request.assert_called_once()
277 248 assert result == {'result': 'success'}
278 249
279 250
280 251 def test_hooks_http_client_serialize():
281 252 method = 'test_method'
282 253 extras = {'key': 'value'}
283 254 headers, body = HooksHttpClient._serialize(method, extras)
284 255
285 256 assert headers == {'rc-hooks-protocol': HooksHttpClient.proto, 'Connection': 'keep-alive'}
286 257 assert msgpack.unpackb(body) == {'method': method, 'extras': extras}
General Comments 0
You need to be logged in to leave comments. Login now