##// END OF EJS Templates
code: code cleanups, use is None instead of == None.
marcink -
r3231:462664f1 default
parent child Browse files
Show More
@@ -1,781 +1,780 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2018 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21
22 22 import logging
23 23 import collections
24 24
25 25 import datetime
26 26 import formencode
27 27 import formencode.htmlfill
28 28
29 29 import rhodecode
30 30 from pyramid.view import view_config
31 31 from pyramid.httpexceptions import HTTPFound, HTTPNotFound
32 32 from pyramid.renderers import render
33 33 from pyramid.response import Response
34 34
35 35 from rhodecode.apps._base import BaseAppView
36 36 from rhodecode.apps.admin.navigation import navigation_list
37 37 from rhodecode.apps.svn_support.config_keys import generate_config
38 38 from rhodecode.lib import helpers as h
39 39 from rhodecode.lib.auth import (
40 40 LoginRequired, HasPermissionAllDecorator, CSRFRequired)
41 41 from rhodecode.lib.celerylib import tasks, run_task
42 42 from rhodecode.lib.utils import repo2db_mapper
43 43 from rhodecode.lib.utils2 import str2bool, safe_unicode, AttributeDict
44 44 from rhodecode.lib.index import searcher_from_config
45 45
46 46 from rhodecode.model.db import RhodeCodeUi, Repository
47 47 from rhodecode.model.forms import (ApplicationSettingsForm,
48 48 ApplicationUiSettingsForm, ApplicationVisualisationForm,
49 49 LabsSettingsForm, IssueTrackerPatternsForm)
50 50 from rhodecode.model.repo_group import RepoGroupModel
51 51
52 52 from rhodecode.model.scm import ScmModel
53 53 from rhodecode.model.notification import EmailNotificationModel
54 54 from rhodecode.model.meta import Session
55 55 from rhodecode.model.settings import (
56 56 IssueTrackerSettingsModel, VcsSettingsModel, SettingNotFound,
57 57 SettingsModel)
58 58
59 59
60 60 log = logging.getLogger(__name__)
61 61
62 62
63 63 class AdminSettingsView(BaseAppView):
64 64
65 65 def load_default_context(self):
66 66 c = self._get_local_tmpl_context()
67 67 c.labs_active = str2bool(
68 68 rhodecode.CONFIG.get('labs_settings_active', 'true'))
69 69 c.navlist = navigation_list(self.request)
70 70
71 71 return c
72 72
73 73 @classmethod
74 74 def _get_ui_settings(cls):
75 75 ret = RhodeCodeUi.query().all()
76 76
77 77 if not ret:
78 78 raise Exception('Could not get application ui settings !')
79 79 settings = {}
80 80 for each in ret:
81 81 k = each.ui_key
82 82 v = each.ui_value
83 83 if k == '/':
84 84 k = 'root_path'
85 85
86 86 if k in ['push_ssl', 'publish', 'enabled']:
87 87 v = str2bool(v)
88 88
89 89 if k.find('.') != -1:
90 90 k = k.replace('.', '_')
91 91
92 92 if each.ui_section in ['hooks', 'extensions']:
93 93 v = each.ui_active
94 94
95 95 settings[each.ui_section + '_' + k] = v
96 96 return settings
97 97
98 98 @classmethod
99 99 def _form_defaults(cls):
100 100 defaults = SettingsModel().get_all_settings()
101 101 defaults.update(cls._get_ui_settings())
102 102
103 103 defaults.update({
104 104 'new_svn_branch': '',
105 105 'new_svn_tag': '',
106 106 })
107 107 return defaults
108 108
109 109 @LoginRequired()
110 110 @HasPermissionAllDecorator('hg.admin')
111 111 @view_config(
112 112 route_name='admin_settings_vcs', request_method='GET',
113 113 renderer='rhodecode:templates/admin/settings/settings.mako')
114 114 def settings_vcs(self):
115 115 c = self.load_default_context()
116 116 c.active = 'vcs'
117 117 model = VcsSettingsModel()
118 118 c.svn_branch_patterns = model.get_global_svn_branch_patterns()
119 119 c.svn_tag_patterns = model.get_global_svn_tag_patterns()
120 120
121 121 settings = self.request.registry.settings
122 122 c.svn_proxy_generate_config = settings[generate_config]
123 123
124 124 defaults = self._form_defaults()
125 125
126 126 model.create_largeobjects_dirs_if_needed(defaults['paths_root_path'])
127 127
128 128 data = render('rhodecode:templates/admin/settings/settings.mako',
129 129 self._get_template_context(c), self.request)
130 130 html = formencode.htmlfill.render(
131 131 data,
132 132 defaults=defaults,
133 133 encoding="UTF-8",
134 134 force_defaults=False
135 135 )
136 136 return Response(html)
137 137
138 138 @LoginRequired()
139 139 @HasPermissionAllDecorator('hg.admin')
140 140 @CSRFRequired()
141 141 @view_config(
142 142 route_name='admin_settings_vcs_update', request_method='POST',
143 143 renderer='rhodecode:templates/admin/settings/settings.mako')
144 144 def settings_vcs_update(self):
145 145 _ = self.request.translate
146 146 c = self.load_default_context()
147 147 c.active = 'vcs'
148 148
149 149 model = VcsSettingsModel()
150 150 c.svn_branch_patterns = model.get_global_svn_branch_patterns()
151 151 c.svn_tag_patterns = model.get_global_svn_tag_patterns()
152 152
153 153 settings = self.request.registry.settings
154 154 c.svn_proxy_generate_config = settings[generate_config]
155 155
156 156 application_form = ApplicationUiSettingsForm(self.request.translate)()
157 157
158 158 try:
159 159 form_result = application_form.to_python(dict(self.request.POST))
160 160 except formencode.Invalid as errors:
161 161 h.flash(
162 162 _("Some form inputs contain invalid data."),
163 163 category='error')
164 164 data = render('rhodecode:templates/admin/settings/settings.mako',
165 165 self._get_template_context(c), self.request)
166 166 html = formencode.htmlfill.render(
167 167 data,
168 168 defaults=errors.value,
169 169 errors=errors.error_dict or {},
170 170 prefix_error=False,
171 171 encoding="UTF-8",
172 172 force_defaults=False
173 173 )
174 174 return Response(html)
175 175
176 176 try:
177 177 if c.visual.allow_repo_location_change:
178 model.update_global_path_setting(
179 form_result['paths_root_path'])
178 model.update_global_path_setting(form_result['paths_root_path'])
180 179
181 180 model.update_global_ssl_setting(form_result['web_push_ssl'])
182 181 model.update_global_hook_settings(form_result)
183 182
184 183 model.create_or_update_global_svn_settings(form_result)
185 184 model.create_or_update_global_hg_settings(form_result)
186 185 model.create_or_update_global_git_settings(form_result)
187 186 model.create_or_update_global_pr_settings(form_result)
188 187 except Exception:
189 188 log.exception("Exception while updating settings")
190 189 h.flash(_('Error occurred during updating '
191 190 'application settings'), category='error')
192 191 else:
193 192 Session().commit()
194 193 h.flash(_('Updated VCS settings'), category='success')
195 194 raise HTTPFound(h.route_path('admin_settings_vcs'))
196 195
197 196 data = render('rhodecode:templates/admin/settings/settings.mako',
198 197 self._get_template_context(c), self.request)
199 198 html = formencode.htmlfill.render(
200 199 data,
201 200 defaults=self._form_defaults(),
202 201 encoding="UTF-8",
203 202 force_defaults=False
204 203 )
205 204 return Response(html)
206 205
207 206 @LoginRequired()
208 207 @HasPermissionAllDecorator('hg.admin')
209 208 @CSRFRequired()
210 209 @view_config(
211 210 route_name='admin_settings_vcs_svn_pattern_delete', request_method='POST',
212 211 renderer='json_ext', xhr=True)
213 212 def settings_vcs_delete_svn_pattern(self):
214 213 delete_pattern_id = self.request.POST.get('delete_svn_pattern')
215 214 model = VcsSettingsModel()
216 215 try:
217 216 model.delete_global_svn_pattern(delete_pattern_id)
218 217 except SettingNotFound:
219 218 log.exception(
220 219 'Failed to delete svn_pattern with id %s', delete_pattern_id)
221 220 raise HTTPNotFound()
222 221
223 222 Session().commit()
224 223 return True
225 224
226 225 @LoginRequired()
227 226 @HasPermissionAllDecorator('hg.admin')
228 227 @view_config(
229 228 route_name='admin_settings_mapping', request_method='GET',
230 229 renderer='rhodecode:templates/admin/settings/settings.mako')
231 230 def settings_mapping(self):
232 231 c = self.load_default_context()
233 232 c.active = 'mapping'
234 233
235 234 data = render('rhodecode:templates/admin/settings/settings.mako',
236 235 self._get_template_context(c), self.request)
237 236 html = formencode.htmlfill.render(
238 237 data,
239 238 defaults=self._form_defaults(),
240 239 encoding="UTF-8",
241 240 force_defaults=False
242 241 )
243 242 return Response(html)
244 243
245 244 @LoginRequired()
246 245 @HasPermissionAllDecorator('hg.admin')
247 246 @CSRFRequired()
248 247 @view_config(
249 248 route_name='admin_settings_mapping_update', request_method='POST',
250 249 renderer='rhodecode:templates/admin/settings/settings.mako')
251 250 def settings_mapping_update(self):
252 251 _ = self.request.translate
253 252 c = self.load_default_context()
254 253 c.active = 'mapping'
255 254 rm_obsolete = self.request.POST.get('destroy', False)
256 255 invalidate_cache = self.request.POST.get('invalidate', False)
257 256 log.debug(
258 257 'rescanning repo location with destroy obsolete=%s', rm_obsolete)
259 258
260 259 if invalidate_cache:
261 260 log.debug('invalidating all repositories cache')
262 261 for repo in Repository.get_all():
263 262 ScmModel().mark_for_invalidation(repo.repo_name, delete=True)
264 263
265 264 filesystem_repos = ScmModel().repo_scan()
266 265 added, removed = repo2db_mapper(filesystem_repos, rm_obsolete)
267 266 _repr = lambda l: ', '.join(map(safe_unicode, l)) or '-'
268 267 h.flash(_('Repositories successfully '
269 268 'rescanned added: %s ; removed: %s') %
270 269 (_repr(added), _repr(removed)),
271 270 category='success')
272 271 raise HTTPFound(h.route_path('admin_settings_mapping'))
273 272
274 273 @LoginRequired()
275 274 @HasPermissionAllDecorator('hg.admin')
276 275 @view_config(
277 276 route_name='admin_settings', request_method='GET',
278 277 renderer='rhodecode:templates/admin/settings/settings.mako')
279 278 @view_config(
280 279 route_name='admin_settings_global', request_method='GET',
281 280 renderer='rhodecode:templates/admin/settings/settings.mako')
282 281 def settings_global(self):
283 282 c = self.load_default_context()
284 283 c.active = 'global'
285 284 c.personal_repo_group_default_pattern = RepoGroupModel()\
286 285 .get_personal_group_name_pattern()
287 286
288 287 data = render('rhodecode:templates/admin/settings/settings.mako',
289 288 self._get_template_context(c), self.request)
290 289 html = formencode.htmlfill.render(
291 290 data,
292 291 defaults=self._form_defaults(),
293 292 encoding="UTF-8",
294 293 force_defaults=False
295 294 )
296 295 return Response(html)
297 296
298 297 @LoginRequired()
299 298 @HasPermissionAllDecorator('hg.admin')
300 299 @CSRFRequired()
301 300 @view_config(
302 301 route_name='admin_settings_update', request_method='POST',
303 302 renderer='rhodecode:templates/admin/settings/settings.mako')
304 303 @view_config(
305 304 route_name='admin_settings_global_update', request_method='POST',
306 305 renderer='rhodecode:templates/admin/settings/settings.mako')
307 306 def settings_global_update(self):
308 307 _ = self.request.translate
309 308 c = self.load_default_context()
310 309 c.active = 'global'
311 310 c.personal_repo_group_default_pattern = RepoGroupModel()\
312 311 .get_personal_group_name_pattern()
313 312 application_form = ApplicationSettingsForm(self.request.translate)()
314 313 try:
315 314 form_result = application_form.to_python(dict(self.request.POST))
316 315 except formencode.Invalid as errors:
317 316 h.flash(
318 317 _("Some form inputs contain invalid data."),
319 318 category='error')
320 319 data = render('rhodecode:templates/admin/settings/settings.mako',
321 320 self._get_template_context(c), self.request)
322 321 html = formencode.htmlfill.render(
323 322 data,
324 323 defaults=errors.value,
325 324 errors=errors.error_dict or {},
326 325 prefix_error=False,
327 326 encoding="UTF-8",
328 327 force_defaults=False
329 328 )
330 329 return Response(html)
331 330
332 331 settings = [
333 332 ('title', 'rhodecode_title', 'unicode'),
334 333 ('realm', 'rhodecode_realm', 'unicode'),
335 334 ('pre_code', 'rhodecode_pre_code', 'unicode'),
336 335 ('post_code', 'rhodecode_post_code', 'unicode'),
337 336 ('captcha_public_key', 'rhodecode_captcha_public_key', 'unicode'),
338 337 ('captcha_private_key', 'rhodecode_captcha_private_key', 'unicode'),
339 338 ('create_personal_repo_group', 'rhodecode_create_personal_repo_group', 'bool'),
340 339 ('personal_repo_group_pattern', 'rhodecode_personal_repo_group_pattern', 'unicode'),
341 340 ]
342 341 try:
343 342 for setting, form_key, type_ in settings:
344 343 sett = SettingsModel().create_or_update_setting(
345 344 setting, form_result[form_key], type_)
346 345 Session().add(sett)
347 346
348 347 Session().commit()
349 348 SettingsModel().invalidate_settings_cache()
350 349 h.flash(_('Updated application settings'), category='success')
351 350 except Exception:
352 351 log.exception("Exception while updating application settings")
353 352 h.flash(
354 353 _('Error occurred during updating application settings'),
355 354 category='error')
356 355
357 356 raise HTTPFound(h.route_path('admin_settings_global'))
358 357
359 358 @LoginRequired()
360 359 @HasPermissionAllDecorator('hg.admin')
361 360 @view_config(
362 361 route_name='admin_settings_visual', request_method='GET',
363 362 renderer='rhodecode:templates/admin/settings/settings.mako')
364 363 def settings_visual(self):
365 364 c = self.load_default_context()
366 365 c.active = 'visual'
367 366
368 367 data = render('rhodecode:templates/admin/settings/settings.mako',
369 368 self._get_template_context(c), self.request)
370 369 html = formencode.htmlfill.render(
371 370 data,
372 371 defaults=self._form_defaults(),
373 372 encoding="UTF-8",
374 373 force_defaults=False
375 374 )
376 375 return Response(html)
377 376
378 377 @LoginRequired()
379 378 @HasPermissionAllDecorator('hg.admin')
380 379 @CSRFRequired()
381 380 @view_config(
382 381 route_name='admin_settings_visual_update', request_method='POST',
383 382 renderer='rhodecode:templates/admin/settings/settings.mako')
384 383 def settings_visual_update(self):
385 384 _ = self.request.translate
386 385 c = self.load_default_context()
387 386 c.active = 'visual'
388 387 application_form = ApplicationVisualisationForm(self.request.translate)()
389 388 try:
390 389 form_result = application_form.to_python(dict(self.request.POST))
391 390 except formencode.Invalid as errors:
392 391 h.flash(
393 392 _("Some form inputs contain invalid data."),
394 393 category='error')
395 394 data = render('rhodecode:templates/admin/settings/settings.mako',
396 395 self._get_template_context(c), self.request)
397 396 html = formencode.htmlfill.render(
398 397 data,
399 398 defaults=errors.value,
400 399 errors=errors.error_dict or {},
401 400 prefix_error=False,
402 401 encoding="UTF-8",
403 402 force_defaults=False
404 403 )
405 404 return Response(html)
406 405
407 406 try:
408 407 settings = [
409 408 ('show_public_icon', 'rhodecode_show_public_icon', 'bool'),
410 409 ('show_private_icon', 'rhodecode_show_private_icon', 'bool'),
411 410 ('stylify_metatags', 'rhodecode_stylify_metatags', 'bool'),
412 411 ('repository_fields', 'rhodecode_repository_fields', 'bool'),
413 412 ('dashboard_items', 'rhodecode_dashboard_items', 'int'),
414 413 ('admin_grid_items', 'rhodecode_admin_grid_items', 'int'),
415 414 ('show_version', 'rhodecode_show_version', 'bool'),
416 415 ('use_gravatar', 'rhodecode_use_gravatar', 'bool'),
417 416 ('markup_renderer', 'rhodecode_markup_renderer', 'unicode'),
418 417 ('gravatar_url', 'rhodecode_gravatar_url', 'unicode'),
419 418 ('clone_uri_tmpl', 'rhodecode_clone_uri_tmpl', 'unicode'),
420 419 ('clone_uri_ssh_tmpl', 'rhodecode_clone_uri_ssh_tmpl', 'unicode'),
421 420 ('support_url', 'rhodecode_support_url', 'unicode'),
422 421 ('show_revision_number', 'rhodecode_show_revision_number', 'bool'),
423 422 ('show_sha_length', 'rhodecode_show_sha_length', 'int'),
424 423 ]
425 424 for setting, form_key, type_ in settings:
426 425 sett = SettingsModel().create_or_update_setting(
427 426 setting, form_result[form_key], type_)
428 427 Session().add(sett)
429 428
430 429 Session().commit()
431 430 SettingsModel().invalidate_settings_cache()
432 431 h.flash(_('Updated visualisation settings'), category='success')
433 432 except Exception:
434 433 log.exception("Exception updating visualization settings")
435 434 h.flash(_('Error occurred during updating '
436 435 'visualisation settings'),
437 436 category='error')
438 437
439 438 raise HTTPFound(h.route_path('admin_settings_visual'))
440 439
441 440 @LoginRequired()
442 441 @HasPermissionAllDecorator('hg.admin')
443 442 @view_config(
444 443 route_name='admin_settings_issuetracker', request_method='GET',
445 444 renderer='rhodecode:templates/admin/settings/settings.mako')
446 445 def settings_issuetracker(self):
447 446 c = self.load_default_context()
448 447 c.active = 'issuetracker'
449 448 defaults = SettingsModel().get_all_settings()
450 449
451 450 entry_key = 'rhodecode_issuetracker_pat_'
452 451
453 452 c.issuetracker_entries = {}
454 453 for k, v in defaults.items():
455 454 if k.startswith(entry_key):
456 455 uid = k[len(entry_key):]
457 456 c.issuetracker_entries[uid] = None
458 457
459 458 for uid in c.issuetracker_entries:
460 459 c.issuetracker_entries[uid] = AttributeDict({
461 460 'pat': defaults.get('rhodecode_issuetracker_pat_' + uid),
462 461 'url': defaults.get('rhodecode_issuetracker_url_' + uid),
463 462 'pref': defaults.get('rhodecode_issuetracker_pref_' + uid),
464 463 'desc': defaults.get('rhodecode_issuetracker_desc_' + uid),
465 464 })
466 465
467 466 return self._get_template_context(c)
468 467
469 468 @LoginRequired()
470 469 @HasPermissionAllDecorator('hg.admin')
471 470 @CSRFRequired()
472 471 @view_config(
473 472 route_name='admin_settings_issuetracker_test', request_method='POST',
474 473 renderer='string', xhr=True)
475 474 def settings_issuetracker_test(self):
476 475 return h.urlify_commit_message(
477 476 self.request.POST.get('test_text', ''),
478 477 'repo_group/test_repo1')
479 478
480 479 @LoginRequired()
481 480 @HasPermissionAllDecorator('hg.admin')
482 481 @CSRFRequired()
483 482 @view_config(
484 483 route_name='admin_settings_issuetracker_update', request_method='POST',
485 484 renderer='rhodecode:templates/admin/settings/settings.mako')
486 485 def settings_issuetracker_update(self):
487 486 _ = self.request.translate
488 487 self.load_default_context()
489 488 settings_model = IssueTrackerSettingsModel()
490 489
491 490 try:
492 491 form = IssueTrackerPatternsForm(self.request.translate)()
493 492 data = form.to_python(self.request.POST)
494 493 except formencode.Invalid as errors:
495 494 log.exception('Failed to add new pattern')
496 495 error = errors
497 496 h.flash(_('Invalid issue tracker pattern: {}'.format(error)),
498 497 category='error')
499 498 raise HTTPFound(h.route_path('admin_settings_issuetracker'))
500 499
501 500 if data:
502 501 for uid in data.get('delete_patterns', []):
503 502 settings_model.delete_entries(uid)
504 503
505 504 for pattern in data.get('patterns', []):
506 505 for setting, value, type_ in pattern:
507 506 sett = settings_model.create_or_update_setting(
508 507 setting, value, type_)
509 508 Session().add(sett)
510 509
511 510 Session().commit()
512 511
513 512 SettingsModel().invalidate_settings_cache()
514 513 h.flash(_('Updated issue tracker entries'), category='success')
515 514 raise HTTPFound(h.route_path('admin_settings_issuetracker'))
516 515
517 516 @LoginRequired()
518 517 @HasPermissionAllDecorator('hg.admin')
519 518 @CSRFRequired()
520 519 @view_config(
521 520 route_name='admin_settings_issuetracker_delete', request_method='POST',
522 521 renderer='rhodecode:templates/admin/settings/settings.mako')
523 522 def settings_issuetracker_delete(self):
524 523 _ = self.request.translate
525 524 self.load_default_context()
526 525 uid = self.request.POST.get('uid')
527 526 try:
528 527 IssueTrackerSettingsModel().delete_entries(uid)
529 528 except Exception:
530 529 log.exception('Failed to delete issue tracker setting %s', uid)
531 530 raise HTTPNotFound()
532 531 h.flash(_('Removed issue tracker entry'), category='success')
533 532 raise HTTPFound(h.route_path('admin_settings_issuetracker'))
534 533
535 534 @LoginRequired()
536 535 @HasPermissionAllDecorator('hg.admin')
537 536 @view_config(
538 537 route_name='admin_settings_email', request_method='GET',
539 538 renderer='rhodecode:templates/admin/settings/settings.mako')
540 539 def settings_email(self):
541 540 c = self.load_default_context()
542 541 c.active = 'email'
543 542 c.rhodecode_ini = rhodecode.CONFIG
544 543
545 544 data = render('rhodecode:templates/admin/settings/settings.mako',
546 545 self._get_template_context(c), self.request)
547 546 html = formencode.htmlfill.render(
548 547 data,
549 548 defaults=self._form_defaults(),
550 549 encoding="UTF-8",
551 550 force_defaults=False
552 551 )
553 552 return Response(html)
554 553
555 554 @LoginRequired()
556 555 @HasPermissionAllDecorator('hg.admin')
557 556 @CSRFRequired()
558 557 @view_config(
559 558 route_name='admin_settings_email_update', request_method='POST',
560 559 renderer='rhodecode:templates/admin/settings/settings.mako')
561 560 def settings_email_update(self):
562 561 _ = self.request.translate
563 562 c = self.load_default_context()
564 563 c.active = 'email'
565 564
566 565 test_email = self.request.POST.get('test_email')
567 566
568 567 if not test_email:
569 568 h.flash(_('Please enter email address'), category='error')
570 569 raise HTTPFound(h.route_path('admin_settings_email'))
571 570
572 571 email_kwargs = {
573 572 'date': datetime.datetime.now(),
574 573 'user': c.rhodecode_user,
575 574 'rhodecode_version': c.rhodecode_version
576 575 }
577 576
578 577 (subject, headers, email_body,
579 578 email_body_plaintext) = EmailNotificationModel().render_email(
580 579 EmailNotificationModel.TYPE_EMAIL_TEST, **email_kwargs)
581 580
582 581 recipients = [test_email] if test_email else None
583 582
584 583 run_task(tasks.send_email, recipients, subject,
585 584 email_body_plaintext, email_body)
586 585
587 586 h.flash(_('Send email task created'), category='success')
588 587 raise HTTPFound(h.route_path('admin_settings_email'))
589 588
590 589 @LoginRequired()
591 590 @HasPermissionAllDecorator('hg.admin')
592 591 @view_config(
593 592 route_name='admin_settings_hooks', request_method='GET',
594 593 renderer='rhodecode:templates/admin/settings/settings.mako')
595 594 def settings_hooks(self):
596 595 c = self.load_default_context()
597 596 c.active = 'hooks'
598 597
599 598 model = SettingsModel()
600 599 c.hooks = model.get_builtin_hooks()
601 600 c.custom_hooks = model.get_custom_hooks()
602 601
603 602 data = render('rhodecode:templates/admin/settings/settings.mako',
604 603 self._get_template_context(c), self.request)
605 604 html = formencode.htmlfill.render(
606 605 data,
607 606 defaults=self._form_defaults(),
608 607 encoding="UTF-8",
609 608 force_defaults=False
610 609 )
611 610 return Response(html)
612 611
613 612 @LoginRequired()
614 613 @HasPermissionAllDecorator('hg.admin')
615 614 @CSRFRequired()
616 615 @view_config(
617 616 route_name='admin_settings_hooks_update', request_method='POST',
618 617 renderer='rhodecode:templates/admin/settings/settings.mako')
619 618 @view_config(
620 619 route_name='admin_settings_hooks_delete', request_method='POST',
621 620 renderer='rhodecode:templates/admin/settings/settings.mako')
622 621 def settings_hooks_update(self):
623 622 _ = self.request.translate
624 623 c = self.load_default_context()
625 624 c.active = 'hooks'
626 625 if c.visual.allow_custom_hooks_settings:
627 626 ui_key = self.request.POST.get('new_hook_ui_key')
628 627 ui_value = self.request.POST.get('new_hook_ui_value')
629 628
630 629 hook_id = self.request.POST.get('hook_id')
631 630 new_hook = False
632 631
633 632 model = SettingsModel()
634 633 try:
635 634 if ui_value and ui_key:
636 635 model.create_or_update_hook(ui_key, ui_value)
637 636 h.flash(_('Added new hook'), category='success')
638 637 new_hook = True
639 638 elif hook_id:
640 639 RhodeCodeUi.delete(hook_id)
641 640 Session().commit()
642 641
643 642 # check for edits
644 643 update = False
645 644 _d = self.request.POST.dict_of_lists()
646 645 for k, v in zip(_d.get('hook_ui_key', []),
647 646 _d.get('hook_ui_value_new', [])):
648 647 model.create_or_update_hook(k, v)
649 648 update = True
650 649
651 650 if update and not new_hook:
652 651 h.flash(_('Updated hooks'), category='success')
653 652 Session().commit()
654 653 except Exception:
655 654 log.exception("Exception during hook creation")
656 655 h.flash(_('Error occurred during hook creation'),
657 656 category='error')
658 657
659 658 raise HTTPFound(h.route_path('admin_settings_hooks'))
660 659
661 660 @LoginRequired()
662 661 @HasPermissionAllDecorator('hg.admin')
663 662 @view_config(
664 663 route_name='admin_settings_search', request_method='GET',
665 664 renderer='rhodecode:templates/admin/settings/settings.mako')
666 665 def settings_search(self):
667 666 c = self.load_default_context()
668 667 c.active = 'search'
669 668
670 669 searcher = searcher_from_config(self.request.registry.settings)
671 670 c.statistics = searcher.statistics(self.request.translate)
672 671
673 672 return self._get_template_context(c)
674 673
675 674 @LoginRequired()
676 675 @HasPermissionAllDecorator('hg.admin')
677 676 @view_config(
678 677 route_name='admin_settings_automation', request_method='GET',
679 678 renderer='rhodecode:templates/admin/settings/settings.mako')
680 679 def settings_automation(self):
681 680 c = self.load_default_context()
682 681 c.active = 'automation'
683 682
684 683 return self._get_template_context(c)
685 684
686 685 @LoginRequired()
687 686 @HasPermissionAllDecorator('hg.admin')
688 687 @view_config(
689 688 route_name='admin_settings_labs', request_method='GET',
690 689 renderer='rhodecode:templates/admin/settings/settings.mako')
691 690 def settings_labs(self):
692 691 c = self.load_default_context()
693 692 if not c.labs_active:
694 693 raise HTTPFound(h.route_path('admin_settings'))
695 694
696 695 c.active = 'labs'
697 696 c.lab_settings = _LAB_SETTINGS
698 697
699 698 data = render('rhodecode:templates/admin/settings/settings.mako',
700 699 self._get_template_context(c), self.request)
701 700 html = formencode.htmlfill.render(
702 701 data,
703 702 defaults=self._form_defaults(),
704 703 encoding="UTF-8",
705 704 force_defaults=False
706 705 )
707 706 return Response(html)
708 707
709 708 @LoginRequired()
710 709 @HasPermissionAllDecorator('hg.admin')
711 710 @CSRFRequired()
712 711 @view_config(
713 712 route_name='admin_settings_labs_update', request_method='POST',
714 713 renderer='rhodecode:templates/admin/settings/settings.mako')
715 714 def settings_labs_update(self):
716 715 _ = self.request.translate
717 716 c = self.load_default_context()
718 717 c.active = 'labs'
719 718
720 719 application_form = LabsSettingsForm(self.request.translate)()
721 720 try:
722 721 form_result = application_form.to_python(dict(self.request.POST))
723 722 except formencode.Invalid as errors:
724 723 h.flash(
725 724 _("Some form inputs contain invalid data."),
726 725 category='error')
727 726 data = render('rhodecode:templates/admin/settings/settings.mako',
728 727 self._get_template_context(c), self.request)
729 728 html = formencode.htmlfill.render(
730 729 data,
731 730 defaults=errors.value,
732 731 errors=errors.error_dict or {},
733 732 prefix_error=False,
734 733 encoding="UTF-8",
735 734 force_defaults=False
736 735 )
737 736 return Response(html)
738 737
739 738 try:
740 739 session = Session()
741 740 for setting in _LAB_SETTINGS:
742 741 setting_name = setting.key[len('rhodecode_'):]
743 742 sett = SettingsModel().create_or_update_setting(
744 743 setting_name, form_result[setting.key], setting.type)
745 744 session.add(sett)
746 745
747 746 except Exception:
748 747 log.exception('Exception while updating lab settings')
749 748 h.flash(_('Error occurred during updating labs settings'),
750 749 category='error')
751 750 else:
752 751 Session().commit()
753 752 SettingsModel().invalidate_settings_cache()
754 753 h.flash(_('Updated Labs settings'), category='success')
755 754 raise HTTPFound(h.route_path('admin_settings_labs'))
756 755
757 756 data = render('rhodecode:templates/admin/settings/settings.mako',
758 757 self._get_template_context(c), self.request)
759 758 html = formencode.htmlfill.render(
760 759 data,
761 760 defaults=self._form_defaults(),
762 761 encoding="UTF-8",
763 762 force_defaults=False
764 763 )
765 764 return Response(html)
766 765
767 766
768 767 # :param key: name of the setting including the 'rhodecode_' prefix
769 768 # :param type: the RhodeCodeSetting type to use.
770 769 # :param group: the i18ned group in which we should dispaly this setting
771 770 # :param label: the i18ned label we should display for this setting
772 771 # :param help: the i18ned help we should dispaly for this setting
773 772 LabSetting = collections.namedtuple(
774 773 'LabSetting', ('key', 'type', 'group', 'label', 'help'))
775 774
776 775
777 776 # This list has to be kept in sync with the form
778 777 # rhodecode.model.forms.LabsSettingsForm.
779 778 _LAB_SETTINGS = [
780 779
781 780 ]
@@ -1,1919 +1,1919 b''
1 1 #!/usr/bin/python2.4
2 2
3 3 from __future__ import division
4 4
5 5 """Diff Match and Patch
6 6
7 7 Copyright 2006 Google Inc.
8 8 http://code.google.com/p/google-diff-match-patch/
9 9
10 10 Licensed under the Apache License, Version 2.0 (the "License");
11 11 you may not use this file except in compliance with the License.
12 12 You may obtain a copy of the License at
13 13
14 14 http://www.apache.org/licenses/LICENSE-2.0
15 15
16 16 Unless required by applicable law or agreed to in writing, software
17 17 distributed under the License is distributed on an "AS IS" BASIS,
18 18 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19 19 See the License for the specific language governing permissions and
20 20 limitations under the License.
21 21 """
22 22
23 23 """Functions for diff, match and patch.
24 24
25 25 Computes the difference between two texts to create a patch.
26 26 Applies the patch onto another text, allowing for errors.
27 27 """
28 28
29 29 __author__ = 'fraser@google.com (Neil Fraser)'
30 30
31 31 import math
32 32 import re
33 33 import sys
34 34 import time
35 35 import urllib
36 36
37 37 class diff_match_patch:
38 38 """Class containing the diff, match and patch methods.
39 39
40 40 Also contains the behaviour settings.
41 41 """
42 42
43 43 def __init__(self):
44 44 """Inits a diff_match_patch object with default settings.
45 45 Redefine these in your program to override the defaults.
46 46 """
47 47
48 48 # Number of seconds to map a diff before giving up (0 for infinity).
49 49 self.Diff_Timeout = 1.0
50 50 # Cost of an empty edit operation in terms of edit characters.
51 51 self.Diff_EditCost = 4
52 52 # At what point is no match declared (0.0 = perfection, 1.0 = very loose).
53 53 self.Match_Threshold = 0.5
54 54 # How far to search for a match (0 = exact location, 1000+ = broad match).
55 55 # A match this many characters away from the expected location will add
56 56 # 1.0 to the score (0.0 is a perfect match).
57 57 self.Match_Distance = 1000
58 58 # When deleting a large block of text (over ~64 characters), how close do
59 59 # the contents have to be to match the expected contents. (0.0 = perfection,
60 60 # 1.0 = very loose). Note that Match_Threshold controls how closely the
61 61 # end points of a delete need to match.
62 62 self.Patch_DeleteThreshold = 0.5
63 63 # Chunk size for context length.
64 64 self.Patch_Margin = 4
65 65
66 66 # The number of bits in an int.
67 67 # Python has no maximum, thus to disable patch splitting set to 0.
68 68 # However to avoid long patches in certain pathological cases, use 32.
69 69 # Multiple short patches (using native ints) are much faster than long ones.
70 70 self.Match_MaxBits = 32
71 71
72 72 # DIFF FUNCTIONS
73 73
74 74 # The data structure representing a diff is an array of tuples:
75 75 # [(DIFF_DELETE, "Hello"), (DIFF_INSERT, "Goodbye"), (DIFF_EQUAL, " world.")]
76 76 # which means: delete "Hello", add "Goodbye" and keep " world."
77 77 DIFF_DELETE = -1
78 78 DIFF_INSERT = 1
79 79 DIFF_EQUAL = 0
80 80
81 81 def diff_main(self, text1, text2, checklines=True, deadline=None):
82 82 """Find the differences between two texts. Simplifies the problem by
83 83 stripping any common prefix or suffix off the texts before diffing.
84 84
85 85 Args:
86 86 text1: Old string to be diffed.
87 87 text2: New string to be diffed.
88 88 checklines: Optional speedup flag. If present and false, then don't run
89 89 a line-level diff first to identify the changed areas.
90 90 Defaults to true, which does a faster, slightly less optimal diff.
91 91 deadline: Optional time when the diff should be complete by. Used
92 92 internally for recursive calls. Users should set DiffTimeout instead.
93 93
94 94 Returns:
95 95 Array of changes.
96 96 """
97 97 # Set a deadline by which time the diff must be complete.
98 if deadline == None:
98 if deadline is None:
99 99 # Unlike in most languages, Python counts time in seconds.
100 100 if self.Diff_Timeout <= 0:
101 101 deadline = sys.maxint
102 102 else:
103 103 deadline = time.time() + self.Diff_Timeout
104 104
105 105 # Check for null inputs.
106 if text1 == None or text2 == None:
106 if text1 is None or text2 is None:
107 107 raise ValueError("Null inputs. (diff_main)")
108 108
109 109 # Check for equality (speedup).
110 110 if text1 == text2:
111 111 if text1:
112 112 return [(self.DIFF_EQUAL, text1)]
113 113 return []
114 114
115 115 # Trim off common prefix (speedup).
116 116 commonlength = self.diff_commonPrefix(text1, text2)
117 117 commonprefix = text1[:commonlength]
118 118 text1 = text1[commonlength:]
119 119 text2 = text2[commonlength:]
120 120
121 121 # Trim off common suffix (speedup).
122 122 commonlength = self.diff_commonSuffix(text1, text2)
123 123 if commonlength == 0:
124 124 commonsuffix = ''
125 125 else:
126 126 commonsuffix = text1[-commonlength:]
127 127 text1 = text1[:-commonlength]
128 128 text2 = text2[:-commonlength]
129 129
130 130 # Compute the diff on the middle block.
131 131 diffs = self.diff_compute(text1, text2, checklines, deadline)
132 132
133 133 # Restore the prefix and suffix.
134 134 if commonprefix:
135 135 diffs[:0] = [(self.DIFF_EQUAL, commonprefix)]
136 136 if commonsuffix:
137 137 diffs.append((self.DIFF_EQUAL, commonsuffix))
138 138 self.diff_cleanupMerge(diffs)
139 139 return diffs
140 140
141 141 def diff_compute(self, text1, text2, checklines, deadline):
142 142 """Find the differences between two texts. Assumes that the texts do not
143 143 have any common prefix or suffix.
144 144
145 145 Args:
146 146 text1: Old string to be diffed.
147 147 text2: New string to be diffed.
148 148 checklines: Speedup flag. If false, then don't run a line-level diff
149 149 first to identify the changed areas.
150 150 If true, then run a faster, slightly less optimal diff.
151 151 deadline: Time when the diff should be complete by.
152 152
153 153 Returns:
154 154 Array of changes.
155 155 """
156 156 if not text1:
157 157 # Just add some text (speedup).
158 158 return [(self.DIFF_INSERT, text2)]
159 159
160 160 if not text2:
161 161 # Just delete some text (speedup).
162 162 return [(self.DIFF_DELETE, text1)]
163 163
164 164 if len(text1) > len(text2):
165 165 (longtext, shorttext) = (text1, text2)
166 166 else:
167 167 (shorttext, longtext) = (text1, text2)
168 168 i = longtext.find(shorttext)
169 169 if i != -1:
170 170 # Shorter text is inside the longer text (speedup).
171 171 diffs = [(self.DIFF_INSERT, longtext[:i]), (self.DIFF_EQUAL, shorttext),
172 172 (self.DIFF_INSERT, longtext[i + len(shorttext):])]
173 173 # Swap insertions for deletions if diff is reversed.
174 174 if len(text1) > len(text2):
175 175 diffs[0] = (self.DIFF_DELETE, diffs[0][1])
176 176 diffs[2] = (self.DIFF_DELETE, diffs[2][1])
177 177 return diffs
178 178
179 179 if len(shorttext) == 1:
180 180 # Single character string.
181 181 # After the previous speedup, the character can't be an equality.
182 182 return [(self.DIFF_DELETE, text1), (self.DIFF_INSERT, text2)]
183 183
184 184 # Check to see if the problem can be split in two.
185 185 hm = self.diff_halfMatch(text1, text2)
186 186 if hm:
187 187 # A half-match was found, sort out the return data.
188 188 (text1_a, text1_b, text2_a, text2_b, mid_common) = hm
189 189 # Send both pairs off for separate processing.
190 190 diffs_a = self.diff_main(text1_a, text2_a, checklines, deadline)
191 191 diffs_b = self.diff_main(text1_b, text2_b, checklines, deadline)
192 192 # Merge the results.
193 193 return diffs_a + [(self.DIFF_EQUAL, mid_common)] + diffs_b
194 194
195 195 if checklines and len(text1) > 100 and len(text2) > 100:
196 196 return self.diff_lineMode(text1, text2, deadline)
197 197
198 198 return self.diff_bisect(text1, text2, deadline)
199 199
200 200 def diff_lineMode(self, text1, text2, deadline):
201 201 """Do a quick line-level diff on both strings, then rediff the parts for
202 202 greater accuracy.
203 203 This speedup can produce non-minimal diffs.
204 204
205 205 Args:
206 206 text1: Old string to be diffed.
207 207 text2: New string to be diffed.
208 208 deadline: Time when the diff should be complete by.
209 209
210 210 Returns:
211 211 Array of changes.
212 212 """
213 213
214 214 # Scan the text on a line-by-line basis first.
215 215 (text1, text2, linearray) = self.diff_linesToChars(text1, text2)
216 216
217 217 diffs = self.diff_main(text1, text2, False, deadline)
218 218
219 219 # Convert the diff back to original text.
220 220 self.diff_charsToLines(diffs, linearray)
221 221 # Eliminate freak matches (e.g. blank lines)
222 222 self.diff_cleanupSemantic(diffs)
223 223
224 224 # Rediff any replacement blocks, this time character-by-character.
225 225 # Add a dummy entry at the end.
226 226 diffs.append((self.DIFF_EQUAL, ''))
227 227 pointer = 0
228 228 count_delete = 0
229 229 count_insert = 0
230 230 text_delete = ''
231 231 text_insert = ''
232 232 while pointer < len(diffs):
233 233 if diffs[pointer][0] == self.DIFF_INSERT:
234 234 count_insert += 1
235 235 text_insert += diffs[pointer][1]
236 236 elif diffs[pointer][0] == self.DIFF_DELETE:
237 237 count_delete += 1
238 238 text_delete += diffs[pointer][1]
239 239 elif diffs[pointer][0] == self.DIFF_EQUAL:
240 240 # Upon reaching an equality, check for prior redundancies.
241 241 if count_delete >= 1 and count_insert >= 1:
242 242 # Delete the offending records and add the merged ones.
243 243 a = self.diff_main(text_delete, text_insert, False, deadline)
244 244 diffs[pointer - count_delete - count_insert : pointer] = a
245 245 pointer = pointer - count_delete - count_insert + len(a)
246 246 count_insert = 0
247 247 count_delete = 0
248 248 text_delete = ''
249 249 text_insert = ''
250 250
251 251 pointer += 1
252 252
253 253 diffs.pop() # Remove the dummy entry at the end.
254 254
255 255 return diffs
256 256
257 257 def diff_bisect(self, text1, text2, deadline):
258 258 """Find the 'middle snake' of a diff, split the problem in two
259 259 and return the recursively constructed diff.
260 260 See Myers 1986 paper: An O(ND) Difference Algorithm and Its Variations.
261 261
262 262 Args:
263 263 text1: Old string to be diffed.
264 264 text2: New string to be diffed.
265 265 deadline: Time at which to bail if not yet complete.
266 266
267 267 Returns:
268 268 Array of diff tuples.
269 269 """
270 270
271 271 # Cache the text lengths to prevent multiple calls.
272 272 text1_length = len(text1)
273 273 text2_length = len(text2)
274 274 max_d = (text1_length + text2_length + 1) // 2
275 275 v_offset = max_d
276 276 v_length = 2 * max_d
277 277 v1 = [-1] * v_length
278 278 v1[v_offset + 1] = 0
279 279 v2 = v1[:]
280 280 delta = text1_length - text2_length
281 281 # If the total number of characters is odd, then the front path will
282 282 # collide with the reverse path.
283 283 front = (delta % 2 != 0)
284 284 # Offsets for start and end of k loop.
285 285 # Prevents mapping of space beyond the grid.
286 286 k1start = 0
287 287 k1end = 0
288 288 k2start = 0
289 289 k2end = 0
290 290 for d in xrange(max_d):
291 291 # Bail out if deadline is reached.
292 292 if time.time() > deadline:
293 293 break
294 294
295 295 # Walk the front path one step.
296 296 for k1 in xrange(-d + k1start, d + 1 - k1end, 2):
297 297 k1_offset = v_offset + k1
298 298 if k1 == -d or (k1 != d and
299 299 v1[k1_offset - 1] < v1[k1_offset + 1]):
300 300 x1 = v1[k1_offset + 1]
301 301 else:
302 302 x1 = v1[k1_offset - 1] + 1
303 303 y1 = x1 - k1
304 304 while (x1 < text1_length and y1 < text2_length and
305 305 text1[x1] == text2[y1]):
306 306 x1 += 1
307 307 y1 += 1
308 308 v1[k1_offset] = x1
309 309 if x1 > text1_length:
310 310 # Ran off the right of the graph.
311 311 k1end += 2
312 312 elif y1 > text2_length:
313 313 # Ran off the bottom of the graph.
314 314 k1start += 2
315 315 elif front:
316 316 k2_offset = v_offset + delta - k1
317 317 if k2_offset >= 0 and k2_offset < v_length and v2[k2_offset] != -1:
318 318 # Mirror x2 onto top-left coordinate system.
319 319 x2 = text1_length - v2[k2_offset]
320 320 if x1 >= x2:
321 321 # Overlap detected.
322 322 return self.diff_bisectSplit(text1, text2, x1, y1, deadline)
323 323
324 324 # Walk the reverse path one step.
325 325 for k2 in xrange(-d + k2start, d + 1 - k2end, 2):
326 326 k2_offset = v_offset + k2
327 327 if k2 == -d or (k2 != d and
328 328 v2[k2_offset - 1] < v2[k2_offset + 1]):
329 329 x2 = v2[k2_offset + 1]
330 330 else:
331 331 x2 = v2[k2_offset - 1] + 1
332 332 y2 = x2 - k2
333 333 while (x2 < text1_length and y2 < text2_length and
334 334 text1[-x2 - 1] == text2[-y2 - 1]):
335 335 x2 += 1
336 336 y2 += 1
337 337 v2[k2_offset] = x2
338 338 if x2 > text1_length:
339 339 # Ran off the left of the graph.
340 340 k2end += 2
341 341 elif y2 > text2_length:
342 342 # Ran off the top of the graph.
343 343 k2start += 2
344 344 elif not front:
345 345 k1_offset = v_offset + delta - k2
346 346 if k1_offset >= 0 and k1_offset < v_length and v1[k1_offset] != -1:
347 347 x1 = v1[k1_offset]
348 348 y1 = v_offset + x1 - k1_offset
349 349 # Mirror x2 onto top-left coordinate system.
350 350 x2 = text1_length - x2
351 351 if x1 >= x2:
352 352 # Overlap detected.
353 353 return self.diff_bisectSplit(text1, text2, x1, y1, deadline)
354 354
355 355 # Diff took too long and hit the deadline or
356 356 # number of diffs equals number of characters, no commonality at all.
357 357 return [(self.DIFF_DELETE, text1), (self.DIFF_INSERT, text2)]
358 358
359 359 def diff_bisectSplit(self, text1, text2, x, y, deadline):
360 360 """Given the location of the 'middle snake', split the diff in two parts
361 361 and recurse.
362 362
363 363 Args:
364 364 text1: Old string to be diffed.
365 365 text2: New string to be diffed.
366 366 x: Index of split point in text1.
367 367 y: Index of split point in text2.
368 368 deadline: Time at which to bail if not yet complete.
369 369
370 370 Returns:
371 371 Array of diff tuples.
372 372 """
373 373 text1a = text1[:x]
374 374 text2a = text2[:y]
375 375 text1b = text1[x:]
376 376 text2b = text2[y:]
377 377
378 378 # Compute both diffs serially.
379 379 diffs = self.diff_main(text1a, text2a, False, deadline)
380 380 diffsb = self.diff_main(text1b, text2b, False, deadline)
381 381
382 382 return diffs + diffsb
383 383
384 384 def diff_linesToChars(self, text1, text2):
385 385 """Split two texts into an array of strings. Reduce the texts to a string
386 386 of hashes where each Unicode character represents one line.
387 387
388 388 Args:
389 389 text1: First string.
390 390 text2: Second string.
391 391
392 392 Returns:
393 393 Three element tuple, containing the encoded text1, the encoded text2 and
394 394 the array of unique strings. The zeroth element of the array of unique
395 395 strings is intentionally blank.
396 396 """
397 397 lineArray = [] # e.g. lineArray[4] == "Hello\n"
398 398 lineHash = {} # e.g. lineHash["Hello\n"] == 4
399 399
400 400 # "\x00" is a valid character, but various debuggers don't like it.
401 401 # So we'll insert a junk entry to avoid generating a null character.
402 402 lineArray.append('')
403 403
404 404 def diff_linesToCharsMunge(text):
405 405 """Split a text into an array of strings. Reduce the texts to a string
406 406 of hashes where each Unicode character represents one line.
407 407 Modifies linearray and linehash through being a closure.
408 408
409 409 Args:
410 410 text: String to encode.
411 411
412 412 Returns:
413 413 Encoded string.
414 414 """
415 415 chars = []
416 416 # Walk the text, pulling out a substring for each line.
417 417 # text.split('\n') would would temporarily double our memory footprint.
418 418 # Modifying text would create many large strings to garbage collect.
419 419 lineStart = 0
420 420 lineEnd = -1
421 421 while lineEnd < len(text) - 1:
422 422 lineEnd = text.find('\n', lineStart)
423 423 if lineEnd == -1:
424 424 lineEnd = len(text) - 1
425 425 line = text[lineStart:lineEnd + 1]
426 426 lineStart = lineEnd + 1
427 427
428 428 if line in lineHash:
429 429 chars.append(unichr(lineHash[line]))
430 430 else:
431 431 lineArray.append(line)
432 432 lineHash[line] = len(lineArray) - 1
433 433 chars.append(unichr(len(lineArray) - 1))
434 434 return "".join(chars)
435 435
436 436 chars1 = diff_linesToCharsMunge(text1)
437 437 chars2 = diff_linesToCharsMunge(text2)
438 438 return (chars1, chars2, lineArray)
439 439
440 440 def diff_charsToLines(self, diffs, lineArray):
441 441 """Rehydrate the text in a diff from a string of line hashes to real lines
442 442 of text.
443 443
444 444 Args:
445 445 diffs: Array of diff tuples.
446 446 lineArray: Array of unique strings.
447 447 """
448 448 for x in xrange(len(diffs)):
449 449 text = []
450 450 for char in diffs[x][1]:
451 451 text.append(lineArray[ord(char)])
452 452 diffs[x] = (diffs[x][0], "".join(text))
453 453
454 454 def diff_commonPrefix(self, text1, text2):
455 455 """Determine the common prefix of two strings.
456 456
457 457 Args:
458 458 text1: First string.
459 459 text2: Second string.
460 460
461 461 Returns:
462 462 The number of characters common to the start of each string.
463 463 """
464 464 # Quick check for common null cases.
465 465 if not text1 or not text2 or text1[0] != text2[0]:
466 466 return 0
467 467 # Binary search.
468 468 # Performance analysis: http://neil.fraser.name/news/2007/10/09/
469 469 pointermin = 0
470 470 pointermax = min(len(text1), len(text2))
471 471 pointermid = pointermax
472 472 pointerstart = 0
473 473 while pointermin < pointermid:
474 474 if text1[pointerstart:pointermid] == text2[pointerstart:pointermid]:
475 475 pointermin = pointermid
476 476 pointerstart = pointermin
477 477 else:
478 478 pointermax = pointermid
479 479 pointermid = (pointermax - pointermin) // 2 + pointermin
480 480 return pointermid
481 481
482 482 def diff_commonSuffix(self, text1, text2):
483 483 """Determine the common suffix of two strings.
484 484
485 485 Args:
486 486 text1: First string.
487 487 text2: Second string.
488 488
489 489 Returns:
490 490 The number of characters common to the end of each string.
491 491 """
492 492 # Quick check for common null cases.
493 493 if not text1 or not text2 or text1[-1] != text2[-1]:
494 494 return 0
495 495 # Binary search.
496 496 # Performance analysis: http://neil.fraser.name/news/2007/10/09/
497 497 pointermin = 0
498 498 pointermax = min(len(text1), len(text2))
499 499 pointermid = pointermax
500 500 pointerend = 0
501 501 while pointermin < pointermid:
502 502 if (text1[-pointermid:len(text1) - pointerend] ==
503 503 text2[-pointermid:len(text2) - pointerend]):
504 504 pointermin = pointermid
505 505 pointerend = pointermin
506 506 else:
507 507 pointermax = pointermid
508 508 pointermid = (pointermax - pointermin) // 2 + pointermin
509 509 return pointermid
510 510
511 511 def diff_commonOverlap(self, text1, text2):
512 512 """Determine if the suffix of one string is the prefix of another.
513 513
514 514 Args:
515 515 text1 First string.
516 516 text2 Second string.
517 517
518 518 Returns:
519 519 The number of characters common to the end of the first
520 520 string and the start of the second string.
521 521 """
522 522 # Cache the text lengths to prevent multiple calls.
523 523 text1_length = len(text1)
524 524 text2_length = len(text2)
525 525 # Eliminate the null case.
526 526 if text1_length == 0 or text2_length == 0:
527 527 return 0
528 528 # Truncate the longer string.
529 529 if text1_length > text2_length:
530 530 text1 = text1[-text2_length:]
531 531 elif text1_length < text2_length:
532 532 text2 = text2[:text1_length]
533 533 text_length = min(text1_length, text2_length)
534 534 # Quick check for the worst case.
535 535 if text1 == text2:
536 536 return text_length
537 537
538 538 # Start by looking for a single character match
539 539 # and increase length until no match is found.
540 540 # Performance analysis: http://neil.fraser.name/news/2010/11/04/
541 541 best = 0
542 542 length = 1
543 543 while True:
544 544 pattern = text1[-length:]
545 545 found = text2.find(pattern)
546 546 if found == -1:
547 547 return best
548 548 length += found
549 549 if found == 0 or text1[-length:] == text2[:length]:
550 550 best = length
551 551 length += 1
552 552
553 553 def diff_halfMatch(self, text1, text2):
554 554 """Do the two texts share a substring which is at least half the length of
555 555 the longer text?
556 556 This speedup can produce non-minimal diffs.
557 557
558 558 Args:
559 559 text1: First string.
560 560 text2: Second string.
561 561
562 562 Returns:
563 563 Five element Array, containing the prefix of text1, the suffix of text1,
564 564 the prefix of text2, the suffix of text2 and the common middle. Or None
565 565 if there was no match.
566 566 """
567 567 if self.Diff_Timeout <= 0:
568 568 # Don't risk returning a non-optimal diff if we have unlimited time.
569 569 return None
570 570 if len(text1) > len(text2):
571 571 (longtext, shorttext) = (text1, text2)
572 572 else:
573 573 (shorttext, longtext) = (text1, text2)
574 574 if len(longtext) < 4 or len(shorttext) * 2 < len(longtext):
575 575 return None # Pointless.
576 576
577 577 def diff_halfMatchI(longtext, shorttext, i):
578 578 """Does a substring of shorttext exist within longtext such that the
579 579 substring is at least half the length of longtext?
580 580 Closure, but does not reference any external variables.
581 581
582 582 Args:
583 583 longtext: Longer string.
584 584 shorttext: Shorter string.
585 585 i: Start index of quarter length substring within longtext.
586 586
587 587 Returns:
588 588 Five element Array, containing the prefix of longtext, the suffix of
589 589 longtext, the prefix of shorttext, the suffix of shorttext and the
590 590 common middle. Or None if there was no match.
591 591 """
592 592 seed = longtext[i:i + len(longtext) // 4]
593 593 best_common = ''
594 594 j = shorttext.find(seed)
595 595 while j != -1:
596 596 prefixLength = self.diff_commonPrefix(longtext[i:], shorttext[j:])
597 597 suffixLength = self.diff_commonSuffix(longtext[:i], shorttext[:j])
598 598 if len(best_common) < suffixLength + prefixLength:
599 599 best_common = (shorttext[j - suffixLength:j] +
600 600 shorttext[j:j + prefixLength])
601 601 best_longtext_a = longtext[:i - suffixLength]
602 602 best_longtext_b = longtext[i + prefixLength:]
603 603 best_shorttext_a = shorttext[:j - suffixLength]
604 604 best_shorttext_b = shorttext[j + prefixLength:]
605 605 j = shorttext.find(seed, j + 1)
606 606
607 607 if len(best_common) * 2 >= len(longtext):
608 608 return (best_longtext_a, best_longtext_b,
609 609 best_shorttext_a, best_shorttext_b, best_common)
610 610 else:
611 611 return None
612 612
613 613 # First check if the second quarter is the seed for a half-match.
614 614 hm1 = diff_halfMatchI(longtext, shorttext, (len(longtext) + 3) // 4)
615 615 # Check again based on the third quarter.
616 616 hm2 = diff_halfMatchI(longtext, shorttext, (len(longtext) + 1) // 2)
617 617 if not hm1 and not hm2:
618 618 return None
619 619 elif not hm2:
620 620 hm = hm1
621 621 elif not hm1:
622 622 hm = hm2
623 623 else:
624 624 # Both matched. Select the longest.
625 625 if len(hm1[4]) > len(hm2[4]):
626 626 hm = hm1
627 627 else:
628 628 hm = hm2
629 629
630 630 # A half-match was found, sort out the return data.
631 631 if len(text1) > len(text2):
632 632 (text1_a, text1_b, text2_a, text2_b, mid_common) = hm
633 633 else:
634 634 (text2_a, text2_b, text1_a, text1_b, mid_common) = hm
635 635 return (text1_a, text1_b, text2_a, text2_b, mid_common)
636 636
637 637 def diff_cleanupSemantic(self, diffs):
638 638 """Reduce the number of edits by eliminating semantically trivial
639 639 equalities.
640 640
641 641 Args:
642 642 diffs: Array of diff tuples.
643 643 """
644 644 changes = False
645 645 equalities = [] # Stack of indices where equalities are found.
646 646 lastequality = None # Always equal to diffs[equalities[-1]][1]
647 647 pointer = 0 # Index of current position.
648 648 # Number of chars that changed prior to the equality.
649 649 length_insertions1, length_deletions1 = 0, 0
650 650 # Number of chars that changed after the equality.
651 651 length_insertions2, length_deletions2 = 0, 0
652 652 while pointer < len(diffs):
653 653 if diffs[pointer][0] == self.DIFF_EQUAL: # Equality found.
654 654 equalities.append(pointer)
655 655 length_insertions1, length_insertions2 = length_insertions2, 0
656 656 length_deletions1, length_deletions2 = length_deletions2, 0
657 657 lastequality = diffs[pointer][1]
658 658 else: # An insertion or deletion.
659 659 if diffs[pointer][0] == self.DIFF_INSERT:
660 660 length_insertions2 += len(diffs[pointer][1])
661 661 else:
662 662 length_deletions2 += len(diffs[pointer][1])
663 663 # Eliminate an equality that is smaller or equal to the edits on both
664 664 # sides of it.
665 665 if (lastequality and (len(lastequality) <=
666 666 max(length_insertions1, length_deletions1)) and
667 667 (len(lastequality) <= max(length_insertions2, length_deletions2))):
668 668 # Duplicate record.
669 669 diffs.insert(equalities[-1], (self.DIFF_DELETE, lastequality))
670 670 # Change second copy to insert.
671 671 diffs[equalities[-1] + 1] = (self.DIFF_INSERT,
672 672 diffs[equalities[-1] + 1][1])
673 673 # Throw away the equality we just deleted.
674 674 equalities.pop()
675 675 # Throw away the previous equality (it needs to be reevaluated).
676 676 if len(equalities):
677 677 equalities.pop()
678 678 if len(equalities):
679 679 pointer = equalities[-1]
680 680 else:
681 681 pointer = -1
682 682 # Reset the counters.
683 683 length_insertions1, length_deletions1 = 0, 0
684 684 length_insertions2, length_deletions2 = 0, 0
685 685 lastequality = None
686 686 changes = True
687 687 pointer += 1
688 688
689 689 # Normalize the diff.
690 690 if changes:
691 691 self.diff_cleanupMerge(diffs)
692 692 self.diff_cleanupSemanticLossless(diffs)
693 693
694 694 # Find any overlaps between deletions and insertions.
695 695 # e.g: <del>abcxxx</del><ins>xxxdef</ins>
696 696 # -> <del>abc</del>xxx<ins>def</ins>
697 697 # e.g: <del>xxxabc</del><ins>defxxx</ins>
698 698 # -> <ins>def</ins>xxx<del>abc</del>
699 699 # Only extract an overlap if it is as big as the edit ahead or behind it.
700 700 pointer = 1
701 701 while pointer < len(diffs):
702 702 if (diffs[pointer - 1][0] == self.DIFF_DELETE and
703 703 diffs[pointer][0] == self.DIFF_INSERT):
704 704 deletion = diffs[pointer - 1][1]
705 705 insertion = diffs[pointer][1]
706 706 overlap_length1 = self.diff_commonOverlap(deletion, insertion)
707 707 overlap_length2 = self.diff_commonOverlap(insertion, deletion)
708 708 if overlap_length1 >= overlap_length2:
709 709 if (overlap_length1 >= len(deletion) / 2.0 or
710 710 overlap_length1 >= len(insertion) / 2.0):
711 711 # Overlap found. Insert an equality and trim the surrounding edits.
712 712 diffs.insert(pointer, (self.DIFF_EQUAL,
713 713 insertion[:overlap_length1]))
714 714 diffs[pointer - 1] = (self.DIFF_DELETE,
715 715 deletion[:len(deletion) - overlap_length1])
716 716 diffs[pointer + 1] = (self.DIFF_INSERT,
717 717 insertion[overlap_length1:])
718 718 pointer += 1
719 719 else:
720 720 if (overlap_length2 >= len(deletion) / 2.0 or
721 721 overlap_length2 >= len(insertion) / 2.0):
722 722 # Reverse overlap found.
723 723 # Insert an equality and swap and trim the surrounding edits.
724 724 diffs.insert(pointer, (self.DIFF_EQUAL, deletion[:overlap_length2]))
725 725 diffs[pointer - 1] = (self.DIFF_INSERT,
726 726 insertion[:len(insertion) - overlap_length2])
727 727 diffs[pointer + 1] = (self.DIFF_DELETE, deletion[overlap_length2:])
728 728 pointer += 1
729 729 pointer += 1
730 730 pointer += 1
731 731
732 732 def diff_cleanupSemanticLossless(self, diffs):
733 733 """Look for single edits surrounded on both sides by equalities
734 734 which can be shifted sideways to align the edit to a word boundary.
735 735 e.g: The c<ins>at c</ins>ame. -> The <ins>cat </ins>came.
736 736
737 737 Args:
738 738 diffs: Array of diff tuples.
739 739 """
740 740
741 741 def diff_cleanupSemanticScore(one, two):
742 742 """Given two strings, compute a score representing whether the
743 743 internal boundary falls on logical boundaries.
744 744 Scores range from 6 (best) to 0 (worst).
745 745 Closure, but does not reference any external variables.
746 746
747 747 Args:
748 748 one: First string.
749 749 two: Second string.
750 750
751 751 Returns:
752 752 The score.
753 753 """
754 754 if not one or not two:
755 755 # Edges are the best.
756 756 return 6
757 757
758 758 # Each port of this function behaves slightly differently due to
759 759 # subtle differences in each language's definition of things like
760 760 # 'whitespace'. Since this function's purpose is largely cosmetic,
761 761 # the choice has been made to use each language's native features
762 762 # rather than force total conformity.
763 763 char1 = one[-1]
764 764 char2 = two[0]
765 765 nonAlphaNumeric1 = not char1.isalnum()
766 766 nonAlphaNumeric2 = not char2.isalnum()
767 767 whitespace1 = nonAlphaNumeric1 and char1.isspace()
768 768 whitespace2 = nonAlphaNumeric2 and char2.isspace()
769 769 lineBreak1 = whitespace1 and (char1 == "\r" or char1 == "\n")
770 770 lineBreak2 = whitespace2 and (char2 == "\r" or char2 == "\n")
771 771 blankLine1 = lineBreak1 and self.BLANKLINEEND.search(one)
772 772 blankLine2 = lineBreak2 and self.BLANKLINESTART.match(two)
773 773
774 774 if blankLine1 or blankLine2:
775 775 # Five points for blank lines.
776 776 return 5
777 777 elif lineBreak1 or lineBreak2:
778 778 # Four points for line breaks.
779 779 return 4
780 780 elif nonAlphaNumeric1 and not whitespace1 and whitespace2:
781 781 # Three points for end of sentences.
782 782 return 3
783 783 elif whitespace1 or whitespace2:
784 784 # Two points for whitespace.
785 785 return 2
786 786 elif nonAlphaNumeric1 or nonAlphaNumeric2:
787 787 # One point for non-alphanumeric.
788 788 return 1
789 789 return 0
790 790
791 791 pointer = 1
792 792 # Intentionally ignore the first and last element (don't need checking).
793 793 while pointer < len(diffs) - 1:
794 794 if (diffs[pointer - 1][0] == self.DIFF_EQUAL and
795 795 diffs[pointer + 1][0] == self.DIFF_EQUAL):
796 796 # This is a single edit surrounded by equalities.
797 797 equality1 = diffs[pointer - 1][1]
798 798 edit = diffs[pointer][1]
799 799 equality2 = diffs[pointer + 1][1]
800 800
801 801 # First, shift the edit as far left as possible.
802 802 commonOffset = self.diff_commonSuffix(equality1, edit)
803 803 if commonOffset:
804 804 commonString = edit[-commonOffset:]
805 805 equality1 = equality1[:-commonOffset]
806 806 edit = commonString + edit[:-commonOffset]
807 807 equality2 = commonString + equality2
808 808
809 809 # Second, step character by character right, looking for the best fit.
810 810 bestEquality1 = equality1
811 811 bestEdit = edit
812 812 bestEquality2 = equality2
813 813 bestScore = (diff_cleanupSemanticScore(equality1, edit) +
814 814 diff_cleanupSemanticScore(edit, equality2))
815 815 while edit and equality2 and edit[0] == equality2[0]:
816 816 equality1 += edit[0]
817 817 edit = edit[1:] + equality2[0]
818 818 equality2 = equality2[1:]
819 819 score = (diff_cleanupSemanticScore(equality1, edit) +
820 820 diff_cleanupSemanticScore(edit, equality2))
821 821 # The >= encourages trailing rather than leading whitespace on edits.
822 822 if score >= bestScore:
823 823 bestScore = score
824 824 bestEquality1 = equality1
825 825 bestEdit = edit
826 826 bestEquality2 = equality2
827 827
828 828 if diffs[pointer - 1][1] != bestEquality1:
829 829 # We have an improvement, save it back to the diff.
830 830 if bestEquality1:
831 831 diffs[pointer - 1] = (diffs[pointer - 1][0], bestEquality1)
832 832 else:
833 833 del diffs[pointer - 1]
834 834 pointer -= 1
835 835 diffs[pointer] = (diffs[pointer][0], bestEdit)
836 836 if bestEquality2:
837 837 diffs[pointer + 1] = (diffs[pointer + 1][0], bestEquality2)
838 838 else:
839 839 del diffs[pointer + 1]
840 840 pointer -= 1
841 841 pointer += 1
842 842
843 843 # Define some regex patterns for matching boundaries.
844 844 BLANKLINEEND = re.compile(r"\n\r?\n$");
845 845 BLANKLINESTART = re.compile(r"^\r?\n\r?\n");
846 846
847 847 def diff_cleanupEfficiency(self, diffs):
848 848 """Reduce the number of edits by eliminating operationally trivial
849 849 equalities.
850 850
851 851 Args:
852 852 diffs: Array of diff tuples.
853 853 """
854 854 changes = False
855 855 equalities = [] # Stack of indices where equalities are found.
856 856 lastequality = None # Always equal to diffs[equalities[-1]][1]
857 857 pointer = 0 # Index of current position.
858 858 pre_ins = False # Is there an insertion operation before the last equality.
859 859 pre_del = False # Is there a deletion operation before the last equality.
860 860 post_ins = False # Is there an insertion operation after the last equality.
861 861 post_del = False # Is there a deletion operation after the last equality.
862 862 while pointer < len(diffs):
863 863 if diffs[pointer][0] == self.DIFF_EQUAL: # Equality found.
864 864 if (len(diffs[pointer][1]) < self.Diff_EditCost and
865 865 (post_ins or post_del)):
866 866 # Candidate found.
867 867 equalities.append(pointer)
868 868 pre_ins = post_ins
869 869 pre_del = post_del
870 870 lastequality = diffs[pointer][1]
871 871 else:
872 872 # Not a candidate, and can never become one.
873 873 equalities = []
874 874 lastequality = None
875 875
876 876 post_ins = post_del = False
877 877 else: # An insertion or deletion.
878 878 if diffs[pointer][0] == self.DIFF_DELETE:
879 879 post_del = True
880 880 else:
881 881 post_ins = True
882 882
883 883 # Five types to be split:
884 884 # <ins>A</ins><del>B</del>XY<ins>C</ins><del>D</del>
885 885 # <ins>A</ins>X<ins>C</ins><del>D</del>
886 886 # <ins>A</ins><del>B</del>X<ins>C</ins>
887 887 # <ins>A</del>X<ins>C</ins><del>D</del>
888 888 # <ins>A</ins><del>B</del>X<del>C</del>
889 889
890 890 if lastequality and ((pre_ins and pre_del and post_ins and post_del) or
891 891 ((len(lastequality) < self.Diff_EditCost / 2) and
892 892 (pre_ins + pre_del + post_ins + post_del) == 3)):
893 893 # Duplicate record.
894 894 diffs.insert(equalities[-1], (self.DIFF_DELETE, lastequality))
895 895 # Change second copy to insert.
896 896 diffs[equalities[-1] + 1] = (self.DIFF_INSERT,
897 897 diffs[equalities[-1] + 1][1])
898 898 equalities.pop() # Throw away the equality we just deleted.
899 899 lastequality = None
900 900 if pre_ins and pre_del:
901 901 # No changes made which could affect previous entry, keep going.
902 902 post_ins = post_del = True
903 903 equalities = []
904 904 else:
905 905 if len(equalities):
906 906 equalities.pop() # Throw away the previous equality.
907 907 if len(equalities):
908 908 pointer = equalities[-1]
909 909 else:
910 910 pointer = -1
911 911 post_ins = post_del = False
912 912 changes = True
913 913 pointer += 1
914 914
915 915 if changes:
916 916 self.diff_cleanupMerge(diffs)
917 917
918 918 def diff_cleanupMerge(self, diffs):
919 919 """Reorder and merge like edit sections. Merge equalities.
920 920 Any edit section can move as long as it doesn't cross an equality.
921 921
922 922 Args:
923 923 diffs: Array of diff tuples.
924 924 """
925 925 diffs.append((self.DIFF_EQUAL, '')) # Add a dummy entry at the end.
926 926 pointer = 0
927 927 count_delete = 0
928 928 count_insert = 0
929 929 text_delete = ''
930 930 text_insert = ''
931 931 while pointer < len(diffs):
932 932 if diffs[pointer][0] == self.DIFF_INSERT:
933 933 count_insert += 1
934 934 text_insert += diffs[pointer][1]
935 935 pointer += 1
936 936 elif diffs[pointer][0] == self.DIFF_DELETE:
937 937 count_delete += 1
938 938 text_delete += diffs[pointer][1]
939 939 pointer += 1
940 940 elif diffs[pointer][0] == self.DIFF_EQUAL:
941 941 # Upon reaching an equality, check for prior redundancies.
942 942 if count_delete + count_insert > 1:
943 943 if count_delete != 0 and count_insert != 0:
944 944 # Factor out any common prefixies.
945 945 commonlength = self.diff_commonPrefix(text_insert, text_delete)
946 946 if commonlength != 0:
947 947 x = pointer - count_delete - count_insert - 1
948 948 if x >= 0 and diffs[x][0] == self.DIFF_EQUAL:
949 949 diffs[x] = (diffs[x][0], diffs[x][1] +
950 950 text_insert[:commonlength])
951 951 else:
952 952 diffs.insert(0, (self.DIFF_EQUAL, text_insert[:commonlength]))
953 953 pointer += 1
954 954 text_insert = text_insert[commonlength:]
955 955 text_delete = text_delete[commonlength:]
956 956 # Factor out any common suffixies.
957 957 commonlength = self.diff_commonSuffix(text_insert, text_delete)
958 958 if commonlength != 0:
959 959 diffs[pointer] = (diffs[pointer][0], text_insert[-commonlength:] +
960 960 diffs[pointer][1])
961 961 text_insert = text_insert[:-commonlength]
962 962 text_delete = text_delete[:-commonlength]
963 963 # Delete the offending records and add the merged ones.
964 964 if count_delete == 0:
965 965 diffs[pointer - count_insert : pointer] = [
966 966 (self.DIFF_INSERT, text_insert)]
967 967 elif count_insert == 0:
968 968 diffs[pointer - count_delete : pointer] = [
969 969 (self.DIFF_DELETE, text_delete)]
970 970 else:
971 971 diffs[pointer - count_delete - count_insert : pointer] = [
972 972 (self.DIFF_DELETE, text_delete),
973 973 (self.DIFF_INSERT, text_insert)]
974 974 pointer = pointer - count_delete - count_insert + 1
975 975 if count_delete != 0:
976 976 pointer += 1
977 977 if count_insert != 0:
978 978 pointer += 1
979 979 elif pointer != 0 and diffs[pointer - 1][0] == self.DIFF_EQUAL:
980 980 # Merge this equality with the previous one.
981 981 diffs[pointer - 1] = (diffs[pointer - 1][0],
982 982 diffs[pointer - 1][1] + diffs[pointer][1])
983 983 del diffs[pointer]
984 984 else:
985 985 pointer += 1
986 986
987 987 count_insert = 0
988 988 count_delete = 0
989 989 text_delete = ''
990 990 text_insert = ''
991 991
992 992 if diffs[-1][1] == '':
993 993 diffs.pop() # Remove the dummy entry at the end.
994 994
995 995 # Second pass: look for single edits surrounded on both sides by equalities
996 996 # which can be shifted sideways to eliminate an equality.
997 997 # e.g: A<ins>BA</ins>C -> <ins>AB</ins>AC
998 998 changes = False
999 999 pointer = 1
1000 1000 # Intentionally ignore the first and last element (don't need checking).
1001 1001 while pointer < len(diffs) - 1:
1002 1002 if (diffs[pointer - 1][0] == self.DIFF_EQUAL and
1003 1003 diffs[pointer + 1][0] == self.DIFF_EQUAL):
1004 1004 # This is a single edit surrounded by equalities.
1005 1005 if diffs[pointer][1].endswith(diffs[pointer - 1][1]):
1006 1006 # Shift the edit over the previous equality.
1007 1007 diffs[pointer] = (diffs[pointer][0],
1008 1008 diffs[pointer - 1][1] +
1009 1009 diffs[pointer][1][:-len(diffs[pointer - 1][1])])
1010 1010 diffs[pointer + 1] = (diffs[pointer + 1][0],
1011 1011 diffs[pointer - 1][1] + diffs[pointer + 1][1])
1012 1012 del diffs[pointer - 1]
1013 1013 changes = True
1014 1014 elif diffs[pointer][1].startswith(diffs[pointer + 1][1]):
1015 1015 # Shift the edit over the next equality.
1016 1016 diffs[pointer - 1] = (diffs[pointer - 1][0],
1017 1017 diffs[pointer - 1][1] + diffs[pointer + 1][1])
1018 1018 diffs[pointer] = (diffs[pointer][0],
1019 1019 diffs[pointer][1][len(diffs[pointer + 1][1]):] +
1020 1020 diffs[pointer + 1][1])
1021 1021 del diffs[pointer + 1]
1022 1022 changes = True
1023 1023 pointer += 1
1024 1024
1025 1025 # If shifts were made, the diff needs reordering and another shift sweep.
1026 1026 if changes:
1027 1027 self.diff_cleanupMerge(diffs)
1028 1028
1029 1029 def diff_xIndex(self, diffs, loc):
1030 1030 """loc is a location in text1, compute and return the equivalent location
1031 1031 in text2. e.g. "The cat" vs "The big cat", 1->1, 5->8
1032 1032
1033 1033 Args:
1034 1034 diffs: Array of diff tuples.
1035 1035 loc: Location within text1.
1036 1036
1037 1037 Returns:
1038 1038 Location within text2.
1039 1039 """
1040 1040 chars1 = 0
1041 1041 chars2 = 0
1042 1042 last_chars1 = 0
1043 1043 last_chars2 = 0
1044 1044 for x in xrange(len(diffs)):
1045 1045 (op, text) = diffs[x]
1046 1046 if op != self.DIFF_INSERT: # Equality or deletion.
1047 1047 chars1 += len(text)
1048 1048 if op != self.DIFF_DELETE: # Equality or insertion.
1049 1049 chars2 += len(text)
1050 1050 if chars1 > loc: # Overshot the location.
1051 1051 break
1052 1052 last_chars1 = chars1
1053 1053 last_chars2 = chars2
1054 1054
1055 1055 if len(diffs) != x and diffs[x][0] == self.DIFF_DELETE:
1056 1056 # The location was deleted.
1057 1057 return last_chars2
1058 1058 # Add the remaining len(character).
1059 1059 return last_chars2 + (loc - last_chars1)
1060 1060
1061 1061 def diff_prettyHtml(self, diffs):
1062 1062 """Convert a diff array into a pretty HTML report.
1063 1063
1064 1064 Args:
1065 1065 diffs: Array of diff tuples.
1066 1066
1067 1067 Returns:
1068 1068 HTML representation.
1069 1069 """
1070 1070 html = []
1071 1071 for (op, data) in diffs:
1072 1072 text = (data.replace("&", "&amp;").replace("<", "&lt;")
1073 1073 .replace(">", "&gt;").replace("\n", "&para;<br>"))
1074 1074 if op == self.DIFF_INSERT:
1075 1075 html.append("<ins style=\"background:#e6ffe6;\">%s</ins>" % text)
1076 1076 elif op == self.DIFF_DELETE:
1077 1077 html.append("<del style=\"background:#ffe6e6;\">%s</del>" % text)
1078 1078 elif op == self.DIFF_EQUAL:
1079 1079 html.append("<span>%s</span>" % text)
1080 1080 return "".join(html)
1081 1081
1082 1082 def diff_text1(self, diffs):
1083 1083 """Compute and return the source text (all equalities and deletions).
1084 1084
1085 1085 Args:
1086 1086 diffs: Array of diff tuples.
1087 1087
1088 1088 Returns:
1089 1089 Source text.
1090 1090 """
1091 1091 text = []
1092 1092 for (op, data) in diffs:
1093 1093 if op != self.DIFF_INSERT:
1094 1094 text.append(data)
1095 1095 return "".join(text)
1096 1096
1097 1097 def diff_text2(self, diffs):
1098 1098 """Compute and return the destination text (all equalities and insertions).
1099 1099
1100 1100 Args:
1101 1101 diffs: Array of diff tuples.
1102 1102
1103 1103 Returns:
1104 1104 Destination text.
1105 1105 """
1106 1106 text = []
1107 1107 for (op, data) in diffs:
1108 1108 if op != self.DIFF_DELETE:
1109 1109 text.append(data)
1110 1110 return "".join(text)
1111 1111
1112 1112 def diff_levenshtein(self, diffs):
1113 1113 """Compute the Levenshtein distance; the number of inserted, deleted or
1114 1114 substituted characters.
1115 1115
1116 1116 Args:
1117 1117 diffs: Array of diff tuples.
1118 1118
1119 1119 Returns:
1120 1120 Number of changes.
1121 1121 """
1122 1122 levenshtein = 0
1123 1123 insertions = 0
1124 1124 deletions = 0
1125 1125 for (op, data) in diffs:
1126 1126 if op == self.DIFF_INSERT:
1127 1127 insertions += len(data)
1128 1128 elif op == self.DIFF_DELETE:
1129 1129 deletions += len(data)
1130 1130 elif op == self.DIFF_EQUAL:
1131 1131 # A deletion and an insertion is one substitution.
1132 1132 levenshtein += max(insertions, deletions)
1133 1133 insertions = 0
1134 1134 deletions = 0
1135 1135 levenshtein += max(insertions, deletions)
1136 1136 return levenshtein
1137 1137
1138 1138 def diff_toDelta(self, diffs):
1139 1139 """Crush the diff into an encoded string which describes the operations
1140 1140 required to transform text1 into text2.
1141 1141 E.g. =3\t-2\t+ing -> Keep 3 chars, delete 2 chars, insert 'ing'.
1142 1142 Operations are tab-separated. Inserted text is escaped using %xx notation.
1143 1143
1144 1144 Args:
1145 1145 diffs: Array of diff tuples.
1146 1146
1147 1147 Returns:
1148 1148 Delta text.
1149 1149 """
1150 1150 text = []
1151 1151 for (op, data) in diffs:
1152 1152 if op == self.DIFF_INSERT:
1153 1153 # High ascii will raise UnicodeDecodeError. Use Unicode instead.
1154 1154 data = data.encode("utf-8")
1155 1155 text.append("+" + urllib.quote(data, "!~*'();/?:@&=+$,# "))
1156 1156 elif op == self.DIFF_DELETE:
1157 1157 text.append("-%d" % len(data))
1158 1158 elif op == self.DIFF_EQUAL:
1159 1159 text.append("=%d" % len(data))
1160 1160 return "\t".join(text)
1161 1161
1162 1162 def diff_fromDelta(self, text1, delta):
1163 1163 """Given the original text1, and an encoded string which describes the
1164 1164 operations required to transform text1 into text2, compute the full diff.
1165 1165
1166 1166 Args:
1167 1167 text1: Source string for the diff.
1168 1168 delta: Delta text.
1169 1169
1170 1170 Returns:
1171 1171 Array of diff tuples.
1172 1172
1173 1173 Raises:
1174 1174 ValueError: If invalid input.
1175 1175 """
1176 1176 if type(delta) == unicode:
1177 1177 # Deltas should be composed of a subset of ascii chars, Unicode not
1178 1178 # required. If this encode raises UnicodeEncodeError, delta is invalid.
1179 1179 delta = delta.encode("ascii")
1180 1180 diffs = []
1181 1181 pointer = 0 # Cursor in text1
1182 1182 tokens = delta.split("\t")
1183 1183 for token in tokens:
1184 1184 if token == "":
1185 1185 # Blank tokens are ok (from a trailing \t).
1186 1186 continue
1187 1187 # Each token begins with a one character parameter which specifies the
1188 1188 # operation of this token (delete, insert, equality).
1189 1189 param = token[1:]
1190 1190 if token[0] == "+":
1191 1191 param = urllib.unquote(param).decode("utf-8")
1192 1192 diffs.append((self.DIFF_INSERT, param))
1193 1193 elif token[0] == "-" or token[0] == "=":
1194 1194 try:
1195 1195 n = int(param)
1196 1196 except ValueError:
1197 1197 raise ValueError("Invalid number in diff_fromDelta: " + param)
1198 1198 if n < 0:
1199 1199 raise ValueError("Negative number in diff_fromDelta: " + param)
1200 1200 text = text1[pointer : pointer + n]
1201 1201 pointer += n
1202 1202 if token[0] == "=":
1203 1203 diffs.append((self.DIFF_EQUAL, text))
1204 1204 else:
1205 1205 diffs.append((self.DIFF_DELETE, text))
1206 1206 else:
1207 1207 # Anything else is an error.
1208 1208 raise ValueError("Invalid diff operation in diff_fromDelta: " +
1209 1209 token[0])
1210 1210 if pointer != len(text1):
1211 1211 raise ValueError(
1212 1212 "Delta length (%d) does not equal source text length (%d)." %
1213 1213 (pointer, len(text1)))
1214 1214 return diffs
1215 1215
1216 1216 # MATCH FUNCTIONS
1217 1217
1218 1218 def match_main(self, text, pattern, loc):
1219 1219 """Locate the best instance of 'pattern' in 'text' near 'loc'.
1220 1220
1221 1221 Args:
1222 1222 text: The text to search.
1223 1223 pattern: The pattern to search for.
1224 1224 loc: The location to search around.
1225 1225
1226 1226 Returns:
1227 1227 Best match index or -1.
1228 1228 """
1229 1229 # Check for null inputs.
1230 if text == None or pattern == None:
1230 if text is None or pattern is None:
1231 1231 raise ValueError("Null inputs. (match_main)")
1232 1232
1233 1233 loc = max(0, min(loc, len(text)))
1234 1234 if text == pattern:
1235 1235 # Shortcut (potentially not guaranteed by the algorithm)
1236 1236 return 0
1237 1237 elif not text:
1238 1238 # Nothing to match.
1239 1239 return -1
1240 1240 elif text[loc:loc + len(pattern)] == pattern:
1241 1241 # Perfect match at the perfect spot! (Includes case of null pattern)
1242 1242 return loc
1243 1243 else:
1244 1244 # Do a fuzzy compare.
1245 1245 match = self.match_bitap(text, pattern, loc)
1246 1246 return match
1247 1247
1248 1248 def match_bitap(self, text, pattern, loc):
1249 1249 """Locate the best instance of 'pattern' in 'text' near 'loc' using the
1250 1250 Bitap algorithm.
1251 1251
1252 1252 Args:
1253 1253 text: The text to search.
1254 1254 pattern: The pattern to search for.
1255 1255 loc: The location to search around.
1256 1256
1257 1257 Returns:
1258 1258 Best match index or -1.
1259 1259 """
1260 1260 # Python doesn't have a maxint limit, so ignore this check.
1261 1261 #if self.Match_MaxBits != 0 and len(pattern) > self.Match_MaxBits:
1262 1262 # raise ValueError("Pattern too long for this application.")
1263 1263
1264 1264 # Initialise the alphabet.
1265 1265 s = self.match_alphabet(pattern)
1266 1266
1267 1267 def match_bitapScore(e, x):
1268 1268 """Compute and return the score for a match with e errors and x location.
1269 1269 Accesses loc and pattern through being a closure.
1270 1270
1271 1271 Args:
1272 1272 e: Number of errors in match.
1273 1273 x: Location of match.
1274 1274
1275 1275 Returns:
1276 1276 Overall score for match (0.0 = good, 1.0 = bad).
1277 1277 """
1278 1278 accuracy = float(e) / len(pattern)
1279 1279 proximity = abs(loc - x)
1280 1280 if not self.Match_Distance:
1281 1281 # Dodge divide by zero error.
1282 1282 return proximity and 1.0 or accuracy
1283 1283 return accuracy + (proximity / float(self.Match_Distance))
1284 1284
1285 1285 # Highest score beyond which we give up.
1286 1286 score_threshold = self.Match_Threshold
1287 1287 # Is there a nearby exact match? (speedup)
1288 1288 best_loc = text.find(pattern, loc)
1289 1289 if best_loc != -1:
1290 1290 score_threshold = min(match_bitapScore(0, best_loc), score_threshold)
1291 1291 # What about in the other direction? (speedup)
1292 1292 best_loc = text.rfind(pattern, loc + len(pattern))
1293 1293 if best_loc != -1:
1294 1294 score_threshold = min(match_bitapScore(0, best_loc), score_threshold)
1295 1295
1296 1296 # Initialise the bit arrays.
1297 1297 matchmask = 1 << (len(pattern) - 1)
1298 1298 best_loc = -1
1299 1299
1300 1300 bin_max = len(pattern) + len(text)
1301 1301 # Empty initialization added to appease pychecker.
1302 1302 last_rd = None
1303 1303 for d in xrange(len(pattern)):
1304 1304 # Scan for the best match each iteration allows for one more error.
1305 1305 # Run a binary search to determine how far from 'loc' we can stray at
1306 1306 # this error level.
1307 1307 bin_min = 0
1308 1308 bin_mid = bin_max
1309 1309 while bin_min < bin_mid:
1310 1310 if match_bitapScore(d, loc + bin_mid) <= score_threshold:
1311 1311 bin_min = bin_mid
1312 1312 else:
1313 1313 bin_max = bin_mid
1314 1314 bin_mid = (bin_max - bin_min) // 2 + bin_min
1315 1315
1316 1316 # Use the result from this iteration as the maximum for the next.
1317 1317 bin_max = bin_mid
1318 1318 start = max(1, loc - bin_mid + 1)
1319 1319 finish = min(loc + bin_mid, len(text)) + len(pattern)
1320 1320
1321 1321 rd = [0] * (finish + 2)
1322 1322 rd[finish + 1] = (1 << d) - 1
1323 1323 for j in xrange(finish, start - 1, -1):
1324 1324 if len(text) <= j - 1:
1325 1325 # Out of range.
1326 1326 charMatch = 0
1327 1327 else:
1328 1328 charMatch = s.get(text[j - 1], 0)
1329 1329 if d == 0: # First pass: exact match.
1330 1330 rd[j] = ((rd[j + 1] << 1) | 1) & charMatch
1331 1331 else: # Subsequent passes: fuzzy match.
1332 1332 rd[j] = (((rd[j + 1] << 1) | 1) & charMatch) | (
1333 1333 ((last_rd[j + 1] | last_rd[j]) << 1) | 1) | last_rd[j + 1]
1334 1334 if rd[j] & matchmask:
1335 1335 score = match_bitapScore(d, j - 1)
1336 1336 # This match will almost certainly be better than any existing match.
1337 1337 # But check anyway.
1338 1338 if score <= score_threshold:
1339 1339 # Told you so.
1340 1340 score_threshold = score
1341 1341 best_loc = j - 1
1342 1342 if best_loc > loc:
1343 1343 # When passing loc, don't exceed our current distance from loc.
1344 1344 start = max(1, 2 * loc - best_loc)
1345 1345 else:
1346 1346 # Already passed loc, downhill from here on in.
1347 1347 break
1348 1348 # No hope for a (better) match at greater error levels.
1349 1349 if match_bitapScore(d + 1, loc) > score_threshold:
1350 1350 break
1351 1351 last_rd = rd
1352 1352 return best_loc
1353 1353
1354 1354 def match_alphabet(self, pattern):
1355 1355 """Initialise the alphabet for the Bitap algorithm.
1356 1356
1357 1357 Args:
1358 1358 pattern: The text to encode.
1359 1359
1360 1360 Returns:
1361 1361 Hash of character locations.
1362 1362 """
1363 1363 s = {}
1364 1364 for char in pattern:
1365 1365 s[char] = 0
1366 1366 for i in xrange(len(pattern)):
1367 1367 s[pattern[i]] |= 1 << (len(pattern) - i - 1)
1368 1368 return s
1369 1369
1370 1370 # PATCH FUNCTIONS
1371 1371
1372 1372 def patch_addContext(self, patch, text):
1373 1373 """Increase the context until it is unique,
1374 1374 but don't let the pattern expand beyond Match_MaxBits.
1375 1375
1376 1376 Args:
1377 1377 patch: The patch to grow.
1378 1378 text: Source text.
1379 1379 """
1380 1380 if len(text) == 0:
1381 1381 return
1382 1382 pattern = text[patch.start2 : patch.start2 + patch.length1]
1383 1383 padding = 0
1384 1384
1385 1385 # Look for the first and last matches of pattern in text. If two different
1386 1386 # matches are found, increase the pattern length.
1387 1387 while (text.find(pattern) != text.rfind(pattern) and (self.Match_MaxBits ==
1388 1388 0 or len(pattern) < self.Match_MaxBits - self.Patch_Margin -
1389 1389 self.Patch_Margin)):
1390 1390 padding += self.Patch_Margin
1391 1391 pattern = text[max(0, patch.start2 - padding) :
1392 1392 patch.start2 + patch.length1 + padding]
1393 1393 # Add one chunk for good luck.
1394 1394 padding += self.Patch_Margin
1395 1395
1396 1396 # Add the prefix.
1397 1397 prefix = text[max(0, patch.start2 - padding) : patch.start2]
1398 1398 if prefix:
1399 1399 patch.diffs[:0] = [(self.DIFF_EQUAL, prefix)]
1400 1400 # Add the suffix.
1401 1401 suffix = text[patch.start2 + patch.length1 :
1402 1402 patch.start2 + patch.length1 + padding]
1403 1403 if suffix:
1404 1404 patch.diffs.append((self.DIFF_EQUAL, suffix))
1405 1405
1406 1406 # Roll back the start points.
1407 1407 patch.start1 -= len(prefix)
1408 1408 patch.start2 -= len(prefix)
1409 1409 # Extend lengths.
1410 1410 patch.length1 += len(prefix) + len(suffix)
1411 1411 patch.length2 += len(prefix) + len(suffix)
1412 1412
1413 1413 def patch_make(self, a, b=None, c=None):
1414 1414 """Compute a list of patches to turn text1 into text2.
1415 1415 Use diffs if provided, otherwise compute it ourselves.
1416 1416 There are four ways to call this function, depending on what data is
1417 1417 available to the caller:
1418 1418 Method 1:
1419 1419 a = text1, b = text2
1420 1420 Method 2:
1421 1421 a = diffs
1422 1422 Method 3 (optimal):
1423 1423 a = text1, b = diffs
1424 1424 Method 4 (deprecated, use method 3):
1425 1425 a = text1, b = text2, c = diffs
1426 1426
1427 1427 Args:
1428 1428 a: text1 (methods 1,3,4) or Array of diff tuples for text1 to
1429 1429 text2 (method 2).
1430 1430 b: text2 (methods 1,4) or Array of diff tuples for text1 to
1431 1431 text2 (method 3) or undefined (method 2).
1432 1432 c: Array of diff tuples for text1 to text2 (method 4) or
1433 1433 undefined (methods 1,2,3).
1434 1434
1435 1435 Returns:
1436 1436 Array of Patch objects.
1437 1437 """
1438 1438 text1 = None
1439 1439 diffs = None
1440 1440 # Note that texts may arrive as 'str' or 'unicode'.
1441 1441 if isinstance(a, basestring) and isinstance(b, basestring) and c is None:
1442 1442 # Method 1: text1, text2
1443 1443 # Compute diffs from text1 and text2.
1444 1444 text1 = a
1445 1445 diffs = self.diff_main(text1, b, True)
1446 1446 if len(diffs) > 2:
1447 1447 self.diff_cleanupSemantic(diffs)
1448 1448 self.diff_cleanupEfficiency(diffs)
1449 1449 elif isinstance(a, list) and b is None and c is None:
1450 1450 # Method 2: diffs
1451 1451 # Compute text1 from diffs.
1452 1452 diffs = a
1453 1453 text1 = self.diff_text1(diffs)
1454 1454 elif isinstance(a, basestring) and isinstance(b, list) and c is None:
1455 1455 # Method 3: text1, diffs
1456 1456 text1 = a
1457 1457 diffs = b
1458 1458 elif (isinstance(a, basestring) and isinstance(b, basestring) and
1459 1459 isinstance(c, list)):
1460 1460 # Method 4: text1, text2, diffs
1461 1461 # text2 is not used.
1462 1462 text1 = a
1463 1463 diffs = c
1464 1464 else:
1465 1465 raise ValueError("Unknown call format to patch_make.")
1466 1466
1467 1467 if not diffs:
1468 1468 return [] # Get rid of the None case.
1469 1469 patches = []
1470 1470 patch = patch_obj()
1471 1471 char_count1 = 0 # Number of characters into the text1 string.
1472 1472 char_count2 = 0 # Number of characters into the text2 string.
1473 1473 prepatch_text = text1 # Recreate the patches to determine context info.
1474 1474 postpatch_text = text1
1475 1475 for x in xrange(len(diffs)):
1476 1476 (diff_type, diff_text) = diffs[x]
1477 1477 if len(patch.diffs) == 0 and diff_type != self.DIFF_EQUAL:
1478 1478 # A new patch starts here.
1479 1479 patch.start1 = char_count1
1480 1480 patch.start2 = char_count2
1481 1481 if diff_type == self.DIFF_INSERT:
1482 1482 # Insertion
1483 1483 patch.diffs.append(diffs[x])
1484 1484 patch.length2 += len(diff_text)
1485 1485 postpatch_text = (postpatch_text[:char_count2] + diff_text +
1486 1486 postpatch_text[char_count2:])
1487 1487 elif diff_type == self.DIFF_DELETE:
1488 1488 # Deletion.
1489 1489 patch.length1 += len(diff_text)
1490 1490 patch.diffs.append(diffs[x])
1491 1491 postpatch_text = (postpatch_text[:char_count2] +
1492 1492 postpatch_text[char_count2 + len(diff_text):])
1493 1493 elif (diff_type == self.DIFF_EQUAL and
1494 1494 len(diff_text) <= 2 * self.Patch_Margin and
1495 1495 len(patch.diffs) != 0 and len(diffs) != x + 1):
1496 1496 # Small equality inside a patch.
1497 1497 patch.diffs.append(diffs[x])
1498 1498 patch.length1 += len(diff_text)
1499 1499 patch.length2 += len(diff_text)
1500 1500
1501 1501 if (diff_type == self.DIFF_EQUAL and
1502 1502 len(diff_text) >= 2 * self.Patch_Margin):
1503 1503 # Time for a new patch.
1504 1504 if len(patch.diffs) != 0:
1505 1505 self.patch_addContext(patch, prepatch_text)
1506 1506 patches.append(patch)
1507 1507 patch = patch_obj()
1508 1508 # Unlike Unidiff, our patch lists have a rolling context.
1509 1509 # http://code.google.com/p/google-diff-match-patch/wiki/Unidiff
1510 1510 # Update prepatch text & pos to reflect the application of the
1511 1511 # just completed patch.
1512 1512 prepatch_text = postpatch_text
1513 1513 char_count1 = char_count2
1514 1514
1515 1515 # Update the current character count.
1516 1516 if diff_type != self.DIFF_INSERT:
1517 1517 char_count1 += len(diff_text)
1518 1518 if diff_type != self.DIFF_DELETE:
1519 1519 char_count2 += len(diff_text)
1520 1520
1521 1521 # Pick up the leftover patch if not empty.
1522 1522 if len(patch.diffs) != 0:
1523 1523 self.patch_addContext(patch, prepatch_text)
1524 1524 patches.append(patch)
1525 1525 return patches
1526 1526
1527 1527 def patch_deepCopy(self, patches):
1528 1528 """Given an array of patches, return another array that is identical.
1529 1529
1530 1530 Args:
1531 1531 patches: Array of Patch objects.
1532 1532
1533 1533 Returns:
1534 1534 Array of Patch objects.
1535 1535 """
1536 1536 patchesCopy = []
1537 1537 for patch in patches:
1538 1538 patchCopy = patch_obj()
1539 1539 # No need to deep copy the tuples since they are immutable.
1540 1540 patchCopy.diffs = patch.diffs[:]
1541 1541 patchCopy.start1 = patch.start1
1542 1542 patchCopy.start2 = patch.start2
1543 1543 patchCopy.length1 = patch.length1
1544 1544 patchCopy.length2 = patch.length2
1545 1545 patchesCopy.append(patchCopy)
1546 1546 return patchesCopy
1547 1547
1548 1548 def patch_apply(self, patches, text):
1549 1549 """Merge a set of patches onto the text. Return a patched text, as well
1550 1550 as a list of true/false values indicating which patches were applied.
1551 1551
1552 1552 Args:
1553 1553 patches: Array of Patch objects.
1554 1554 text: Old text.
1555 1555
1556 1556 Returns:
1557 1557 Two element Array, containing the new text and an array of boolean values.
1558 1558 """
1559 1559 if not patches:
1560 1560 return (text, [])
1561 1561
1562 1562 # Deep copy the patches so that no changes are made to originals.
1563 1563 patches = self.patch_deepCopy(patches)
1564 1564
1565 1565 nullPadding = self.patch_addPadding(patches)
1566 1566 text = nullPadding + text + nullPadding
1567 1567 self.patch_splitMax(patches)
1568 1568
1569 1569 # delta keeps track of the offset between the expected and actual location
1570 1570 # of the previous patch. If there are patches expected at positions 10 and
1571 1571 # 20, but the first patch was found at 12, delta is 2 and the second patch
1572 1572 # has an effective expected position of 22.
1573 1573 delta = 0
1574 1574 results = []
1575 1575 for patch in patches:
1576 1576 expected_loc = patch.start2 + delta
1577 1577 text1 = self.diff_text1(patch.diffs)
1578 1578 end_loc = -1
1579 1579 if len(text1) > self.Match_MaxBits:
1580 1580 # patch_splitMax will only provide an oversized pattern in the case of
1581 1581 # a monster delete.
1582 1582 start_loc = self.match_main(text, text1[:self.Match_MaxBits],
1583 1583 expected_loc)
1584 1584 if start_loc != -1:
1585 1585 end_loc = self.match_main(text, text1[-self.Match_MaxBits:],
1586 1586 expected_loc + len(text1) - self.Match_MaxBits)
1587 1587 if end_loc == -1 or start_loc >= end_loc:
1588 1588 # Can't find valid trailing context. Drop this patch.
1589 1589 start_loc = -1
1590 1590 else:
1591 1591 start_loc = self.match_main(text, text1, expected_loc)
1592 1592 if start_loc == -1:
1593 1593 # No match found. :(
1594 1594 results.append(False)
1595 1595 # Subtract the delta for this failed patch from subsequent patches.
1596 1596 delta -= patch.length2 - patch.length1
1597 1597 else:
1598 1598 # Found a match. :)
1599 1599 results.append(True)
1600 1600 delta = start_loc - expected_loc
1601 1601 if end_loc == -1:
1602 1602 text2 = text[start_loc : start_loc + len(text1)]
1603 1603 else:
1604 1604 text2 = text[start_loc : end_loc + self.Match_MaxBits]
1605 1605 if text1 == text2:
1606 1606 # Perfect match, just shove the replacement text in.
1607 1607 text = (text[:start_loc] + self.diff_text2(patch.diffs) +
1608 1608 text[start_loc + len(text1):])
1609 1609 else:
1610 1610 # Imperfect match.
1611 1611 # Run a diff to get a framework of equivalent indices.
1612 1612 diffs = self.diff_main(text1, text2, False)
1613 1613 if (len(text1) > self.Match_MaxBits and
1614 1614 self.diff_levenshtein(diffs) / float(len(text1)) >
1615 1615 self.Patch_DeleteThreshold):
1616 1616 # The end points match, but the content is unacceptably bad.
1617 1617 results[-1] = False
1618 1618 else:
1619 1619 self.diff_cleanupSemanticLossless(diffs)
1620 1620 index1 = 0
1621 1621 for (op, data) in patch.diffs:
1622 1622 if op != self.DIFF_EQUAL:
1623 1623 index2 = self.diff_xIndex(diffs, index1)
1624 1624 if op == self.DIFF_INSERT: # Insertion
1625 1625 text = text[:start_loc + index2] + data + text[start_loc +
1626 1626 index2:]
1627 1627 elif op == self.DIFF_DELETE: # Deletion
1628 1628 text = text[:start_loc + index2] + text[start_loc +
1629 1629 self.diff_xIndex(diffs, index1 + len(data)):]
1630 1630 if op != self.DIFF_DELETE:
1631 1631 index1 += len(data)
1632 1632 # Strip the padding off.
1633 1633 text = text[len(nullPadding):-len(nullPadding)]
1634 1634 return (text, results)
1635 1635
1636 1636 def patch_addPadding(self, patches):
1637 1637 """Add some padding on text start and end so that edges can match
1638 1638 something. Intended to be called only from within patch_apply.
1639 1639
1640 1640 Args:
1641 1641 patches: Array of Patch objects.
1642 1642
1643 1643 Returns:
1644 1644 The padding string added to each side.
1645 1645 """
1646 1646 paddingLength = self.Patch_Margin
1647 1647 nullPadding = ""
1648 1648 for x in xrange(1, paddingLength + 1):
1649 1649 nullPadding += chr(x)
1650 1650
1651 1651 # Bump all the patches forward.
1652 1652 for patch in patches:
1653 1653 patch.start1 += paddingLength
1654 1654 patch.start2 += paddingLength
1655 1655
1656 1656 # Add some padding on start of first diff.
1657 1657 patch = patches[0]
1658 1658 diffs = patch.diffs
1659 1659 if not diffs or diffs[0][0] != self.DIFF_EQUAL:
1660 1660 # Add nullPadding equality.
1661 1661 diffs.insert(0, (self.DIFF_EQUAL, nullPadding))
1662 1662 patch.start1 -= paddingLength # Should be 0.
1663 1663 patch.start2 -= paddingLength # Should be 0.
1664 1664 patch.length1 += paddingLength
1665 1665 patch.length2 += paddingLength
1666 1666 elif paddingLength > len(diffs[0][1]):
1667 1667 # Grow first equality.
1668 1668 extraLength = paddingLength - len(diffs[0][1])
1669 1669 newText = nullPadding[len(diffs[0][1]):] + diffs[0][1]
1670 1670 diffs[0] = (diffs[0][0], newText)
1671 1671 patch.start1 -= extraLength
1672 1672 patch.start2 -= extraLength
1673 1673 patch.length1 += extraLength
1674 1674 patch.length2 += extraLength
1675 1675
1676 1676 # Add some padding on end of last diff.
1677 1677 patch = patches[-1]
1678 1678 diffs = patch.diffs
1679 1679 if not diffs or diffs[-1][0] != self.DIFF_EQUAL:
1680 1680 # Add nullPadding equality.
1681 1681 diffs.append((self.DIFF_EQUAL, nullPadding))
1682 1682 patch.length1 += paddingLength
1683 1683 patch.length2 += paddingLength
1684 1684 elif paddingLength > len(diffs[-1][1]):
1685 1685 # Grow last equality.
1686 1686 extraLength = paddingLength - len(diffs[-1][1])
1687 1687 newText = diffs[-1][1] + nullPadding[:extraLength]
1688 1688 diffs[-1] = (diffs[-1][0], newText)
1689 1689 patch.length1 += extraLength
1690 1690 patch.length2 += extraLength
1691 1691
1692 1692 return nullPadding
1693 1693
1694 1694 def patch_splitMax(self, patches):
1695 1695 """Look through the patches and break up any which are longer than the
1696 1696 maximum limit of the match algorithm.
1697 1697 Intended to be called only from within patch_apply.
1698 1698
1699 1699 Args:
1700 1700 patches: Array of Patch objects.
1701 1701 """
1702 1702 patch_size = self.Match_MaxBits
1703 1703 if patch_size == 0:
1704 1704 # Python has the option of not splitting strings due to its ability
1705 1705 # to handle integers of arbitrary precision.
1706 1706 return
1707 1707 for x in xrange(len(patches)):
1708 1708 if patches[x].length1 <= patch_size:
1709 1709 continue
1710 1710 bigpatch = patches[x]
1711 1711 # Remove the big old patch.
1712 1712 del patches[x]
1713 1713 x -= 1
1714 1714 start1 = bigpatch.start1
1715 1715 start2 = bigpatch.start2
1716 1716 precontext = ''
1717 1717 while len(bigpatch.diffs) != 0:
1718 1718 # Create one of several smaller patches.
1719 1719 patch = patch_obj()
1720 1720 empty = True
1721 1721 patch.start1 = start1 - len(precontext)
1722 1722 patch.start2 = start2 - len(precontext)
1723 1723 if precontext:
1724 1724 patch.length1 = patch.length2 = len(precontext)
1725 1725 patch.diffs.append((self.DIFF_EQUAL, precontext))
1726 1726
1727 1727 while (len(bigpatch.diffs) != 0 and
1728 1728 patch.length1 < patch_size - self.Patch_Margin):
1729 1729 (diff_type, diff_text) = bigpatch.diffs[0]
1730 1730 if diff_type == self.DIFF_INSERT:
1731 1731 # Insertions are harmless.
1732 1732 patch.length2 += len(diff_text)
1733 1733 start2 += len(diff_text)
1734 1734 patch.diffs.append(bigpatch.diffs.pop(0))
1735 1735 empty = False
1736 1736 elif (diff_type == self.DIFF_DELETE and len(patch.diffs) == 1 and
1737 1737 patch.diffs[0][0] == self.DIFF_EQUAL and
1738 1738 len(diff_text) > 2 * patch_size):
1739 1739 # This is a large deletion. Let it pass in one chunk.
1740 1740 patch.length1 += len(diff_text)
1741 1741 start1 += len(diff_text)
1742 1742 empty = False
1743 1743 patch.diffs.append((diff_type, diff_text))
1744 1744 del bigpatch.diffs[0]
1745 1745 else:
1746 1746 # Deletion or equality. Only take as much as we can stomach.
1747 1747 diff_text = diff_text[:patch_size - patch.length1 -
1748 1748 self.Patch_Margin]
1749 1749 patch.length1 += len(diff_text)
1750 1750 start1 += len(diff_text)
1751 1751 if diff_type == self.DIFF_EQUAL:
1752 1752 patch.length2 += len(diff_text)
1753 1753 start2 += len(diff_text)
1754 1754 else:
1755 1755 empty = False
1756 1756
1757 1757 patch.diffs.append((diff_type, diff_text))
1758 1758 if diff_text == bigpatch.diffs[0][1]:
1759 1759 del bigpatch.diffs[0]
1760 1760 else:
1761 1761 bigpatch.diffs[0] = (bigpatch.diffs[0][0],
1762 1762 bigpatch.diffs[0][1][len(diff_text):])
1763 1763
1764 1764 # Compute the head context for the next patch.
1765 1765 precontext = self.diff_text2(patch.diffs)
1766 1766 precontext = precontext[-self.Patch_Margin:]
1767 1767 # Append the end context for this patch.
1768 1768 postcontext = self.diff_text1(bigpatch.diffs)[:self.Patch_Margin]
1769 1769 if postcontext:
1770 1770 patch.length1 += len(postcontext)
1771 1771 patch.length2 += len(postcontext)
1772 1772 if len(patch.diffs) != 0 and patch.diffs[-1][0] == self.DIFF_EQUAL:
1773 1773 patch.diffs[-1] = (self.DIFF_EQUAL, patch.diffs[-1][1] +
1774 1774 postcontext)
1775 1775 else:
1776 1776 patch.diffs.append((self.DIFF_EQUAL, postcontext))
1777 1777
1778 1778 if not empty:
1779 1779 x += 1
1780 1780 patches.insert(x, patch)
1781 1781
1782 1782 def patch_toText(self, patches):
1783 1783 """Take a list of patches and return a textual representation.
1784 1784
1785 1785 Args:
1786 1786 patches: Array of Patch objects.
1787 1787
1788 1788 Returns:
1789 1789 Text representation of patches.
1790 1790 """
1791 1791 text = []
1792 1792 for patch in patches:
1793 1793 text.append(str(patch))
1794 1794 return "".join(text)
1795 1795
1796 1796 def patch_fromText(self, textline):
1797 1797 """Parse a textual representation of patches and return a list of patch
1798 1798 objects.
1799 1799
1800 1800 Args:
1801 1801 textline: Text representation of patches.
1802 1802
1803 1803 Returns:
1804 1804 Array of Patch objects.
1805 1805
1806 1806 Raises:
1807 1807 ValueError: If invalid input.
1808 1808 """
1809 1809 if type(textline) == unicode:
1810 1810 # Patches should be composed of a subset of ascii chars, Unicode not
1811 1811 # required. If this encode raises UnicodeEncodeError, patch is invalid.
1812 1812 textline = textline.encode("ascii")
1813 1813 patches = []
1814 1814 if not textline:
1815 1815 return patches
1816 1816 text = textline.split('\n')
1817 1817 while len(text) != 0:
1818 1818 m = re.match("^@@ -(\d+),?(\d*) \+(\d+),?(\d*) @@$", text[0])
1819 1819 if not m:
1820 1820 raise ValueError("Invalid patch string: " + text[0])
1821 1821 patch = patch_obj()
1822 1822 patches.append(patch)
1823 1823 patch.start1 = int(m.group(1))
1824 1824 if m.group(2) == '':
1825 1825 patch.start1 -= 1
1826 1826 patch.length1 = 1
1827 1827 elif m.group(2) == '0':
1828 1828 patch.length1 = 0
1829 1829 else:
1830 1830 patch.start1 -= 1
1831 1831 patch.length1 = int(m.group(2))
1832 1832
1833 1833 patch.start2 = int(m.group(3))
1834 1834 if m.group(4) == '':
1835 1835 patch.start2 -= 1
1836 1836 patch.length2 = 1
1837 1837 elif m.group(4) == '0':
1838 1838 patch.length2 = 0
1839 1839 else:
1840 1840 patch.start2 -= 1
1841 1841 patch.length2 = int(m.group(4))
1842 1842
1843 1843 del text[0]
1844 1844
1845 1845 while len(text) != 0:
1846 1846 if text[0]:
1847 1847 sign = text[0][0]
1848 1848 else:
1849 1849 sign = ''
1850 1850 line = urllib.unquote(text[0][1:])
1851 1851 line = line.decode("utf-8")
1852 1852 if sign == '+':
1853 1853 # Insertion.
1854 1854 patch.diffs.append((self.DIFF_INSERT, line))
1855 1855 elif sign == '-':
1856 1856 # Deletion.
1857 1857 patch.diffs.append((self.DIFF_DELETE, line))
1858 1858 elif sign == ' ':
1859 1859 # Minor equality.
1860 1860 patch.diffs.append((self.DIFF_EQUAL, line))
1861 1861 elif sign == '@':
1862 1862 # Start of next patch.
1863 1863 break
1864 1864 elif sign == '':
1865 1865 # Blank line? Whatever.
1866 1866 pass
1867 1867 else:
1868 1868 # WTF?
1869 1869 raise ValueError("Invalid patch mode: '%s'\n%s" % (sign, line))
1870 1870 del text[0]
1871 1871 return patches
1872 1872
1873 1873
1874 1874 class patch_obj:
1875 1875 """Class representing one patch operation.
1876 1876 """
1877 1877
1878 1878 def __init__(self):
1879 1879 """Initializes with an empty list of diffs.
1880 1880 """
1881 1881 self.diffs = []
1882 1882 self.start1 = None
1883 1883 self.start2 = None
1884 1884 self.length1 = 0
1885 1885 self.length2 = 0
1886 1886
1887 1887 def __str__(self):
1888 1888 """Emmulate GNU diff's format.
1889 1889 Header: @@ -382,8 +481,9 @@
1890 1890 Indicies are printed as 1-based, not 0-based.
1891 1891
1892 1892 Returns:
1893 1893 The GNU diff string.
1894 1894 """
1895 1895 if self.length1 == 0:
1896 1896 coords1 = str(self.start1) + ",0"
1897 1897 elif self.length1 == 1:
1898 1898 coords1 = str(self.start1 + 1)
1899 1899 else:
1900 1900 coords1 = str(self.start1 + 1) + "," + str(self.length1)
1901 1901 if self.length2 == 0:
1902 1902 coords2 = str(self.start2) + ",0"
1903 1903 elif self.length2 == 1:
1904 1904 coords2 = str(self.start2 + 1)
1905 1905 else:
1906 1906 coords2 = str(self.start2 + 1) + "," + str(self.length2)
1907 1907 text = ["@@ -", coords1, " +", coords2, " @@\n"]
1908 1908 # Escape the body of the patch with %xx notation.
1909 1909 for (op, data) in self.diffs:
1910 1910 if op == diff_match_patch.DIFF_INSERT:
1911 1911 text.append("+")
1912 1912 elif op == diff_match_patch.DIFF_DELETE:
1913 1913 text.append("-")
1914 1914 elif op == diff_match_patch.DIFF_EQUAL:
1915 1915 text.append(" ")
1916 1916 # High ascii will raise UnicodeDecodeError. Use Unicode instead.
1917 1917 data = data.encode("utf-8")
1918 1918 text.append(urllib.quote(data, "!~*'();/?:@&=+$,# ") + "\n")
1919 1919 return "".join(text) No newline at end of file
@@ -1,1739 +1,1737 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2012-2018 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21
22 22 """
23 23 pull request model for RhodeCode
24 24 """
25 25
26 26
27 27 import json
28 28 import logging
29 29 import datetime
30 30 import urllib
31 31 import collections
32 32
33 33 from pyramid.threadlocal import get_current_request
34 34
35 35 from rhodecode import events
36 36 from rhodecode.translation import lazy_ugettext#, _
37 37 from rhodecode.lib import helpers as h, hooks_utils, diffs
38 38 from rhodecode.lib import audit_logger
39 39 from rhodecode.lib.compat import OrderedDict
40 40 from rhodecode.lib.hooks_daemon import prepare_callback_daemon
41 41 from rhodecode.lib.markup_renderer import (
42 42 DEFAULT_COMMENTS_RENDERER, RstTemplateRenderer)
43 43 from rhodecode.lib.utils2 import safe_unicode, safe_str, md5_safe
44 44 from rhodecode.lib.vcs.backends.base import (
45 45 Reference, MergeResponse, MergeFailureReason, UpdateFailureReason)
46 46 from rhodecode.lib.vcs.conf import settings as vcs_settings
47 47 from rhodecode.lib.vcs.exceptions import (
48 48 CommitDoesNotExistError, EmptyRepositoryError)
49 49 from rhodecode.model import BaseModel
50 50 from rhodecode.model.changeset_status import ChangesetStatusModel
51 51 from rhodecode.model.comment import CommentsModel
52 52 from rhodecode.model.db import (
53 53 or_, PullRequest, PullRequestReviewers, ChangesetStatus,
54 54 PullRequestVersion, ChangesetComment, Repository, RepoReviewRule)
55 55 from rhodecode.model.meta import Session
56 56 from rhodecode.model.notification import NotificationModel, \
57 57 EmailNotificationModel
58 58 from rhodecode.model.scm import ScmModel
59 59 from rhodecode.model.settings import VcsSettingsModel
60 60
61 61
62 62 log = logging.getLogger(__name__)
63 63
64 64
65 65 # Data structure to hold the response data when updating commits during a pull
66 66 # request update.
67 67 UpdateResponse = collections.namedtuple('UpdateResponse', [
68 68 'executed', 'reason', 'new', 'old', 'changes',
69 69 'source_changed', 'target_changed'])
70 70
71 71
72 72 class PullRequestModel(BaseModel):
73 73
74 74 cls = PullRequest
75 75
76 76 DIFF_CONTEXT = diffs.DEFAULT_CONTEXT
77 77
78 78 MERGE_STATUS_MESSAGES = {
79 79 MergeFailureReason.NONE: lazy_ugettext(
80 80 'This pull request can be automatically merged.'),
81 81 MergeFailureReason.UNKNOWN: lazy_ugettext(
82 82 'This pull request cannot be merged because of an unhandled'
83 83 ' exception.'),
84 84 MergeFailureReason.MERGE_FAILED: lazy_ugettext(
85 85 'This pull request cannot be merged because of merge conflicts.'),
86 86 MergeFailureReason.PUSH_FAILED: lazy_ugettext(
87 87 'This pull request could not be merged because push to target'
88 88 ' failed.'),
89 89 MergeFailureReason.TARGET_IS_NOT_HEAD: lazy_ugettext(
90 90 'This pull request cannot be merged because the target is not a'
91 91 ' head.'),
92 92 MergeFailureReason.HG_SOURCE_HAS_MORE_BRANCHES: lazy_ugettext(
93 93 'This pull request cannot be merged because the source contains'
94 94 ' more branches than the target.'),
95 95 MergeFailureReason.HG_TARGET_HAS_MULTIPLE_HEADS: lazy_ugettext(
96 96 'This pull request cannot be merged because the target has'
97 97 ' multiple heads.'),
98 98 MergeFailureReason.TARGET_IS_LOCKED: lazy_ugettext(
99 99 'This pull request cannot be merged because the target repository'
100 100 ' is locked.'),
101 101 MergeFailureReason._DEPRECATED_MISSING_COMMIT: lazy_ugettext(
102 102 'This pull request cannot be merged because the target or the '
103 103 'source reference is missing.'),
104 104 MergeFailureReason.MISSING_TARGET_REF: lazy_ugettext(
105 105 'This pull request cannot be merged because the target '
106 106 'reference is missing.'),
107 107 MergeFailureReason.MISSING_SOURCE_REF: lazy_ugettext(
108 108 'This pull request cannot be merged because the source '
109 109 'reference is missing.'),
110 110 MergeFailureReason.SUBREPO_MERGE_FAILED: lazy_ugettext(
111 111 'This pull request cannot be merged because of conflicts related '
112 112 'to sub repositories.'),
113 113 }
114 114
115 115 UPDATE_STATUS_MESSAGES = {
116 116 UpdateFailureReason.NONE: lazy_ugettext(
117 117 'Pull request update successful.'),
118 118 UpdateFailureReason.UNKNOWN: lazy_ugettext(
119 119 'Pull request update failed because of an unknown error.'),
120 120 UpdateFailureReason.NO_CHANGE: lazy_ugettext(
121 121 'No update needed because the source and target have not changed.'),
122 122 UpdateFailureReason.WRONG_REF_TYPE: lazy_ugettext(
123 123 'Pull request cannot be updated because the reference type is '
124 124 'not supported for an update. Only Branch, Tag or Bookmark is allowed.'),
125 125 UpdateFailureReason.MISSING_TARGET_REF: lazy_ugettext(
126 126 'This pull request cannot be updated because the target '
127 127 'reference is missing.'),
128 128 UpdateFailureReason.MISSING_SOURCE_REF: lazy_ugettext(
129 129 'This pull request cannot be updated because the source '
130 130 'reference is missing.'),
131 131 }
132 132
133 133 def __get_pull_request(self, pull_request):
134 134 return self._get_instance((
135 135 PullRequest, PullRequestVersion), pull_request)
136 136
137 137 def _check_perms(self, perms, pull_request, user, api=False):
138 138 if not api:
139 139 return h.HasRepoPermissionAny(*perms)(
140 140 user=user, repo_name=pull_request.target_repo.repo_name)
141 141 else:
142 142 return h.HasRepoPermissionAnyApi(*perms)(
143 143 user=user, repo_name=pull_request.target_repo.repo_name)
144 144
145 145 def check_user_read(self, pull_request, user, api=False):
146 146 _perms = ('repository.admin', 'repository.write', 'repository.read',)
147 147 return self._check_perms(_perms, pull_request, user, api)
148 148
149 149 def check_user_merge(self, pull_request, user, api=False):
150 150 _perms = ('repository.admin', 'repository.write', 'hg.admin',)
151 151 return self._check_perms(_perms, pull_request, user, api)
152 152
153 153 def check_user_update(self, pull_request, user, api=False):
154 154 owner = user.user_id == pull_request.user_id
155 155 return self.check_user_merge(pull_request, user, api) or owner
156 156
157 157 def check_user_delete(self, pull_request, user):
158 158 owner = user.user_id == pull_request.user_id
159 159 _perms = ('repository.admin',)
160 160 return self._check_perms(_perms, pull_request, user) or owner
161 161
162 162 def check_user_change_status(self, pull_request, user, api=False):
163 163 reviewer = user.user_id in [x.user_id for x in
164 164 pull_request.reviewers]
165 165 return self.check_user_update(pull_request, user, api) or reviewer
166 166
167 167 def check_user_comment(self, pull_request, user):
168 168 owner = user.user_id == pull_request.user_id
169 169 return self.check_user_read(pull_request, user) or owner
170 170
171 171 def get(self, pull_request):
172 172 return self.__get_pull_request(pull_request)
173 173
174 174 def _prepare_get_all_query(self, repo_name, source=False, statuses=None,
175 175 opened_by=None, order_by=None,
176 176 order_dir='desc'):
177 177 repo = None
178 178 if repo_name:
179 179 repo = self._get_repo(repo_name)
180 180
181 181 q = PullRequest.query()
182 182
183 183 # source or target
184 184 if repo and source:
185 185 q = q.filter(PullRequest.source_repo == repo)
186 186 elif repo:
187 187 q = q.filter(PullRequest.target_repo == repo)
188 188
189 189 # closed,opened
190 190 if statuses:
191 191 q = q.filter(PullRequest.status.in_(statuses))
192 192
193 193 # opened by filter
194 194 if opened_by:
195 195 q = q.filter(PullRequest.user_id.in_(opened_by))
196 196
197 197 if order_by:
198 198 order_map = {
199 199 'name_raw': PullRequest.pull_request_id,
200 200 'title': PullRequest.title,
201 201 'updated_on_raw': PullRequest.updated_on,
202 202 'target_repo': PullRequest.target_repo_id
203 203 }
204 204 if order_dir == 'asc':
205 205 q = q.order_by(order_map[order_by].asc())
206 206 else:
207 207 q = q.order_by(order_map[order_by].desc())
208 208
209 209 return q
210 210
211 211 def count_all(self, repo_name, source=False, statuses=None,
212 212 opened_by=None):
213 213 """
214 214 Count the number of pull requests for a specific repository.
215 215
216 216 :param repo_name: target or source repo
217 217 :param source: boolean flag to specify if repo_name refers to source
218 218 :param statuses: list of pull request statuses
219 219 :param opened_by: author user of the pull request
220 220 :returns: int number of pull requests
221 221 """
222 222 q = self._prepare_get_all_query(
223 223 repo_name, source=source, statuses=statuses, opened_by=opened_by)
224 224
225 225 return q.count()
226 226
227 227 def get_all(self, repo_name, source=False, statuses=None, opened_by=None,
228 228 offset=0, length=None, order_by=None, order_dir='desc'):
229 229 """
230 230 Get all pull requests for a specific repository.
231 231
232 232 :param repo_name: target or source repo
233 233 :param source: boolean flag to specify if repo_name refers to source
234 234 :param statuses: list of pull request statuses
235 235 :param opened_by: author user of the pull request
236 236 :param offset: pagination offset
237 237 :param length: length of returned list
238 238 :param order_by: order of the returned list
239 239 :param order_dir: 'asc' or 'desc' ordering direction
240 240 :returns: list of pull requests
241 241 """
242 242 q = self._prepare_get_all_query(
243 243 repo_name, source=source, statuses=statuses, opened_by=opened_by,
244 244 order_by=order_by, order_dir=order_dir)
245 245
246 246 if length:
247 247 pull_requests = q.limit(length).offset(offset).all()
248 248 else:
249 249 pull_requests = q.all()
250 250
251 251 return pull_requests
252 252
253 253 def count_awaiting_review(self, repo_name, source=False, statuses=None,
254 254 opened_by=None):
255 255 """
256 256 Count the number of pull requests for a specific repository that are
257 257 awaiting review.
258 258
259 259 :param repo_name: target or source repo
260 260 :param source: boolean flag to specify if repo_name refers to source
261 261 :param statuses: list of pull request statuses
262 262 :param opened_by: author user of the pull request
263 263 :returns: int number of pull requests
264 264 """
265 265 pull_requests = self.get_awaiting_review(
266 266 repo_name, source=source, statuses=statuses, opened_by=opened_by)
267 267
268 268 return len(pull_requests)
269 269
270 270 def get_awaiting_review(self, repo_name, source=False, statuses=None,
271 271 opened_by=None, offset=0, length=None,
272 272 order_by=None, order_dir='desc'):
273 273 """
274 274 Get all pull requests for a specific repository that are awaiting
275 275 review.
276 276
277 277 :param repo_name: target or source repo
278 278 :param source: boolean flag to specify if repo_name refers to source
279 279 :param statuses: list of pull request statuses
280 280 :param opened_by: author user of the pull request
281 281 :param offset: pagination offset
282 282 :param length: length of returned list
283 283 :param order_by: order of the returned list
284 284 :param order_dir: 'asc' or 'desc' ordering direction
285 285 :returns: list of pull requests
286 286 """
287 287 pull_requests = self.get_all(
288 288 repo_name, source=source, statuses=statuses, opened_by=opened_by,
289 289 order_by=order_by, order_dir=order_dir)
290 290
291 291 _filtered_pull_requests = []
292 292 for pr in pull_requests:
293 293 status = pr.calculated_review_status()
294 294 if status in [ChangesetStatus.STATUS_NOT_REVIEWED,
295 295 ChangesetStatus.STATUS_UNDER_REVIEW]:
296 296 _filtered_pull_requests.append(pr)
297 297 if length:
298 298 return _filtered_pull_requests[offset:offset+length]
299 299 else:
300 300 return _filtered_pull_requests
301 301
302 302 def count_awaiting_my_review(self, repo_name, source=False, statuses=None,
303 303 opened_by=None, user_id=None):
304 304 """
305 305 Count the number of pull requests for a specific repository that are
306 306 awaiting review from a specific user.
307 307
308 308 :param repo_name: target or source repo
309 309 :param source: boolean flag to specify if repo_name refers to source
310 310 :param statuses: list of pull request statuses
311 311 :param opened_by: author user of the pull request
312 312 :param user_id: reviewer user of the pull request
313 313 :returns: int number of pull requests
314 314 """
315 315 pull_requests = self.get_awaiting_my_review(
316 316 repo_name, source=source, statuses=statuses, opened_by=opened_by,
317 317 user_id=user_id)
318 318
319 319 return len(pull_requests)
320 320
321 321 def get_awaiting_my_review(self, repo_name, source=False, statuses=None,
322 322 opened_by=None, user_id=None, offset=0,
323 323 length=None, order_by=None, order_dir='desc'):
324 324 """
325 325 Get all pull requests for a specific repository that are awaiting
326 326 review from a specific user.
327 327
328 328 :param repo_name: target or source repo
329 329 :param source: boolean flag to specify if repo_name refers to source
330 330 :param statuses: list of pull request statuses
331 331 :param opened_by: author user of the pull request
332 332 :param user_id: reviewer user of the pull request
333 333 :param offset: pagination offset
334 334 :param length: length of returned list
335 335 :param order_by: order of the returned list
336 336 :param order_dir: 'asc' or 'desc' ordering direction
337 337 :returns: list of pull requests
338 338 """
339 339 pull_requests = self.get_all(
340 340 repo_name, source=source, statuses=statuses, opened_by=opened_by,
341 341 order_by=order_by, order_dir=order_dir)
342 342
343 343 _my = PullRequestModel().get_not_reviewed(user_id)
344 344 my_participation = []
345 345 for pr in pull_requests:
346 346 if pr in _my:
347 347 my_participation.append(pr)
348 348 _filtered_pull_requests = my_participation
349 349 if length:
350 350 return _filtered_pull_requests[offset:offset+length]
351 351 else:
352 352 return _filtered_pull_requests
353 353
354 354 def get_not_reviewed(self, user_id):
355 355 return [
356 356 x.pull_request for x in PullRequestReviewers.query().filter(
357 357 PullRequestReviewers.user_id == user_id).all()
358 358 ]
359 359
360 360 def _prepare_participating_query(self, user_id=None, statuses=None,
361 361 order_by=None, order_dir='desc'):
362 362 q = PullRequest.query()
363 363 if user_id:
364 364 reviewers_subquery = Session().query(
365 365 PullRequestReviewers.pull_request_id).filter(
366 366 PullRequestReviewers.user_id == user_id).subquery()
367 367 user_filter = or_(
368 368 PullRequest.user_id == user_id,
369 369 PullRequest.pull_request_id.in_(reviewers_subquery)
370 370 )
371 371 q = PullRequest.query().filter(user_filter)
372 372
373 373 # closed,opened
374 374 if statuses:
375 375 q = q.filter(PullRequest.status.in_(statuses))
376 376
377 377 if order_by:
378 378 order_map = {
379 379 'name_raw': PullRequest.pull_request_id,
380 380 'title': PullRequest.title,
381 381 'updated_on_raw': PullRequest.updated_on,
382 382 'target_repo': PullRequest.target_repo_id
383 383 }
384 384 if order_dir == 'asc':
385 385 q = q.order_by(order_map[order_by].asc())
386 386 else:
387 387 q = q.order_by(order_map[order_by].desc())
388 388
389 389 return q
390 390
391 391 def count_im_participating_in(self, user_id=None, statuses=None):
392 392 q = self._prepare_participating_query(user_id, statuses=statuses)
393 393 return q.count()
394 394
395 395 def get_im_participating_in(
396 396 self, user_id=None, statuses=None, offset=0,
397 397 length=None, order_by=None, order_dir='desc'):
398 398 """
399 399 Get all Pull requests that i'm participating in, or i have opened
400 400 """
401 401
402 402 q = self._prepare_participating_query(
403 403 user_id, statuses=statuses, order_by=order_by,
404 404 order_dir=order_dir)
405 405
406 406 if length:
407 407 pull_requests = q.limit(length).offset(offset).all()
408 408 else:
409 409 pull_requests = q.all()
410 410
411 411 return pull_requests
412 412
413 413 def get_versions(self, pull_request):
414 414 """
415 415 returns version of pull request sorted by ID descending
416 416 """
417 417 return PullRequestVersion.query()\
418 418 .filter(PullRequestVersion.pull_request == pull_request)\
419 419 .order_by(PullRequestVersion.pull_request_version_id.asc())\
420 420 .all()
421 421
422 422 def get_pr_version(self, pull_request_id, version=None):
423 423 at_version = None
424 424
425 425 if version and version == 'latest':
426 426 pull_request_ver = PullRequest.get(pull_request_id)
427 427 pull_request_obj = pull_request_ver
428 428 _org_pull_request_obj = pull_request_obj
429 429 at_version = 'latest'
430 430 elif version:
431 431 pull_request_ver = PullRequestVersion.get_or_404(version)
432 432 pull_request_obj = pull_request_ver
433 433 _org_pull_request_obj = pull_request_ver.pull_request
434 434 at_version = pull_request_ver.pull_request_version_id
435 435 else:
436 436 _org_pull_request_obj = pull_request_obj = PullRequest.get_or_404(
437 437 pull_request_id)
438 438
439 439 pull_request_display_obj = PullRequest.get_pr_display_object(
440 440 pull_request_obj, _org_pull_request_obj)
441 441
442 442 return _org_pull_request_obj, pull_request_obj, \
443 443 pull_request_display_obj, at_version
444 444
445 445 def create(self, created_by, source_repo, source_ref, target_repo,
446 446 target_ref, revisions, reviewers, title, description=None,
447 447 description_renderer=None,
448 448 reviewer_data=None, translator=None, auth_user=None):
449 449 translator = translator or get_current_request().translate
450 450
451 451 created_by_user = self._get_user(created_by)
452 452 auth_user = auth_user or created_by_user.AuthUser()
453 453 source_repo = self._get_repo(source_repo)
454 454 target_repo = self._get_repo(target_repo)
455 455
456 456 pull_request = PullRequest()
457 457 pull_request.source_repo = source_repo
458 458 pull_request.source_ref = source_ref
459 459 pull_request.target_repo = target_repo
460 460 pull_request.target_ref = target_ref
461 461 pull_request.revisions = revisions
462 462 pull_request.title = title
463 463 pull_request.description = description
464 464 pull_request.description_renderer = description_renderer
465 465 pull_request.author = created_by_user
466 466 pull_request.reviewer_data = reviewer_data
467 467
468 468 Session().add(pull_request)
469 469 Session().flush()
470 470
471 471 reviewer_ids = set()
472 472 # members / reviewers
473 473 for reviewer_object in reviewers:
474 474 user_id, reasons, mandatory, rules = reviewer_object
475 475 user = self._get_user(user_id)
476 476
477 477 # skip duplicates
478 478 if user.user_id in reviewer_ids:
479 479 continue
480 480
481 481 reviewer_ids.add(user.user_id)
482 482
483 483 reviewer = PullRequestReviewers()
484 484 reviewer.user = user
485 485 reviewer.pull_request = pull_request
486 486 reviewer.reasons = reasons
487 487 reviewer.mandatory = mandatory
488 488
489 489 # NOTE(marcink): pick only first rule for now
490 490 rule_id = list(rules)[0] if rules else None
491 491 rule = RepoReviewRule.get(rule_id) if rule_id else None
492 492 if rule:
493 493 review_group = rule.user_group_vote_rule(user_id)
494 494 # we check if this particular reviewer is member of a voting group
495 495 if review_group:
496 496 # NOTE(marcink):
497 497 # can be that user is member of more but we pick the first same,
498 498 # same as default reviewers algo
499 499 review_group = review_group[0]
500 500
501 501 rule_data = {
502 502 'rule_name':
503 503 rule.review_rule_name,
504 504 'rule_user_group_entry_id':
505 505 review_group.repo_review_rule_users_group_id,
506 506 'rule_user_group_name':
507 507 review_group.users_group.users_group_name,
508 508 'rule_user_group_members':
509 509 [x.user.username for x in review_group.users_group.members],
510 510 'rule_user_group_members_id':
511 511 [x.user.user_id for x in review_group.users_group.members],
512 512 }
513 513 # e.g {'vote_rule': -1, 'mandatory': True}
514 514 rule_data.update(review_group.rule_data())
515 515
516 516 reviewer.rule_data = rule_data
517 517
518 518 Session().add(reviewer)
519 519 Session().flush()
520 520
521 521 # Set approval status to "Under Review" for all commits which are
522 522 # part of this pull request.
523 523 ChangesetStatusModel().set_status(
524 524 repo=target_repo,
525 525 status=ChangesetStatus.STATUS_UNDER_REVIEW,
526 526 user=created_by_user,
527 527 pull_request=pull_request
528 528 )
529 529 # we commit early at this point. This has to do with a fact
530 530 # that before queries do some row-locking. And because of that
531 531 # we need to commit and finish transation before below validate call
532 532 # that for large repos could be long resulting in long row locks
533 533 Session().commit()
534 534
535 535 # prepare workspace, and run initial merge simulation
536 536 MergeCheck.validate(
537 537 pull_request, auth_user=auth_user, translator=translator)
538 538
539 539 self.notify_reviewers(pull_request, reviewer_ids)
540 540 self._trigger_pull_request_hook(
541 541 pull_request, created_by_user, 'create')
542 542
543 543 creation_data = pull_request.get_api_data(with_merge_state=False)
544 544 self._log_audit_action(
545 545 'repo.pull_request.create', {'data': creation_data},
546 546 auth_user, pull_request)
547 547
548 548 return pull_request
549 549
550 550 def _trigger_pull_request_hook(self, pull_request, user, action):
551 551 pull_request = self.__get_pull_request(pull_request)
552 552 target_scm = pull_request.target_repo.scm_instance()
553 553 if action == 'create':
554 554 trigger_hook = hooks_utils.trigger_log_create_pull_request_hook
555 555 elif action == 'merge':
556 556 trigger_hook = hooks_utils.trigger_log_merge_pull_request_hook
557 557 elif action == 'close':
558 558 trigger_hook = hooks_utils.trigger_log_close_pull_request_hook
559 559 elif action == 'review_status_change':
560 560 trigger_hook = hooks_utils.trigger_log_review_pull_request_hook
561 561 elif action == 'update':
562 562 trigger_hook = hooks_utils.trigger_log_update_pull_request_hook
563 563 else:
564 564 return
565 565
566 566 trigger_hook(
567 567 username=user.username,
568 568 repo_name=pull_request.target_repo.repo_name,
569 569 repo_alias=target_scm.alias,
570 570 pull_request=pull_request)
571 571
572 572 def _get_commit_ids(self, pull_request):
573 573 """
574 574 Return the commit ids of the merged pull request.
575 575
576 576 This method is not dealing correctly yet with the lack of autoupdates
577 577 nor with the implicit target updates.
578 578 For example: if a commit in the source repo is already in the target it
579 579 will be reported anyways.
580 580 """
581 581 merge_rev = pull_request.merge_rev
582 582 if merge_rev is None:
583 583 raise ValueError('This pull request was not merged yet')
584 584
585 585 commit_ids = list(pull_request.revisions)
586 586 if merge_rev not in commit_ids:
587 587 commit_ids.append(merge_rev)
588 588
589 589 return commit_ids
590 590
591 591 def merge_repo(self, pull_request, user, extras):
592 592 log.debug("Merging pull request %s", pull_request.pull_request_id)
593 593 extras['user_agent'] = 'internal-merge'
594 594 merge_state = self._merge_pull_request(pull_request, user, extras)
595 595 if merge_state.executed:
596 596 log.debug(
597 597 "Merge was successful, updating the pull request comments.")
598 598 self._comment_and_close_pr(pull_request, user, merge_state)
599 599
600 600 self._log_audit_action(
601 601 'repo.pull_request.merge',
602 602 {'merge_state': merge_state.__dict__},
603 603 user, pull_request)
604 604
605 605 else:
606 606 log.warn("Merge failed, not updating the pull request.")
607 607 return merge_state
608 608
609 609 def _merge_pull_request(self, pull_request, user, extras, merge_msg=None):
610 610 target_vcs = pull_request.target_repo.scm_instance()
611 611 source_vcs = pull_request.source_repo.scm_instance()
612 612
613 613 message = safe_unicode(merge_msg or vcs_settings.MERGE_MESSAGE_TMPL).format(
614 614 pr_id=pull_request.pull_request_id,
615 615 pr_title=pull_request.title,
616 616 source_repo=source_vcs.name,
617 617 source_ref_name=pull_request.source_ref_parts.name,
618 618 target_repo=target_vcs.name,
619 619 target_ref_name=pull_request.target_ref_parts.name,
620 620 )
621 621
622 622 workspace_id = self._workspace_id(pull_request)
623 623 repo_id = pull_request.target_repo.repo_id
624 624 use_rebase = self._use_rebase_for_merging(pull_request)
625 625 close_branch = self._close_branch_before_merging(pull_request)
626 626
627 627 target_ref = self._refresh_reference(
628 628 pull_request.target_ref_parts, target_vcs)
629 629
630 630 callback_daemon, extras = prepare_callback_daemon(
631 631 extras, protocol=vcs_settings.HOOKS_PROTOCOL,
632 632 host=vcs_settings.HOOKS_HOST,
633 633 use_direct_calls=vcs_settings.HOOKS_DIRECT_CALLS)
634 634
635 635 with callback_daemon:
636 636 # TODO: johbo: Implement a clean way to run a config_override
637 637 # for a single call.
638 638 target_vcs.config.set(
639 639 'rhodecode', 'RC_SCM_DATA', json.dumps(extras))
640 640
641 641 user_name = user.short_contact
642 642 merge_state = target_vcs.merge(
643 643 repo_id, workspace_id, target_ref, source_vcs,
644 644 pull_request.source_ref_parts,
645 645 user_name=user_name, user_email=user.email,
646 646 message=message, use_rebase=use_rebase,
647 647 close_branch=close_branch)
648 648 return merge_state
649 649
650 650 def _comment_and_close_pr(self, pull_request, user, merge_state, close_msg=None):
651 651 pull_request.merge_rev = merge_state.merge_ref.commit_id
652 652 pull_request.updated_on = datetime.datetime.now()
653 653 close_msg = close_msg or 'Pull request merged and closed'
654 654
655 655 CommentsModel().create(
656 656 text=safe_unicode(close_msg),
657 657 repo=pull_request.target_repo.repo_id,
658 658 user=user.user_id,
659 659 pull_request=pull_request.pull_request_id,
660 660 f_path=None,
661 661 line_no=None,
662 662 closing_pr=True
663 663 )
664 664
665 665 Session().add(pull_request)
666 666 Session().flush()
667 667 # TODO: paris: replace invalidation with less radical solution
668 668 ScmModel().mark_for_invalidation(
669 669 pull_request.target_repo.repo_name)
670 670 self._trigger_pull_request_hook(pull_request, user, 'merge')
671 671
672 672 def has_valid_update_type(self, pull_request):
673 673 source_ref_type = pull_request.source_ref_parts.type
674 674 return source_ref_type in ['book', 'branch', 'tag']
675 675
676 676 def update_commits(self, pull_request):
677 677 """
678 678 Get the updated list of commits for the pull request
679 679 and return the new pull request version and the list
680 680 of commits processed by this update action
681 681 """
682 682 pull_request = self.__get_pull_request(pull_request)
683 683 source_ref_type = pull_request.source_ref_parts.type
684 684 source_ref_name = pull_request.source_ref_parts.name
685 685 source_ref_id = pull_request.source_ref_parts.commit_id
686 686
687 687 target_ref_type = pull_request.target_ref_parts.type
688 688 target_ref_name = pull_request.target_ref_parts.name
689 689 target_ref_id = pull_request.target_ref_parts.commit_id
690 690
691 691 if not self.has_valid_update_type(pull_request):
692 692 log.debug(
693 693 "Skipping update of pull request %s due to ref type: %s",
694 694 pull_request, source_ref_type)
695 695 return UpdateResponse(
696 696 executed=False,
697 697 reason=UpdateFailureReason.WRONG_REF_TYPE,
698 698 old=pull_request, new=None, changes=None,
699 699 source_changed=False, target_changed=False)
700 700
701 701 # source repo
702 702 source_repo = pull_request.source_repo.scm_instance()
703 703 try:
704 704 source_commit = source_repo.get_commit(commit_id=source_ref_name)
705 705 except CommitDoesNotExistError:
706 706 return UpdateResponse(
707 707 executed=False,
708 708 reason=UpdateFailureReason.MISSING_SOURCE_REF,
709 709 old=pull_request, new=None, changes=None,
710 710 source_changed=False, target_changed=False)
711 711
712 712 source_changed = source_ref_id != source_commit.raw_id
713 713
714 714 # target repo
715 715 target_repo = pull_request.target_repo.scm_instance()
716 716 try:
717 717 target_commit = target_repo.get_commit(commit_id=target_ref_name)
718 718 except CommitDoesNotExistError:
719 719 return UpdateResponse(
720 720 executed=False,
721 721 reason=UpdateFailureReason.MISSING_TARGET_REF,
722 722 old=pull_request, new=None, changes=None,
723 723 source_changed=False, target_changed=False)
724 724 target_changed = target_ref_id != target_commit.raw_id
725 725
726 726 if not (source_changed or target_changed):
727 727 log.debug("Nothing changed in pull request %s", pull_request)
728 728 return UpdateResponse(
729 729 executed=False,
730 730 reason=UpdateFailureReason.NO_CHANGE,
731 731 old=pull_request, new=None, changes=None,
732 732 source_changed=target_changed, target_changed=source_changed)
733 733
734 734 change_in_found = 'target repo' if target_changed else 'source repo'
735 735 log.debug('Updating pull request because of change in %s detected',
736 736 change_in_found)
737 737
738 738 # Finally there is a need for an update, in case of source change
739 739 # we create a new version, else just an update
740 740 if source_changed:
741 741 pull_request_version = self._create_version_from_snapshot(pull_request)
742 742 self._link_comments_to_version(pull_request_version)
743 743 else:
744 744 try:
745 745 ver = pull_request.versions[-1]
746 746 except IndexError:
747 747 ver = None
748 748
749 749 pull_request.pull_request_version_id = \
750 750 ver.pull_request_version_id if ver else None
751 751 pull_request_version = pull_request
752 752
753 753 try:
754 754 if target_ref_type in ('tag', 'branch', 'book'):
755 755 target_commit = target_repo.get_commit(target_ref_name)
756 756 else:
757 757 target_commit = target_repo.get_commit(target_ref_id)
758 758 except CommitDoesNotExistError:
759 759 return UpdateResponse(
760 760 executed=False,
761 761 reason=UpdateFailureReason.MISSING_TARGET_REF,
762 762 old=pull_request, new=None, changes=None,
763 763 source_changed=source_changed, target_changed=target_changed)
764 764
765 765 # re-compute commit ids
766 766 old_commit_ids = pull_request.revisions
767 767 pre_load = ["author", "branch", "date", "message"]
768 768 commit_ranges = target_repo.compare(
769 769 target_commit.raw_id, source_commit.raw_id, source_repo, merge=True,
770 770 pre_load=pre_load)
771 771
772 772 ancestor = target_repo.get_common_ancestor(
773 773 target_commit.raw_id, source_commit.raw_id, source_repo)
774 774
775 775 pull_request.source_ref = '%s:%s:%s' % (
776 776 source_ref_type, source_ref_name, source_commit.raw_id)
777 777 pull_request.target_ref = '%s:%s:%s' % (
778 778 target_ref_type, target_ref_name, ancestor)
779 779
780 780 pull_request.revisions = [
781 781 commit.raw_id for commit in reversed(commit_ranges)]
782 782 pull_request.updated_on = datetime.datetime.now()
783 783 Session().add(pull_request)
784 784 new_commit_ids = pull_request.revisions
785 785
786 786 old_diff_data, new_diff_data = self._generate_update_diffs(
787 787 pull_request, pull_request_version)
788 788
789 789 # calculate commit and file changes
790 790 changes = self._calculate_commit_id_changes(
791 791 old_commit_ids, new_commit_ids)
792 792 file_changes = self._calculate_file_changes(
793 793 old_diff_data, new_diff_data)
794 794
795 795 # set comments as outdated if DIFFS changed
796 796 CommentsModel().outdate_comments(
797 797 pull_request, old_diff_data=old_diff_data,
798 798 new_diff_data=new_diff_data)
799 799
800 800 commit_changes = (changes.added or changes.removed)
801 801 file_node_changes = (
802 802 file_changes.added or file_changes.modified or file_changes.removed)
803 803 pr_has_changes = commit_changes or file_node_changes
804 804
805 805 # Add an automatic comment to the pull request, in case
806 806 # anything has changed
807 807 if pr_has_changes:
808 808 update_comment = CommentsModel().create(
809 809 text=self._render_update_message(changes, file_changes),
810 810 repo=pull_request.target_repo,
811 811 user=pull_request.author,
812 812 pull_request=pull_request,
813 813 send_email=False, renderer=DEFAULT_COMMENTS_RENDERER)
814 814
815 815 # Update status to "Under Review" for added commits
816 816 for commit_id in changes.added:
817 817 ChangesetStatusModel().set_status(
818 818 repo=pull_request.source_repo,
819 819 status=ChangesetStatus.STATUS_UNDER_REVIEW,
820 820 comment=update_comment,
821 821 user=pull_request.author,
822 822 pull_request=pull_request,
823 823 revision=commit_id)
824 824
825 825 log.debug(
826 826 'Updated pull request %s, added_ids: %s, common_ids: %s, '
827 827 'removed_ids: %s', pull_request.pull_request_id,
828 828 changes.added, changes.common, changes.removed)
829 829 log.debug(
830 830 'Updated pull request with the following file changes: %s',
831 831 file_changes)
832 832
833 833 log.info(
834 834 "Updated pull request %s from commit %s to commit %s, "
835 835 "stored new version %s of this pull request.",
836 836 pull_request.pull_request_id, source_ref_id,
837 837 pull_request.source_ref_parts.commit_id,
838 838 pull_request_version.pull_request_version_id)
839 839 Session().commit()
840 840 self._trigger_pull_request_hook(
841 841 pull_request, pull_request.author, 'update')
842 842
843 843 return UpdateResponse(
844 844 executed=True, reason=UpdateFailureReason.NONE,
845 845 old=pull_request, new=pull_request_version, changes=changes,
846 846 source_changed=source_changed, target_changed=target_changed)
847 847
848 848 def _create_version_from_snapshot(self, pull_request):
849 849 version = PullRequestVersion()
850 850 version.title = pull_request.title
851 851 version.description = pull_request.description
852 852 version.status = pull_request.status
853 853 version.created_on = datetime.datetime.now()
854 854 version.updated_on = pull_request.updated_on
855 855 version.user_id = pull_request.user_id
856 856 version.source_repo = pull_request.source_repo
857 857 version.source_ref = pull_request.source_ref
858 858 version.target_repo = pull_request.target_repo
859 859 version.target_ref = pull_request.target_ref
860 860
861 861 version._last_merge_source_rev = pull_request._last_merge_source_rev
862 862 version._last_merge_target_rev = pull_request._last_merge_target_rev
863 863 version.last_merge_status = pull_request.last_merge_status
864 864 version.shadow_merge_ref = pull_request.shadow_merge_ref
865 865 version.merge_rev = pull_request.merge_rev
866 866 version.reviewer_data = pull_request.reviewer_data
867 867
868 868 version.revisions = pull_request.revisions
869 869 version.pull_request = pull_request
870 870 Session().add(version)
871 871 Session().flush()
872 872
873 873 return version
874 874
875 875 def _generate_update_diffs(self, pull_request, pull_request_version):
876 876
877 877 diff_context = (
878 878 self.DIFF_CONTEXT +
879 879 CommentsModel.needed_extra_diff_context())
880 880 hide_whitespace_changes = False
881 881 source_repo = pull_request_version.source_repo
882 882 source_ref_id = pull_request_version.source_ref_parts.commit_id
883 883 target_ref_id = pull_request_version.target_ref_parts.commit_id
884 884 old_diff = self._get_diff_from_pr_or_version(
885 885 source_repo, source_ref_id, target_ref_id,
886 886 hide_whitespace_changes=hide_whitespace_changes, diff_context=diff_context)
887 887
888 888 source_repo = pull_request.source_repo
889 889 source_ref_id = pull_request.source_ref_parts.commit_id
890 890 target_ref_id = pull_request.target_ref_parts.commit_id
891 891
892 892 new_diff = self._get_diff_from_pr_or_version(
893 893 source_repo, source_ref_id, target_ref_id,
894 894 hide_whitespace_changes=hide_whitespace_changes, diff_context=diff_context)
895 895
896 896 old_diff_data = diffs.DiffProcessor(old_diff)
897 897 old_diff_data.prepare()
898 898 new_diff_data = diffs.DiffProcessor(new_diff)
899 899 new_diff_data.prepare()
900 900
901 901 return old_diff_data, new_diff_data
902 902
903 903 def _link_comments_to_version(self, pull_request_version):
904 904 """
905 905 Link all unlinked comments of this pull request to the given version.
906 906
907 907 :param pull_request_version: The `PullRequestVersion` to which
908 908 the comments shall be linked.
909 909
910 910 """
911 911 pull_request = pull_request_version.pull_request
912 912 comments = ChangesetComment.query()\
913 913 .filter(
914 914 # TODO: johbo: Should we query for the repo at all here?
915 915 # Pending decision on how comments of PRs are to be related
916 916 # to either the source repo, the target repo or no repo at all.
917 917 ChangesetComment.repo_id == pull_request.target_repo.repo_id,
918 918 ChangesetComment.pull_request == pull_request,
919 919 ChangesetComment.pull_request_version == None)\
920 920 .order_by(ChangesetComment.comment_id.asc())
921 921
922 922 # TODO: johbo: Find out why this breaks if it is done in a bulk
923 923 # operation.
924 924 for comment in comments:
925 925 comment.pull_request_version_id = (
926 926 pull_request_version.pull_request_version_id)
927 927 Session().add(comment)
928 928
929 929 def _calculate_commit_id_changes(self, old_ids, new_ids):
930 930 added = [x for x in new_ids if x not in old_ids]
931 931 common = [x for x in new_ids if x in old_ids]
932 932 removed = [x for x in old_ids if x not in new_ids]
933 933 total = new_ids
934 934 return ChangeTuple(added, common, removed, total)
935 935
936 936 def _calculate_file_changes(self, old_diff_data, new_diff_data):
937 937
938 938 old_files = OrderedDict()
939 939 for diff_data in old_diff_data.parsed_diff:
940 940 old_files[diff_data['filename']] = md5_safe(diff_data['raw_diff'])
941 941
942 942 added_files = []
943 943 modified_files = []
944 944 removed_files = []
945 945 for diff_data in new_diff_data.parsed_diff:
946 946 new_filename = diff_data['filename']
947 947 new_hash = md5_safe(diff_data['raw_diff'])
948 948
949 949 old_hash = old_files.get(new_filename)
950 950 if not old_hash:
951 951 # file is not present in old diff, means it's added
952 952 added_files.append(new_filename)
953 953 else:
954 954 if new_hash != old_hash:
955 955 modified_files.append(new_filename)
956 956 # now remove a file from old, since we have seen it already
957 957 del old_files[new_filename]
958 958
959 959 # removed files is when there are present in old, but not in NEW,
960 960 # since we remove old files that are present in new diff, left-overs
961 961 # if any should be the removed files
962 962 removed_files.extend(old_files.keys())
963 963
964 964 return FileChangeTuple(added_files, modified_files, removed_files)
965 965
966 966 def _render_update_message(self, changes, file_changes):
967 967 """
968 968 render the message using DEFAULT_COMMENTS_RENDERER (RST renderer),
969 969 so it's always looking the same disregarding on which default
970 970 renderer system is using.
971 971
972 972 :param changes: changes named tuple
973 973 :param file_changes: file changes named tuple
974 974
975 975 """
976 976 new_status = ChangesetStatus.get_status_lbl(
977 977 ChangesetStatus.STATUS_UNDER_REVIEW)
978 978
979 979 changed_files = (
980 980 file_changes.added + file_changes.modified + file_changes.removed)
981 981
982 982 params = {
983 983 'under_review_label': new_status,
984 984 'added_commits': changes.added,
985 985 'removed_commits': changes.removed,
986 986 'changed_files': changed_files,
987 987 'added_files': file_changes.added,
988 988 'modified_files': file_changes.modified,
989 989 'removed_files': file_changes.removed,
990 990 }
991 991 renderer = RstTemplateRenderer()
992 992 return renderer.render('pull_request_update.mako', **params)
993 993
994 994 def edit(self, pull_request, title, description, description_renderer, user):
995 995 pull_request = self.__get_pull_request(pull_request)
996 996 old_data = pull_request.get_api_data(with_merge_state=False)
997 997 if pull_request.is_closed():
998 998 raise ValueError('This pull request is closed')
999 999 if title:
1000 1000 pull_request.title = title
1001 1001 pull_request.description = description
1002 1002 pull_request.updated_on = datetime.datetime.now()
1003 1003 pull_request.description_renderer = description_renderer
1004 1004 Session().add(pull_request)
1005 1005 self._log_audit_action(
1006 1006 'repo.pull_request.edit', {'old_data': old_data},
1007 1007 user, pull_request)
1008 1008
1009 1009 def update_reviewers(self, pull_request, reviewer_data, user):
1010 1010 """
1011 1011 Update the reviewers in the pull request
1012 1012
1013 1013 :param pull_request: the pr to update
1014 1014 :param reviewer_data: list of tuples
1015 1015 [(user, ['reason1', 'reason2'], mandatory_flag, [rules])]
1016 1016 """
1017 1017 pull_request = self.__get_pull_request(pull_request)
1018 1018 if pull_request.is_closed():
1019 1019 raise ValueError('This pull request is closed')
1020 1020
1021 1021 reviewers = {}
1022 1022 for user_id, reasons, mandatory, rules in reviewer_data:
1023 1023 if isinstance(user_id, (int, basestring)):
1024 1024 user_id = self._get_user(user_id).user_id
1025 1025 reviewers[user_id] = {
1026 1026 'reasons': reasons, 'mandatory': mandatory}
1027 1027
1028 1028 reviewers_ids = set(reviewers.keys())
1029 1029 current_reviewers = PullRequestReviewers.query()\
1030 1030 .filter(PullRequestReviewers.pull_request ==
1031 1031 pull_request).all()
1032 1032 current_reviewers_ids = set([x.user.user_id for x in current_reviewers])
1033 1033
1034 1034 ids_to_add = reviewers_ids.difference(current_reviewers_ids)
1035 1035 ids_to_remove = current_reviewers_ids.difference(reviewers_ids)
1036 1036
1037 1037 log.debug("Adding %s reviewers", ids_to_add)
1038 1038 log.debug("Removing %s reviewers", ids_to_remove)
1039 1039 changed = False
1040 1040 for uid in ids_to_add:
1041 1041 changed = True
1042 1042 _usr = self._get_user(uid)
1043 1043 reviewer = PullRequestReviewers()
1044 1044 reviewer.user = _usr
1045 1045 reviewer.pull_request = pull_request
1046 1046 reviewer.reasons = reviewers[uid]['reasons']
1047 1047 # NOTE(marcink): mandatory shouldn't be changed now
1048 1048 # reviewer.mandatory = reviewers[uid]['reasons']
1049 1049 Session().add(reviewer)
1050 1050 self._log_audit_action(
1051 1051 'repo.pull_request.reviewer.add', {'data': reviewer.get_dict()},
1052 1052 user, pull_request)
1053 1053
1054 1054 for uid in ids_to_remove:
1055 1055 changed = True
1056 1056 reviewers = PullRequestReviewers.query()\
1057 1057 .filter(PullRequestReviewers.user_id == uid,
1058 1058 PullRequestReviewers.pull_request == pull_request)\
1059 1059 .all()
1060 1060 # use .all() in case we accidentally added the same person twice
1061 1061 # this CAN happen due to the lack of DB checks
1062 1062 for obj in reviewers:
1063 1063 old_data = obj.get_dict()
1064 1064 Session().delete(obj)
1065 1065 self._log_audit_action(
1066 1066 'repo.pull_request.reviewer.delete',
1067 1067 {'old_data': old_data}, user, pull_request)
1068 1068
1069 1069 if changed:
1070 1070 pull_request.updated_on = datetime.datetime.now()
1071 1071 Session().add(pull_request)
1072 1072
1073 1073 self.notify_reviewers(pull_request, ids_to_add)
1074 1074 return ids_to_add, ids_to_remove
1075 1075
1076 1076 def get_url(self, pull_request, request=None, permalink=False):
1077 1077 if not request:
1078 1078 request = get_current_request()
1079 1079
1080 1080 if permalink:
1081 1081 return request.route_url(
1082 1082 'pull_requests_global',
1083 1083 pull_request_id=pull_request.pull_request_id,)
1084 1084 else:
1085 1085 return request.route_url('pullrequest_show',
1086 1086 repo_name=safe_str(pull_request.target_repo.repo_name),
1087 1087 pull_request_id=pull_request.pull_request_id,)
1088 1088
1089 1089 def get_shadow_clone_url(self, pull_request, request=None):
1090 1090 """
1091 1091 Returns qualified url pointing to the shadow repository. If this pull
1092 1092 request is closed there is no shadow repository and ``None`` will be
1093 1093 returned.
1094 1094 """
1095 1095 if pull_request.is_closed():
1096 1096 return None
1097 1097 else:
1098 1098 pr_url = urllib.unquote(self.get_url(pull_request, request=request))
1099 1099 return safe_unicode('{pr_url}/repository'.format(pr_url=pr_url))
1100 1100
1101 1101 def notify_reviewers(self, pull_request, reviewers_ids):
1102 1102 # notification to reviewers
1103 1103 if not reviewers_ids:
1104 1104 return
1105 1105
1106 1106 pull_request_obj = pull_request
1107 1107 # get the current participants of this pull request
1108 1108 recipients = reviewers_ids
1109 1109 notification_type = EmailNotificationModel.TYPE_PULL_REQUEST
1110 1110
1111 1111 pr_source_repo = pull_request_obj.source_repo
1112 1112 pr_target_repo = pull_request_obj.target_repo
1113 1113
1114 1114 pr_url = h.route_url('pullrequest_show',
1115 1115 repo_name=pr_target_repo.repo_name,
1116 1116 pull_request_id=pull_request_obj.pull_request_id,)
1117 1117
1118 1118 # set some variables for email notification
1119 1119 pr_target_repo_url = h.route_url(
1120 1120 'repo_summary', repo_name=pr_target_repo.repo_name)
1121 1121
1122 1122 pr_source_repo_url = h.route_url(
1123 1123 'repo_summary', repo_name=pr_source_repo.repo_name)
1124 1124
1125 1125 # pull request specifics
1126 1126 pull_request_commits = [
1127 1127 (x.raw_id, x.message)
1128 1128 for x in map(pr_source_repo.get_commit, pull_request.revisions)]
1129 1129
1130 1130 kwargs = {
1131 1131 'user': pull_request.author,
1132 1132 'pull_request': pull_request_obj,
1133 1133 'pull_request_commits': pull_request_commits,
1134 1134
1135 1135 'pull_request_target_repo': pr_target_repo,
1136 1136 'pull_request_target_repo_url': pr_target_repo_url,
1137 1137
1138 1138 'pull_request_source_repo': pr_source_repo,
1139 1139 'pull_request_source_repo_url': pr_source_repo_url,
1140 1140
1141 1141 'pull_request_url': pr_url,
1142 1142 }
1143 1143
1144 1144 # pre-generate the subject for notification itself
1145 1145 (subject,
1146 1146 _h, _e, # we don't care about those
1147 1147 body_plaintext) = EmailNotificationModel().render_email(
1148 1148 notification_type, **kwargs)
1149 1149
1150 1150 # create notification objects, and emails
1151 1151 NotificationModel().create(
1152 1152 created_by=pull_request.author,
1153 1153 notification_subject=subject,
1154 1154 notification_body=body_plaintext,
1155 1155 notification_type=notification_type,
1156 1156 recipients=recipients,
1157 1157 email_kwargs=kwargs,
1158 1158 )
1159 1159
1160 1160 def delete(self, pull_request, user):
1161 1161 pull_request = self.__get_pull_request(pull_request)
1162 1162 old_data = pull_request.get_api_data(with_merge_state=False)
1163 1163 self._cleanup_merge_workspace(pull_request)
1164 1164 self._log_audit_action(
1165 1165 'repo.pull_request.delete', {'old_data': old_data},
1166 1166 user, pull_request)
1167 1167 Session().delete(pull_request)
1168 1168
1169 1169 def close_pull_request(self, pull_request, user):
1170 1170 pull_request = self.__get_pull_request(pull_request)
1171 1171 self._cleanup_merge_workspace(pull_request)
1172 1172 pull_request.status = PullRequest.STATUS_CLOSED
1173 1173 pull_request.updated_on = datetime.datetime.now()
1174 1174 Session().add(pull_request)
1175 1175 self._trigger_pull_request_hook(
1176 1176 pull_request, pull_request.author, 'close')
1177 1177
1178 1178 pr_data = pull_request.get_api_data(with_merge_state=False)
1179 1179 self._log_audit_action(
1180 1180 'repo.pull_request.close', {'data': pr_data}, user, pull_request)
1181 1181
1182 1182 def close_pull_request_with_comment(
1183 1183 self, pull_request, user, repo, message=None, auth_user=None):
1184 1184
1185 1185 pull_request_review_status = pull_request.calculated_review_status()
1186 1186
1187 1187 if pull_request_review_status == ChangesetStatus.STATUS_APPROVED:
1188 1188 # approved only if we have voting consent
1189 1189 status = ChangesetStatus.STATUS_APPROVED
1190 1190 else:
1191 1191 status = ChangesetStatus.STATUS_REJECTED
1192 1192 status_lbl = ChangesetStatus.get_status_lbl(status)
1193 1193
1194 1194 default_message = (
1195 1195 'Closing with status change {transition_icon} {status}.'
1196 1196 ).format(transition_icon='>', status=status_lbl)
1197 1197 text = message or default_message
1198 1198
1199 1199 # create a comment, and link it to new status
1200 1200 comment = CommentsModel().create(
1201 1201 text=text,
1202 1202 repo=repo.repo_id,
1203 1203 user=user.user_id,
1204 1204 pull_request=pull_request.pull_request_id,
1205 1205 status_change=status_lbl,
1206 1206 status_change_type=status,
1207 1207 closing_pr=True,
1208 1208 auth_user=auth_user,
1209 1209 )
1210 1210
1211 1211 # calculate old status before we change it
1212 1212 old_calculated_status = pull_request.calculated_review_status()
1213 1213 ChangesetStatusModel().set_status(
1214 1214 repo.repo_id,
1215 1215 status,
1216 1216 user.user_id,
1217 1217 comment=comment,
1218 1218 pull_request=pull_request.pull_request_id
1219 1219 )
1220 1220
1221 1221 Session().flush()
1222 1222 events.trigger(events.PullRequestCommentEvent(pull_request, comment))
1223 1223 # we now calculate the status of pull request again, and based on that
1224 1224 # calculation trigger status change. This might happen in cases
1225 1225 # that non-reviewer admin closes a pr, which means his vote doesn't
1226 1226 # change the status, while if he's a reviewer this might change it.
1227 1227 calculated_status = pull_request.calculated_review_status()
1228 1228 if old_calculated_status != calculated_status:
1229 1229 self._trigger_pull_request_hook(
1230 1230 pull_request, user, 'review_status_change')
1231 1231
1232 1232 # finally close the PR
1233 1233 PullRequestModel().close_pull_request(
1234 1234 pull_request.pull_request_id, user)
1235 1235
1236 1236 return comment, status
1237 1237
1238 1238 def merge_status(self, pull_request, translator=None,
1239 1239 force_shadow_repo_refresh=False):
1240 1240 _ = translator or get_current_request().translate
1241 1241
1242 1242 if not self._is_merge_enabled(pull_request):
1243 1243 return False, _('Server-side pull request merging is disabled.')
1244 1244 if pull_request.is_closed():
1245 1245 return False, _('This pull request is closed.')
1246 1246 merge_possible, msg = self._check_repo_requirements(
1247 1247 target=pull_request.target_repo, source=pull_request.source_repo,
1248 1248 translator=_)
1249 1249 if not merge_possible:
1250 1250 return merge_possible, msg
1251 1251
1252 1252 try:
1253 1253 resp = self._try_merge(
1254 1254 pull_request,
1255 1255 force_shadow_repo_refresh=force_shadow_repo_refresh)
1256 1256 log.debug("Merge response: %s", resp)
1257 1257 status = resp.possible, self.merge_status_message(
1258 1258 resp.failure_reason)
1259 1259 except NotImplementedError:
1260 1260 status = False, _('Pull request merging is not supported.')
1261 1261
1262 1262 return status
1263 1263
1264 1264 def _check_repo_requirements(self, target, source, translator):
1265 1265 """
1266 1266 Check if `target` and `source` have compatible requirements.
1267 1267
1268 1268 Currently this is just checking for largefiles.
1269 1269 """
1270 1270 _ = translator
1271 1271 target_has_largefiles = self._has_largefiles(target)
1272 1272 source_has_largefiles = self._has_largefiles(source)
1273 1273 merge_possible = True
1274 1274 message = u''
1275 1275
1276 1276 if target_has_largefiles != source_has_largefiles:
1277 1277 merge_possible = False
1278 1278 if source_has_largefiles:
1279 1279 message = _(
1280 1280 'Target repository large files support is disabled.')
1281 1281 else:
1282 1282 message = _(
1283 1283 'Source repository large files support is disabled.')
1284 1284
1285 1285 return merge_possible, message
1286 1286
1287 1287 def _has_largefiles(self, repo):
1288 1288 largefiles_ui = VcsSettingsModel(repo=repo).get_ui_settings(
1289 1289 'extensions', 'largefiles')
1290 1290 return largefiles_ui and largefiles_ui[0].active
1291 1291
1292 1292 def _try_merge(self, pull_request, force_shadow_repo_refresh=False):
1293 1293 """
1294 1294 Try to merge the pull request and return the merge status.
1295 1295 """
1296 1296 log.debug(
1297 1297 "Trying out if the pull request %s can be merged. Force_refresh=%s",
1298 1298 pull_request.pull_request_id, force_shadow_repo_refresh)
1299 1299 target_vcs = pull_request.target_repo.scm_instance()
1300 1300
1301 1301 # Refresh the target reference.
1302 1302 try:
1303 1303 target_ref = self._refresh_reference(
1304 1304 pull_request.target_ref_parts, target_vcs)
1305 1305 except CommitDoesNotExistError:
1306 1306 merge_state = MergeResponse(
1307 1307 False, False, None, MergeFailureReason.MISSING_TARGET_REF)
1308 1308 return merge_state
1309 1309
1310 1310 target_locked = pull_request.target_repo.locked
1311 1311 if target_locked and target_locked[0]:
1312 1312 log.debug("The target repository is locked.")
1313 1313 merge_state = MergeResponse(
1314 1314 False, False, None, MergeFailureReason.TARGET_IS_LOCKED)
1315 1315 elif force_shadow_repo_refresh or self._needs_merge_state_refresh(
1316 1316 pull_request, target_ref):
1317 1317 log.debug("Refreshing the merge status of the repository.")
1318 1318 merge_state = self._refresh_merge_state(
1319 1319 pull_request, target_vcs, target_ref)
1320 1320 else:
1321 1321 possible = pull_request.\
1322 1322 last_merge_status == MergeFailureReason.NONE
1323 1323 merge_state = MergeResponse(
1324 1324 possible, False, None, pull_request.last_merge_status)
1325 1325
1326 1326 return merge_state
1327 1327
1328 1328 def _refresh_reference(self, reference, vcs_repository):
1329 1329 if reference.type in ('branch', 'book'):
1330 1330 name_or_id = reference.name
1331 1331 else:
1332 1332 name_or_id = reference.commit_id
1333 1333 refreshed_commit = vcs_repository.get_commit(name_or_id)
1334 1334 refreshed_reference = Reference(
1335 1335 reference.type, reference.name, refreshed_commit.raw_id)
1336 1336 return refreshed_reference
1337 1337
1338 1338 def _needs_merge_state_refresh(self, pull_request, target_reference):
1339 1339 return not(
1340 1340 pull_request.revisions and
1341 1341 pull_request.revisions[0] == pull_request._last_merge_source_rev and
1342 1342 target_reference.commit_id == pull_request._last_merge_target_rev)
1343 1343
1344 1344 def _refresh_merge_state(self, pull_request, target_vcs, target_reference):
1345 1345 workspace_id = self._workspace_id(pull_request)
1346 1346 source_vcs = pull_request.source_repo.scm_instance()
1347 1347 repo_id = pull_request.target_repo.repo_id
1348 1348 use_rebase = self._use_rebase_for_merging(pull_request)
1349 1349 close_branch = self._close_branch_before_merging(pull_request)
1350 1350 merge_state = target_vcs.merge(
1351 1351 repo_id, workspace_id,
1352 1352 target_reference, source_vcs, pull_request.source_ref_parts,
1353 1353 dry_run=True, use_rebase=use_rebase,
1354 1354 close_branch=close_branch)
1355 1355
1356 1356 # Do not store the response if there was an unknown error.
1357 1357 if merge_state.failure_reason != MergeFailureReason.UNKNOWN:
1358 1358 pull_request._last_merge_source_rev = \
1359 1359 pull_request.source_ref_parts.commit_id
1360 1360 pull_request._last_merge_target_rev = target_reference.commit_id
1361 1361 pull_request.last_merge_status = merge_state.failure_reason
1362 1362 pull_request.shadow_merge_ref = merge_state.merge_ref
1363 1363 Session().add(pull_request)
1364 1364 Session().commit()
1365 1365
1366 1366 return merge_state
1367 1367
1368 1368 def _workspace_id(self, pull_request):
1369 1369 workspace_id = 'pr-%s' % pull_request.pull_request_id
1370 1370 return workspace_id
1371 1371
1372 1372 def merge_status_message(self, status_code):
1373 1373 """
1374 1374 Return a human friendly error message for the given merge status code.
1375 1375 """
1376 1376 return self.MERGE_STATUS_MESSAGES[status_code]
1377 1377
1378 1378 def generate_repo_data(self, repo, commit_id=None, branch=None,
1379 1379 bookmark=None, translator=None):
1380 1380 from rhodecode.model.repo import RepoModel
1381 1381
1382 1382 all_refs, selected_ref = \
1383 1383 self._get_repo_pullrequest_sources(
1384 1384 repo.scm_instance(), commit_id=commit_id,
1385 1385 branch=branch, bookmark=bookmark, translator=translator)
1386 1386
1387 1387 refs_select2 = []
1388 1388 for element in all_refs:
1389 1389 children = [{'id': x[0], 'text': x[1]} for x in element[0]]
1390 1390 refs_select2.append({'text': element[1], 'children': children})
1391 1391
1392 1392 return {
1393 1393 'user': {
1394 1394 'user_id': repo.user.user_id,
1395 1395 'username': repo.user.username,
1396 1396 'firstname': repo.user.first_name,
1397 1397 'lastname': repo.user.last_name,
1398 1398 'gravatar_link': h.gravatar_url(repo.user.email, 14),
1399 1399 },
1400 1400 'name': repo.repo_name,
1401 1401 'link': RepoModel().get_url(repo),
1402 1402 'description': h.chop_at_smart(repo.description_safe, '\n'),
1403 1403 'refs': {
1404 1404 'all_refs': all_refs,
1405 1405 'selected_ref': selected_ref,
1406 1406 'select2_refs': refs_select2
1407 1407 }
1408 1408 }
1409 1409
1410 1410 def generate_pullrequest_title(self, source, source_ref, target):
1411 1411 return u'{source}#{at_ref} to {target}'.format(
1412 1412 source=source,
1413 1413 at_ref=source_ref,
1414 1414 target=target,
1415 1415 )
1416 1416
1417 1417 def _cleanup_merge_workspace(self, pull_request):
1418 1418 # Merging related cleanup
1419 1419 repo_id = pull_request.target_repo.repo_id
1420 1420 target_scm = pull_request.target_repo.scm_instance()
1421 1421 workspace_id = self._workspace_id(pull_request)
1422 1422
1423 1423 try:
1424 1424 target_scm.cleanup_merge_workspace(repo_id, workspace_id)
1425 1425 except NotImplementedError:
1426 1426 pass
1427 1427
1428 1428 def _get_repo_pullrequest_sources(
1429 1429 self, repo, commit_id=None, branch=None, bookmark=None,
1430 1430 translator=None):
1431 1431 """
1432 1432 Return a structure with repo's interesting commits, suitable for
1433 1433 the selectors in pullrequest controller
1434 1434
1435 1435 :param commit_id: a commit that must be in the list somehow
1436 1436 and selected by default
1437 1437 :param branch: a branch that must be in the list and selected
1438 1438 by default - even if closed
1439 1439 :param bookmark: a bookmark that must be in the list and selected
1440 1440 """
1441 1441 _ = translator or get_current_request().translate
1442 1442
1443 1443 commit_id = safe_str(commit_id) if commit_id else None
1444 1444 branch = safe_str(branch) if branch else None
1445 1445 bookmark = safe_str(bookmark) if bookmark else None
1446 1446
1447 1447 selected = None
1448 1448
1449 1449 # order matters: first source that has commit_id in it will be selected
1450 1450 sources = []
1451 1451 sources.append(('book', repo.bookmarks.items(), _('Bookmarks'), bookmark))
1452 1452 sources.append(('branch', repo.branches.items(), _('Branches'), branch))
1453 1453
1454 1454 if commit_id:
1455 1455 ref_commit = (h.short_id(commit_id), commit_id)
1456 1456 sources.append(('rev', [ref_commit], _('Commit IDs'), commit_id))
1457 1457
1458 1458 sources.append(
1459 1459 ('branch', repo.branches_closed.items(), _('Closed Branches'), branch),
1460 1460 )
1461 1461
1462 1462 groups = []
1463 1463 for group_key, ref_list, group_name, match in sources:
1464 1464 group_refs = []
1465 1465 for ref_name, ref_id in ref_list:
1466 1466 ref_key = '%s:%s:%s' % (group_key, ref_name, ref_id)
1467 1467 group_refs.append((ref_key, ref_name))
1468 1468
1469 1469 if not selected:
1470 1470 if set([commit_id, match]) & set([ref_id, ref_name]):
1471 1471 selected = ref_key
1472 1472
1473 1473 if group_refs:
1474 1474 groups.append((group_refs, group_name))
1475 1475
1476 1476 if not selected:
1477 1477 ref = commit_id or branch or bookmark
1478 1478 if ref:
1479 1479 raise CommitDoesNotExistError(
1480 1480 'No commit refs could be found matching: %s' % ref)
1481 1481 elif repo.DEFAULT_BRANCH_NAME in repo.branches:
1482 1482 selected = 'branch:%s:%s' % (
1483 1483 repo.DEFAULT_BRANCH_NAME,
1484 1484 repo.branches[repo.DEFAULT_BRANCH_NAME]
1485 1485 )
1486 1486 elif repo.commit_ids:
1487 1487 # make the user select in this case
1488 1488 selected = None
1489 1489 else:
1490 1490 raise EmptyRepositoryError()
1491 1491 return groups, selected
1492 1492
1493 1493 def get_diff(self, source_repo, source_ref_id, target_ref_id,
1494 1494 hide_whitespace_changes, diff_context):
1495 1495
1496 1496 return self._get_diff_from_pr_or_version(
1497 1497 source_repo, source_ref_id, target_ref_id,
1498 1498 hide_whitespace_changes=hide_whitespace_changes, diff_context=diff_context)
1499 1499
1500 1500 def _get_diff_from_pr_or_version(
1501 1501 self, source_repo, source_ref_id, target_ref_id,
1502 1502 hide_whitespace_changes, diff_context):
1503 1503
1504 1504 target_commit = source_repo.get_commit(
1505 1505 commit_id=safe_str(target_ref_id))
1506 1506 source_commit = source_repo.get_commit(
1507 1507 commit_id=safe_str(source_ref_id))
1508 1508 if isinstance(source_repo, Repository):
1509 1509 vcs_repo = source_repo.scm_instance()
1510 1510 else:
1511 1511 vcs_repo = source_repo
1512 1512
1513 1513 # TODO: johbo: In the context of an update, we cannot reach
1514 1514 # the old commit anymore with our normal mechanisms. It needs
1515 1515 # some sort of special support in the vcs layer to avoid this
1516 1516 # workaround.
1517 1517 if (source_commit.raw_id == vcs_repo.EMPTY_COMMIT_ID and
1518 1518 vcs_repo.alias == 'git'):
1519 1519 source_commit.raw_id = safe_str(source_ref_id)
1520 1520
1521 1521 log.debug('calculating diff between '
1522 1522 'source_ref:%s and target_ref:%s for repo `%s`',
1523 1523 target_ref_id, source_ref_id,
1524 1524 safe_unicode(vcs_repo.path))
1525 1525
1526 1526 vcs_diff = vcs_repo.get_diff(
1527 1527 commit1=target_commit, commit2=source_commit,
1528 1528 ignore_whitespace=hide_whitespace_changes, context=diff_context)
1529 1529 return vcs_diff
1530 1530
1531 1531 def _is_merge_enabled(self, pull_request):
1532 1532 return self._get_general_setting(
1533 1533 pull_request, 'rhodecode_pr_merge_enabled')
1534 1534
1535 1535 def _use_rebase_for_merging(self, pull_request):
1536 1536 repo_type = pull_request.target_repo.repo_type
1537 1537 if repo_type == 'hg':
1538 1538 return self._get_general_setting(
1539 1539 pull_request, 'rhodecode_hg_use_rebase_for_merging')
1540 1540 elif repo_type == 'git':
1541 1541 return self._get_general_setting(
1542 1542 pull_request, 'rhodecode_git_use_rebase_for_merging')
1543 1543
1544 1544 return False
1545 1545
1546 1546 def _close_branch_before_merging(self, pull_request):
1547 1547 repo_type = pull_request.target_repo.repo_type
1548 1548 if repo_type == 'hg':
1549 1549 return self._get_general_setting(
1550 1550 pull_request, 'rhodecode_hg_close_branch_before_merging')
1551 1551 elif repo_type == 'git':
1552 1552 return self._get_general_setting(
1553 1553 pull_request, 'rhodecode_git_close_branch_before_merging')
1554 1554
1555 1555 return False
1556 1556
1557 1557 def _get_general_setting(self, pull_request, settings_key, default=False):
1558 1558 settings_model = VcsSettingsModel(repo=pull_request.target_repo)
1559 1559 settings = settings_model.get_general_settings()
1560 1560 return settings.get(settings_key, default)
1561 1561
1562 1562 def _log_audit_action(self, action, action_data, user, pull_request):
1563 1563 audit_logger.store(
1564 1564 action=action,
1565 1565 action_data=action_data,
1566 1566 user=user,
1567 1567 repo=pull_request.target_repo)
1568 1568
1569 1569 def get_reviewer_functions(self):
1570 1570 """
1571 1571 Fetches functions for validation and fetching default reviewers.
1572 1572 If available we use the EE package, else we fallback to CE
1573 1573 package functions
1574 1574 """
1575 1575 try:
1576 1576 from rc_reviewers.utils import get_default_reviewers_data
1577 1577 from rc_reviewers.utils import validate_default_reviewers
1578 1578 except ImportError:
1579 from rhodecode.apps.repository.utils import \
1580 get_default_reviewers_data
1581 from rhodecode.apps.repository.utils import \
1582 validate_default_reviewers
1579 from rhodecode.apps.repository.utils import get_default_reviewers_data
1580 from rhodecode.apps.repository.utils import validate_default_reviewers
1583 1581
1584 1582 return get_default_reviewers_data, validate_default_reviewers
1585 1583
1586 1584
1587 1585 class MergeCheck(object):
1588 1586 """
1589 1587 Perform Merge Checks and returns a check object which stores information
1590 1588 about merge errors, and merge conditions
1591 1589 """
1592 1590 TODO_CHECK = 'todo'
1593 1591 PERM_CHECK = 'perm'
1594 1592 REVIEW_CHECK = 'review'
1595 1593 MERGE_CHECK = 'merge'
1596 1594
1597 1595 def __init__(self):
1598 1596 self.review_status = None
1599 1597 self.merge_possible = None
1600 1598 self.merge_msg = ''
1601 1599 self.failed = None
1602 1600 self.errors = []
1603 1601 self.error_details = OrderedDict()
1604 1602
1605 1603 def push_error(self, error_type, message, error_key, details):
1606 1604 self.failed = True
1607 1605 self.errors.append([error_type, message])
1608 1606 self.error_details[error_key] = dict(
1609 1607 details=details,
1610 1608 error_type=error_type,
1611 1609 message=message
1612 1610 )
1613 1611
1614 1612 @classmethod
1615 1613 def validate(cls, pull_request, auth_user, translator, fail_early=False,
1616 1614 force_shadow_repo_refresh=False):
1617 1615 _ = translator
1618 1616 merge_check = cls()
1619 1617
1620 1618 # permissions to merge
1621 1619 user_allowed_to_merge = PullRequestModel().check_user_merge(
1622 1620 pull_request, auth_user)
1623 1621 if not user_allowed_to_merge:
1624 1622 log.debug("MergeCheck: cannot merge, approval is pending.")
1625 1623
1626 1624 msg = _('User `{}` not allowed to perform merge.').format(auth_user.username)
1627 1625 merge_check.push_error('error', msg, cls.PERM_CHECK, auth_user.username)
1628 1626 if fail_early:
1629 1627 return merge_check
1630 1628
1631 1629 # permission to merge into the target branch
1632 1630 target_commit_id = pull_request.target_ref_parts.commit_id
1633 1631 if pull_request.target_ref_parts.type == 'branch':
1634 1632 branch_name = pull_request.target_ref_parts.name
1635 1633 else:
1636 1634 # for mercurial we can always figure out the branch from the commit
1637 1635 # in case of bookmark
1638 1636 target_commit = pull_request.target_repo.get_commit(target_commit_id)
1639 1637 branch_name = target_commit.branch
1640 1638
1641 1639 rule, branch_perm = auth_user.get_rule_and_branch_permission(
1642 1640 pull_request.target_repo.repo_name, branch_name)
1643 1641 if branch_perm and branch_perm == 'branch.none':
1644 1642 msg = _('Target branch `{}` changes rejected by rule {}.').format(
1645 1643 branch_name, rule)
1646 1644 merge_check.push_error('error', msg, cls.PERM_CHECK, auth_user.username)
1647 1645 if fail_early:
1648 1646 return merge_check
1649 1647
1650 1648 # review status, must be always present
1651 1649 review_status = pull_request.calculated_review_status()
1652 1650 merge_check.review_status = review_status
1653 1651
1654 1652 status_approved = review_status == ChangesetStatus.STATUS_APPROVED
1655 1653 if not status_approved:
1656 1654 log.debug("MergeCheck: cannot merge, approval is pending.")
1657 1655
1658 1656 msg = _('Pull request reviewer approval is pending.')
1659 1657
1660 1658 merge_check.push_error(
1661 1659 'warning', msg, cls.REVIEW_CHECK, review_status)
1662 1660
1663 1661 if fail_early:
1664 1662 return merge_check
1665 1663
1666 1664 # left over TODOs
1667 1665 todos = CommentsModel().get_unresolved_todos(pull_request)
1668 1666 if todos:
1669 1667 log.debug("MergeCheck: cannot merge, {} "
1670 1668 "unresolved todos left.".format(len(todos)))
1671 1669
1672 1670 if len(todos) == 1:
1673 1671 msg = _('Cannot merge, {} TODO still not resolved.').format(
1674 1672 len(todos))
1675 1673 else:
1676 1674 msg = _('Cannot merge, {} TODOs still not resolved.').format(
1677 1675 len(todos))
1678 1676
1679 1677 merge_check.push_error('warning', msg, cls.TODO_CHECK, todos)
1680 1678
1681 1679 if fail_early:
1682 1680 return merge_check
1683 1681
1684 1682 # merge possible, here is the filesystem simulation + shadow repo
1685 1683 merge_status, msg = PullRequestModel().merge_status(
1686 1684 pull_request, translator=translator,
1687 1685 force_shadow_repo_refresh=force_shadow_repo_refresh)
1688 1686 merge_check.merge_possible = merge_status
1689 1687 merge_check.merge_msg = msg
1690 1688 if not merge_status:
1691 1689 log.debug(
1692 1690 "MergeCheck: cannot merge, pull request merge not possible.")
1693 1691 merge_check.push_error('warning', msg, cls.MERGE_CHECK, None)
1694 1692
1695 1693 if fail_early:
1696 1694 return merge_check
1697 1695
1698 1696 log.debug('MergeCheck: is failed: %s', merge_check.failed)
1699 1697 return merge_check
1700 1698
1701 1699 @classmethod
1702 1700 def get_merge_conditions(cls, pull_request, translator):
1703 1701 _ = translator
1704 1702 merge_details = {}
1705 1703
1706 1704 model = PullRequestModel()
1707 1705 use_rebase = model._use_rebase_for_merging(pull_request)
1708 1706
1709 1707 if use_rebase:
1710 1708 merge_details['merge_strategy'] = dict(
1711 1709 details={},
1712 1710 message=_('Merge strategy: rebase')
1713 1711 )
1714 1712 else:
1715 1713 merge_details['merge_strategy'] = dict(
1716 1714 details={},
1717 1715 message=_('Merge strategy: explicit merge commit')
1718 1716 )
1719 1717
1720 1718 close_branch = model._close_branch_before_merging(pull_request)
1721 1719 if close_branch:
1722 1720 repo_type = pull_request.target_repo.repo_type
1723 1721 if repo_type == 'hg':
1724 1722 close_msg = _('Source branch will be closed after merge.')
1725 1723 elif repo_type == 'git':
1726 1724 close_msg = _('Source branch will be deleted after merge.')
1727 1725
1728 1726 merge_details['close_branch'] = dict(
1729 1727 details={},
1730 1728 message=close_msg
1731 1729 )
1732 1730
1733 1731 return merge_details
1734 1732
1735 1733 ChangeTuple = collections.namedtuple(
1736 1734 'ChangeTuple', ['added', 'common', 'removed', 'total'])
1737 1735
1738 1736 FileChangeTuple = collections.namedtuple(
1739 1737 'FileChangeTuple', ['added', 'modified', 'removed'])
@@ -1,838 +1,837 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2018 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import os
22 22 import hashlib
23 23 import logging
24 24 from collections import namedtuple
25 25 from functools import wraps
26 26 import bleach
27 27
28 28 from rhodecode.lib import rc_cache
29 29 from rhodecode.lib.utils2 import (
30 30 Optional, AttributeDict, safe_str, remove_prefix, str2bool)
31 31 from rhodecode.lib.vcs.backends import base
32 32 from rhodecode.model import BaseModel
33 33 from rhodecode.model.db import (
34 34 RepoRhodeCodeUi, RepoRhodeCodeSetting, RhodeCodeUi, RhodeCodeSetting, CacheKey)
35 35 from rhodecode.model.meta import Session
36 36
37 37
38 38 log = logging.getLogger(__name__)
39 39
40 40
41 41 UiSetting = namedtuple(
42 42 'UiSetting', ['section', 'key', 'value', 'active'])
43 43
44 44 SOCIAL_PLUGINS_LIST = ['github', 'bitbucket', 'twitter', 'google']
45 45
46 46
47 47 class SettingNotFound(Exception):
48 48 def __init__(self, setting_id):
49 49 msg = 'Setting `{}` is not found'.format(setting_id)
50 50 super(SettingNotFound, self).__init__(msg)
51 51
52 52
53 53 class SettingsModel(BaseModel):
54 54 BUILTIN_HOOKS = (
55 55 RhodeCodeUi.HOOK_REPO_SIZE, RhodeCodeUi.HOOK_PUSH,
56 56 RhodeCodeUi.HOOK_PRE_PUSH, RhodeCodeUi.HOOK_PRETX_PUSH,
57 57 RhodeCodeUi.HOOK_PULL, RhodeCodeUi.HOOK_PRE_PULL,
58 58 RhodeCodeUi.HOOK_PUSH_KEY,)
59 59 HOOKS_SECTION = 'hooks'
60 60
61 61 def __init__(self, sa=None, repo=None):
62 62 self.repo = repo
63 63 self.UiDbModel = RepoRhodeCodeUi if repo else RhodeCodeUi
64 64 self.SettingsDbModel = (
65 65 RepoRhodeCodeSetting if repo else RhodeCodeSetting)
66 66 super(SettingsModel, self).__init__(sa)
67 67
68 68 def get_ui_by_key(self, key):
69 69 q = self.UiDbModel.query()
70 70 q = q.filter(self.UiDbModel.ui_key == key)
71 71 q = self._filter_by_repo(RepoRhodeCodeUi, q)
72 72 return q.scalar()
73 73
74 74 def get_ui_by_section(self, section):
75 75 q = self.UiDbModel.query()
76 76 q = q.filter(self.UiDbModel.ui_section == section)
77 77 q = self._filter_by_repo(RepoRhodeCodeUi, q)
78 78 return q.all()
79 79
80 80 def get_ui_by_section_and_key(self, section, key):
81 81 q = self.UiDbModel.query()
82 82 q = q.filter(self.UiDbModel.ui_section == section)
83 83 q = q.filter(self.UiDbModel.ui_key == key)
84 84 q = self._filter_by_repo(RepoRhodeCodeUi, q)
85 85 return q.scalar()
86 86
87 87 def get_ui(self, section=None, key=None):
88 88 q = self.UiDbModel.query()
89 89 q = self._filter_by_repo(RepoRhodeCodeUi, q)
90 90
91 91 if section:
92 92 q = q.filter(self.UiDbModel.ui_section == section)
93 93 if key:
94 94 q = q.filter(self.UiDbModel.ui_key == key)
95 95
96 96 # TODO: mikhail: add caching
97 97 result = [
98 98 UiSetting(
99 99 section=safe_str(r.ui_section), key=safe_str(r.ui_key),
100 100 value=safe_str(r.ui_value), active=r.ui_active
101 101 )
102 102 for r in q.all()
103 103 ]
104 104 return result
105 105
106 106 def get_builtin_hooks(self):
107 107 q = self.UiDbModel.query()
108 108 q = q.filter(self.UiDbModel.ui_key.in_(self.BUILTIN_HOOKS))
109 109 return self._get_hooks(q)
110 110
111 111 def get_custom_hooks(self):
112 112 q = self.UiDbModel.query()
113 113 q = q.filter(~self.UiDbModel.ui_key.in_(self.BUILTIN_HOOKS))
114 114 return self._get_hooks(q)
115 115
116 116 def create_ui_section_value(self, section, val, key=None, active=True):
117 117 new_ui = self.UiDbModel()
118 118 new_ui.ui_section = section
119 119 new_ui.ui_value = val
120 120 new_ui.ui_active = active
121 121
122 122 if self.repo:
123 123 repo = self._get_repo(self.repo)
124 124 repository_id = repo.repo_id
125 125 new_ui.repository_id = repository_id
126 126
127 127 if not key:
128 128 # keys are unique so they need appended info
129 129 if self.repo:
130 130 key = hashlib.sha1(
131 131 '{}{}{}'.format(section, val, repository_id)).hexdigest()
132 132 else:
133 133 key = hashlib.sha1('{}{}'.format(section, val)).hexdigest()
134 134
135 135 new_ui.ui_key = key
136 136
137 137 Session().add(new_ui)
138 138 return new_ui
139 139
140 140 def create_or_update_hook(self, key, value):
141 141 ui = (
142 142 self.get_ui_by_section_and_key(self.HOOKS_SECTION, key) or
143 143 self.UiDbModel())
144 144 ui.ui_section = self.HOOKS_SECTION
145 145 ui.ui_active = True
146 146 ui.ui_key = key
147 147 ui.ui_value = value
148 148
149 149 if self.repo:
150 150 repo = self._get_repo(self.repo)
151 151 repository_id = repo.repo_id
152 152 ui.repository_id = repository_id
153 153
154 154 Session().add(ui)
155 155 return ui
156 156
157 157 def delete_ui(self, id_):
158 158 ui = self.UiDbModel.get(id_)
159 159 if not ui:
160 160 raise SettingNotFound(id_)
161 161 Session().delete(ui)
162 162
163 163 def get_setting_by_name(self, name):
164 164 q = self._get_settings_query()
165 165 q = q.filter(self.SettingsDbModel.app_settings_name == name)
166 166 return q.scalar()
167 167
168 168 def create_or_update_setting(
169 169 self, name, val=Optional(''), type_=Optional('unicode')):
170 170 """
171 171 Creates or updates RhodeCode setting. If updates is triggered it will
172 172 only update parameters that are explicityl set Optional instance will
173 173 be skipped
174 174
175 175 :param name:
176 176 :param val:
177 177 :param type_:
178 178 :return:
179 179 """
180 180
181 181 res = self.get_setting_by_name(name)
182 182 repo = self._get_repo(self.repo) if self.repo else None
183 183
184 184 if not res:
185 185 val = Optional.extract(val)
186 186 type_ = Optional.extract(type_)
187 187
188 188 args = (
189 189 (repo.repo_id, name, val, type_)
190 190 if repo else (name, val, type_))
191 191 res = self.SettingsDbModel(*args)
192 192
193 193 else:
194 194 if self.repo:
195 195 res.repository_id = repo.repo_id
196 196
197 197 res.app_settings_name = name
198 198 if not isinstance(type_, Optional):
199 199 # update if set
200 200 res.app_settings_type = type_
201 201 if not isinstance(val, Optional):
202 202 # update if set
203 203 res.app_settings_value = val
204 204
205 205 Session().add(res)
206 206 return res
207 207
208 208 def invalidate_settings_cache(self):
209 209 invalidation_namespace = CacheKey.SETTINGS_INVALIDATION_NAMESPACE
210 210 CacheKey.set_invalidate(invalidation_namespace)
211 211
212 212 def get_all_settings(self, cache=False):
213 213 region = rc_cache.get_or_create_region('sql_cache_short')
214 214 invalidation_namespace = CacheKey.SETTINGS_INVALIDATION_NAMESPACE
215 215
216 216 @region.conditional_cache_on_arguments(condition=cache)
217 217 def _get_all_settings(name, key):
218 218 q = self._get_settings_query()
219 219 if not q:
220 220 raise Exception('Could not get application settings !')
221 221
222 222 settings = {
223 223 'rhodecode_' + result.app_settings_name: result.app_settings_value
224 224 for result in q
225 225 }
226 226 return settings
227 227
228 228 repo = self._get_repo(self.repo) if self.repo else None
229 229 key = "settings_repo.{}".format(repo.repo_id) if repo else "settings_app"
230 230
231 231 inv_context_manager = rc_cache.InvalidationContext(
232 232 uid='cache_settings', invalidation_namespace=invalidation_namespace)
233 233 with inv_context_manager as invalidation_context:
234 234 # check for stored invalidation signal, and maybe purge the cache
235 235 # before computing it again
236 236 if invalidation_context.should_invalidate():
237 237 # NOTE:(marcink) we flush the whole sql_cache_short region, because it
238 238 # reads different settings etc. It's little too much but those caches
239 239 # are anyway very short lived and it's a safest way.
240 240 region = rc_cache.get_or_create_region('sql_cache_short')
241 241 region.invalidate()
242 242
243 243 result = _get_all_settings('rhodecode_settings', key)
244 244 log.debug('Fetching app settings for key: %s took: %.3fs', key,
245 245 inv_context_manager.compute_time)
246 246
247 247 return result
248 248
249 249 def get_auth_settings(self):
250 250 q = self._get_settings_query()
251 251 q = q.filter(
252 252 self.SettingsDbModel.app_settings_name.startswith('auth_'))
253 253 rows = q.all()
254 254 auth_settings = {
255 255 row.app_settings_name: row.app_settings_value for row in rows}
256 256 return auth_settings
257 257
258 258 def get_auth_plugins(self):
259 259 auth_plugins = self.get_setting_by_name("auth_plugins")
260 260 return auth_plugins.app_settings_value
261 261
262 262 def get_default_repo_settings(self, strip_prefix=False):
263 263 q = self._get_settings_query()
264 264 q = q.filter(
265 265 self.SettingsDbModel.app_settings_name.startswith('default_'))
266 266 rows = q.all()
267 267
268 268 result = {}
269 269 for row in rows:
270 270 key = row.app_settings_name
271 271 if strip_prefix:
272 272 key = remove_prefix(key, prefix='default_')
273 273 result.update({key: row.app_settings_value})
274 274 return result
275 275
276 276 def get_repo(self):
277 277 repo = self._get_repo(self.repo)
278 278 if not repo:
279 279 raise Exception(
280 280 'Repository `{}` cannot be found inside the database'.format(
281 281 self.repo))
282 282 return repo
283 283
284 284 def _filter_by_repo(self, model, query):
285 285 if self.repo:
286 286 repo = self.get_repo()
287 287 query = query.filter(model.repository_id == repo.repo_id)
288 288 return query
289 289
290 290 def _get_hooks(self, query):
291 291 query = query.filter(self.UiDbModel.ui_section == self.HOOKS_SECTION)
292 292 query = self._filter_by_repo(RepoRhodeCodeUi, query)
293 293 return query.all()
294 294
295 295 def _get_settings_query(self):
296 296 q = self.SettingsDbModel.query()
297 297 return self._filter_by_repo(RepoRhodeCodeSetting, q)
298 298
299 299 def list_enabled_social_plugins(self, settings):
300 300 enabled = []
301 301 for plug in SOCIAL_PLUGINS_LIST:
302 302 if str2bool(settings.get('rhodecode_auth_{}_enabled'.format(plug)
303 303 )):
304 304 enabled.append(plug)
305 305 return enabled
306 306
307 307
308 308 def assert_repo_settings(func):
309 309 @wraps(func)
310 310 def _wrapper(self, *args, **kwargs):
311 311 if not self.repo_settings:
312 312 raise Exception('Repository is not specified')
313 313 return func(self, *args, **kwargs)
314 314 return _wrapper
315 315
316 316
317 317 class IssueTrackerSettingsModel(object):
318 318 INHERIT_SETTINGS = 'inherit_issue_tracker_settings'
319 319 SETTINGS_PREFIX = 'issuetracker_'
320 320
321 321 def __init__(self, sa=None, repo=None):
322 322 self.global_settings = SettingsModel(sa=sa)
323 323 self.repo_settings = SettingsModel(sa=sa, repo=repo) if repo else None
324 324
325 325 @property
326 326 def inherit_global_settings(self):
327 327 if not self.repo_settings:
328 328 return True
329 329 setting = self.repo_settings.get_setting_by_name(self.INHERIT_SETTINGS)
330 330 return setting.app_settings_value if setting else True
331 331
332 332 @inherit_global_settings.setter
333 333 def inherit_global_settings(self, value):
334 334 if self.repo_settings:
335 335 settings = self.repo_settings.create_or_update_setting(
336 336 self.INHERIT_SETTINGS, value, type_='bool')
337 337 Session().add(settings)
338 338
339 339 def _get_keyname(self, key, uid, prefix=''):
340 340 return '{0}{1}{2}_{3}'.format(
341 341 prefix, self.SETTINGS_PREFIX, key, uid)
342 342
343 343 def _make_dict_for_settings(self, qs):
344 344 prefix_match = self._get_keyname('pat', '', 'rhodecode_')
345 345
346 346 issuetracker_entries = {}
347 347 # create keys
348 348 for k, v in qs.items():
349 349 if k.startswith(prefix_match):
350 350 uid = k[len(prefix_match):]
351 351 issuetracker_entries[uid] = None
352 352
353 353 # populate
354 354 for uid in issuetracker_entries:
355 355 issuetracker_entries[uid] = AttributeDict({
356 356 'pat': qs.get(
357 357 self._get_keyname('pat', uid, 'rhodecode_')),
358 358 'url': bleach.clean(
359 359 qs.get(self._get_keyname('url', uid, 'rhodecode_')) or ''),
360 360 'pref': bleach.clean(
361 361 qs.get(self._get_keyname('pref', uid, 'rhodecode_')) or ''),
362 362 'desc': qs.get(
363 363 self._get_keyname('desc', uid, 'rhodecode_')),
364 364 })
365 365 return issuetracker_entries
366 366
367 367 def get_global_settings(self, cache=False):
368 368 """
369 369 Returns list of global issue tracker settings
370 370 """
371 371 defaults = self.global_settings.get_all_settings(cache=cache)
372 372 settings = self._make_dict_for_settings(defaults)
373 373 return settings
374 374
375 375 def get_repo_settings(self, cache=False):
376 376 """
377 377 Returns list of issue tracker settings per repository
378 378 """
379 379 if not self.repo_settings:
380 380 raise Exception('Repository is not specified')
381 381 all_settings = self.repo_settings.get_all_settings(cache=cache)
382 382 settings = self._make_dict_for_settings(all_settings)
383 383 return settings
384 384
385 385 def get_settings(self, cache=False):
386 386 if self.inherit_global_settings:
387 387 return self.get_global_settings(cache=cache)
388 388 else:
389 389 return self.get_repo_settings(cache=cache)
390 390
391 391 def delete_entries(self, uid):
392 392 if self.repo_settings:
393 393 all_patterns = self.get_repo_settings()
394 394 settings_model = self.repo_settings
395 395 else:
396 396 all_patterns = self.get_global_settings()
397 397 settings_model = self.global_settings
398 398 entries = all_patterns.get(uid, [])
399 399
400 400 for del_key in entries:
401 401 setting_name = self._get_keyname(del_key, uid)
402 402 entry = settings_model.get_setting_by_name(setting_name)
403 403 if entry:
404 404 Session().delete(entry)
405 405
406 406 Session().commit()
407 407
408 408 def create_or_update_setting(
409 409 self, name, val=Optional(''), type_=Optional('unicode')):
410 410 if self.repo_settings:
411 411 setting = self.repo_settings.create_or_update_setting(
412 412 name, val, type_)
413 413 else:
414 414 setting = self.global_settings.create_or_update_setting(
415 415 name, val, type_)
416 416 return setting
417 417
418 418
419 419 class VcsSettingsModel(object):
420 420
421 421 INHERIT_SETTINGS = 'inherit_vcs_settings'
422 422 GENERAL_SETTINGS = (
423 423 'use_outdated_comments',
424 424 'pr_merge_enabled',
425 425 'hg_use_rebase_for_merging',
426 426 'hg_close_branch_before_merging',
427 427 'git_use_rebase_for_merging',
428 428 'git_close_branch_before_merging',
429 429 'diff_cache',
430 430 )
431 431
432 432 HOOKS_SETTINGS = (
433 433 ('hooks', 'changegroup.repo_size'),
434 434 ('hooks', 'changegroup.push_logger'),
435 435 ('hooks', 'outgoing.pull_logger'),)
436 436 HG_SETTINGS = (
437 437 ('extensions', 'largefiles'),
438 438 ('phases', 'publish'),
439 439 ('extensions', 'evolve'),)
440 440 GIT_SETTINGS = (
441 441 ('vcs_git_lfs', 'enabled'),)
442 442 GLOBAL_HG_SETTINGS = (
443 443 ('extensions', 'largefiles'),
444 444 ('largefiles', 'usercache'),
445 445 ('phases', 'publish'),
446 446 ('extensions', 'hgsubversion'),
447 447 ('extensions', 'evolve'),)
448 448 GLOBAL_GIT_SETTINGS = (
449 449 ('vcs_git_lfs', 'enabled'),
450 450 ('vcs_git_lfs', 'store_location'))
451 451
452 452 GLOBAL_SVN_SETTINGS = (
453 453 ('vcs_svn_proxy', 'http_requests_enabled'),
454 454 ('vcs_svn_proxy', 'http_server_url'))
455 455
456 456 SVN_BRANCH_SECTION = 'vcs_svn_branch'
457 457 SVN_TAG_SECTION = 'vcs_svn_tag'
458 458 SSL_SETTING = ('web', 'push_ssl')
459 459 PATH_SETTING = ('paths', '/')
460 460
461 461 def __init__(self, sa=None, repo=None):
462 462 self.global_settings = SettingsModel(sa=sa)
463 463 self.repo_settings = SettingsModel(sa=sa, repo=repo) if repo else None
464 464 self._ui_settings = (
465 465 self.HG_SETTINGS + self.GIT_SETTINGS + self.HOOKS_SETTINGS)
466 466 self._svn_sections = (self.SVN_BRANCH_SECTION, self.SVN_TAG_SECTION)
467 467
468 468 @property
469 469 @assert_repo_settings
470 470 def inherit_global_settings(self):
471 471 setting = self.repo_settings.get_setting_by_name(self.INHERIT_SETTINGS)
472 472 return setting.app_settings_value if setting else True
473 473
474 474 @inherit_global_settings.setter
475 475 @assert_repo_settings
476 476 def inherit_global_settings(self, value):
477 477 self.repo_settings.create_or_update_setting(
478 478 self.INHERIT_SETTINGS, value, type_='bool')
479 479
480 480 def get_global_svn_branch_patterns(self):
481 481 return self.global_settings.get_ui_by_section(self.SVN_BRANCH_SECTION)
482 482
483 483 @assert_repo_settings
484 484 def get_repo_svn_branch_patterns(self):
485 485 return self.repo_settings.get_ui_by_section(self.SVN_BRANCH_SECTION)
486 486
487 487 def get_global_svn_tag_patterns(self):
488 488 return self.global_settings.get_ui_by_section(self.SVN_TAG_SECTION)
489 489
490 490 @assert_repo_settings
491 491 def get_repo_svn_tag_patterns(self):
492 492 return self.repo_settings.get_ui_by_section(self.SVN_TAG_SECTION)
493 493
494 494 def get_global_settings(self):
495 495 return self._collect_all_settings(global_=True)
496 496
497 497 @assert_repo_settings
498 498 def get_repo_settings(self):
499 499 return self._collect_all_settings(global_=False)
500 500
501 501 @assert_repo_settings
502 502 def create_or_update_repo_settings(
503 503 self, data, inherit_global_settings=False):
504 504 from rhodecode.model.scm import ScmModel
505 505
506 506 self.inherit_global_settings = inherit_global_settings
507 507
508 508 repo = self.repo_settings.get_repo()
509 509 if not inherit_global_settings:
510 510 if repo.repo_type == 'svn':
511 511 self.create_repo_svn_settings(data)
512 512 else:
513 513 self.create_or_update_repo_hook_settings(data)
514 514 self.create_or_update_repo_pr_settings(data)
515 515
516 516 if repo.repo_type == 'hg':
517 517 self.create_or_update_repo_hg_settings(data)
518 518
519 519 if repo.repo_type == 'git':
520 520 self.create_or_update_repo_git_settings(data)
521 521
522 522 ScmModel().mark_for_invalidation(repo.repo_name, delete=True)
523 523
524 524 @assert_repo_settings
525 525 def create_or_update_repo_hook_settings(self, data):
526 526 for section, key in self.HOOKS_SETTINGS:
527 527 data_key = self._get_form_ui_key(section, key)
528 528 if data_key not in data:
529 529 raise ValueError(
530 530 'The given data does not contain {} key'.format(data_key))
531 531
532 532 active = data.get(data_key)
533 533 repo_setting = self.repo_settings.get_ui_by_section_and_key(
534 534 section, key)
535 535 if not repo_setting:
536 536 global_setting = self.global_settings.\
537 537 get_ui_by_section_and_key(section, key)
538 538 self.repo_settings.create_ui_section_value(
539 539 section, global_setting.ui_value, key=key, active=active)
540 540 else:
541 541 repo_setting.ui_active = active
542 542 Session().add(repo_setting)
543 543
544 544 def update_global_hook_settings(self, data):
545 545 for section, key in self.HOOKS_SETTINGS:
546 546 data_key = self._get_form_ui_key(section, key)
547 547 if data_key not in data:
548 548 raise ValueError(
549 549 'The given data does not contain {} key'.format(data_key))
550 550 active = data.get(data_key)
551 551 repo_setting = self.global_settings.get_ui_by_section_and_key(
552 552 section, key)
553 553 repo_setting.ui_active = active
554 554 Session().add(repo_setting)
555 555
556 556 @assert_repo_settings
557 557 def create_or_update_repo_pr_settings(self, data):
558 558 return self._create_or_update_general_settings(
559 559 self.repo_settings, data)
560 560
561 561 def create_or_update_global_pr_settings(self, data):
562 562 return self._create_or_update_general_settings(
563 563 self.global_settings, data)
564 564
565 565 @assert_repo_settings
566 566 def create_repo_svn_settings(self, data):
567 567 return self._create_svn_settings(self.repo_settings, data)
568 568
569 569 @assert_repo_settings
570 570 def create_or_update_repo_hg_settings(self, data):
571 571 largefiles, phases, evolve = \
572 572 self.HG_SETTINGS
573 573 largefiles_key, phases_key, evolve_key = \
574 574 self._get_settings_keys(self.HG_SETTINGS, data)
575 575
576 576 self._create_or_update_ui(
577 577 self.repo_settings, *largefiles, value='',
578 578 active=data[largefiles_key])
579 579 self._create_or_update_ui(
580 580 self.repo_settings, *evolve, value='',
581 581 active=data[evolve_key])
582 582 self._create_or_update_ui(
583 583 self.repo_settings, *phases, value=safe_str(data[phases_key]))
584 584
585
586 585 def create_or_update_global_hg_settings(self, data):
587 586 largefiles, largefiles_store, phases, hgsubversion, evolve \
588 587 = self.GLOBAL_HG_SETTINGS
589 588 largefiles_key, largefiles_store_key, phases_key, subversion_key, evolve_key \
590 589 = self._get_settings_keys(self.GLOBAL_HG_SETTINGS, data)
591 590
592 591 self._create_or_update_ui(
593 592 self.global_settings, *largefiles, value='',
594 593 active=data[largefiles_key])
595 594 self._create_or_update_ui(
596 595 self.global_settings, *largefiles_store,
597 596 value=data[largefiles_store_key])
598 597 self._create_or_update_ui(
599 598 self.global_settings, *phases, value=safe_str(data[phases_key]))
600 599 self._create_or_update_ui(
601 600 self.global_settings, *hgsubversion, active=data[subversion_key])
602 601 self._create_or_update_ui(
603 602 self.global_settings, *evolve, value='',
604 603 active=data[evolve_key])
605 604
606 605 def create_or_update_repo_git_settings(self, data):
607 606 # NOTE(marcink): # comma make unpack work properly
608 607 lfs_enabled, \
609 608 = self.GIT_SETTINGS
610 609
611 610 lfs_enabled_key, \
612 611 = self._get_settings_keys(self.GIT_SETTINGS, data)
613 612
614 613 self._create_or_update_ui(
615 614 self.repo_settings, *lfs_enabled, value=data[lfs_enabled_key],
616 615 active=data[lfs_enabled_key])
617 616
618 617 def create_or_update_global_git_settings(self, data):
619 618 lfs_enabled, lfs_store_location \
620 619 = self.GLOBAL_GIT_SETTINGS
621 620 lfs_enabled_key, lfs_store_location_key \
622 621 = self._get_settings_keys(self.GLOBAL_GIT_SETTINGS, data)
623 622
624 623 self._create_or_update_ui(
625 624 self.global_settings, *lfs_enabled, value=data[lfs_enabled_key],
626 625 active=data[lfs_enabled_key])
627 626 self._create_or_update_ui(
628 627 self.global_settings, *lfs_store_location,
629 628 value=data[lfs_store_location_key])
630 629
631 630 def create_or_update_global_svn_settings(self, data):
632 631 # branch/tags patterns
633 632 self._create_svn_settings(self.global_settings, data)
634 633
635 634 http_requests_enabled, http_server_url = self.GLOBAL_SVN_SETTINGS
636 635 http_requests_enabled_key, http_server_url_key = self._get_settings_keys(
637 636 self.GLOBAL_SVN_SETTINGS, data)
638 637
639 638 self._create_or_update_ui(
640 639 self.global_settings, *http_requests_enabled,
641 640 value=safe_str(data[http_requests_enabled_key]))
642 641 self._create_or_update_ui(
643 642 self.global_settings, *http_server_url,
644 643 value=data[http_server_url_key])
645 644
646 645 def update_global_ssl_setting(self, value):
647 646 self._create_or_update_ui(
648 647 self.global_settings, *self.SSL_SETTING, value=value)
649 648
650 649 def update_global_path_setting(self, value):
651 650 self._create_or_update_ui(
652 651 self.global_settings, *self.PATH_SETTING, value=value)
653 652
654 653 @assert_repo_settings
655 654 def delete_repo_svn_pattern(self, id_):
656 655 ui = self.repo_settings.UiDbModel.get(id_)
657 656 if ui and ui.repository.repo_name == self.repo_settings.repo:
658 657 # only delete if it's the same repo as initialized settings
659 658 self.repo_settings.delete_ui(id_)
660 659 else:
661 660 # raise error as if we wouldn't find this option
662 661 self.repo_settings.delete_ui(-1)
663 662
664 663 def delete_global_svn_pattern(self, id_):
665 664 self.global_settings.delete_ui(id_)
666 665
667 666 @assert_repo_settings
668 667 def get_repo_ui_settings(self, section=None, key=None):
669 668 global_uis = self.global_settings.get_ui(section, key)
670 669 repo_uis = self.repo_settings.get_ui(section, key)
671 670 filtered_repo_uis = self._filter_ui_settings(repo_uis)
672 671 filtered_repo_uis_keys = [
673 672 (s.section, s.key) for s in filtered_repo_uis]
674 673
675 674 def _is_global_ui_filtered(ui):
676 675 return (
677 676 (ui.section, ui.key) in filtered_repo_uis_keys
678 677 or ui.section in self._svn_sections)
679 678
680 679 filtered_global_uis = [
681 680 ui for ui in global_uis if not _is_global_ui_filtered(ui)]
682 681
683 682 return filtered_global_uis + filtered_repo_uis
684 683
685 684 def get_global_ui_settings(self, section=None, key=None):
686 685 return self.global_settings.get_ui(section, key)
687 686
688 687 def get_ui_settings_as_config_obj(self, section=None, key=None):
689 688 config = base.Config()
690 689
691 690 ui_settings = self.get_ui_settings(section=section, key=key)
692 691
693 692 for entry in ui_settings:
694 693 config.set(entry.section, entry.key, entry.value)
695 694
696 695 return config
697 696
698 697 def get_ui_settings(self, section=None, key=None):
699 698 if not self.repo_settings or self.inherit_global_settings:
700 699 return self.get_global_ui_settings(section, key)
701 700 else:
702 701 return self.get_repo_ui_settings(section, key)
703 702
704 703 def get_svn_patterns(self, section=None):
705 704 if not self.repo_settings:
706 705 return self.get_global_ui_settings(section)
707 706 else:
708 707 return self.get_repo_ui_settings(section)
709 708
710 709 @assert_repo_settings
711 710 def get_repo_general_settings(self):
712 711 global_settings = self.global_settings.get_all_settings()
713 712 repo_settings = self.repo_settings.get_all_settings()
714 713 filtered_repo_settings = self._filter_general_settings(repo_settings)
715 714 global_settings.update(filtered_repo_settings)
716 715 return global_settings
717 716
718 717 def get_global_general_settings(self):
719 718 return self.global_settings.get_all_settings()
720 719
721 720 def get_general_settings(self):
722 721 if not self.repo_settings or self.inherit_global_settings:
723 722 return self.get_global_general_settings()
724 723 else:
725 724 return self.get_repo_general_settings()
726 725
727 726 def get_repos_location(self):
728 727 return self.global_settings.get_ui_by_key('/').ui_value
729 728
730 729 def _filter_ui_settings(self, settings):
731 730 filtered_settings = [
732 731 s for s in settings if self._should_keep_setting(s)]
733 732 return filtered_settings
734 733
735 734 def _should_keep_setting(self, setting):
736 735 keep = (
737 736 (setting.section, setting.key) in self._ui_settings or
738 737 setting.section in self._svn_sections)
739 738 return keep
740 739
741 740 def _filter_general_settings(self, settings):
742 741 keys = ['rhodecode_{}'.format(key) for key in self.GENERAL_SETTINGS]
743 742 return {
744 743 k: settings[k]
745 744 for k in settings if k in keys}
746 745
747 746 def _collect_all_settings(self, global_=False):
748 747 settings = self.global_settings if global_ else self.repo_settings
749 748 result = {}
750 749
751 750 for section, key in self._ui_settings:
752 751 ui = settings.get_ui_by_section_and_key(section, key)
753 752 result_key = self._get_form_ui_key(section, key)
754 753
755 754 if ui:
756 755 if section in ('hooks', 'extensions'):
757 756 result[result_key] = ui.ui_active
758 757 elif result_key in ['vcs_git_lfs_enabled']:
759 758 result[result_key] = ui.ui_active
760 759 else:
761 760 result[result_key] = ui.ui_value
762 761
763 762 for name in self.GENERAL_SETTINGS:
764 763 setting = settings.get_setting_by_name(name)
765 764 if setting:
766 765 result_key = 'rhodecode_{}'.format(name)
767 766 result[result_key] = setting.app_settings_value
768 767
769 768 return result
770 769
771 770 def _get_form_ui_key(self, section, key):
772 771 return '{section}_{key}'.format(
773 772 section=section, key=key.replace('.', '_'))
774 773
775 774 def _create_or_update_ui(
776 775 self, settings, section, key, value=None, active=None):
777 776 ui = settings.get_ui_by_section_and_key(section, key)
778 777 if not ui:
779 778 active = True if active is None else active
780 779 settings.create_ui_section_value(
781 780 section, value, key=key, active=active)
782 781 else:
783 782 if active is not None:
784 783 ui.ui_active = active
785 784 if value is not None:
786 785 ui.ui_value = value
787 786 Session().add(ui)
788 787
789 788 def _create_svn_settings(self, settings, data):
790 789 svn_settings = {
791 790 'new_svn_branch': self.SVN_BRANCH_SECTION,
792 791 'new_svn_tag': self.SVN_TAG_SECTION
793 792 }
794 793 for key in svn_settings:
795 794 if data.get(key):
796 795 settings.create_ui_section_value(svn_settings[key], data[key])
797 796
798 797 def _create_or_update_general_settings(self, settings, data):
799 798 for name in self.GENERAL_SETTINGS:
800 799 data_key = 'rhodecode_{}'.format(name)
801 800 if data_key not in data:
802 801 raise ValueError(
803 802 'The given data does not contain {} key'.format(data_key))
804 803 setting = settings.create_or_update_setting(
805 804 name, data[data_key], 'bool')
806 805 Session().add(setting)
807 806
808 807 def _get_settings_keys(self, settings, data):
809 808 data_keys = [self._get_form_ui_key(*s) for s in settings]
810 809 for data_key in data_keys:
811 810 if data_key not in data:
812 811 raise ValueError(
813 812 'The given data does not contain {} key'.format(data_key))
814 813 return data_keys
815 814
816 815 def create_largeobjects_dirs_if_needed(self, repo_store_path):
817 816 """
818 817 This is subscribed to the `pyramid.events.ApplicationCreated` event. It
819 818 does a repository scan if enabled in the settings.
820 819 """
821 820
822 821 from rhodecode.lib.vcs.backends.hg import largefiles_store
823 822 from rhodecode.lib.vcs.backends.git import lfs_store
824 823
825 824 paths = [
826 825 largefiles_store(repo_store_path),
827 826 lfs_store(repo_store_path)]
828 827
829 828 for path in paths:
830 829 if os.path.isdir(path):
831 830 continue
832 831 if os.path.isfile(path):
833 832 continue
834 833 # not a file nor dir, we try to create it
835 834 try:
836 835 os.makedirs(path)
837 836 except Exception:
838 837 log.warning('Failed to create largefiles dir:%s', path)
@@ -1,553 +1,553 b''
1 1 # -*- coding: utf-8 -*-
2 2
3 3 # Copyright (C) 2010-2018 RhodeCode GmbH
4 4 #
5 5 # This program is free software: you can redistribute it and/or modify
6 6 # it under the terms of the GNU Affero General Public License, version 3
7 7 # (only), as published by the Free Software Foundation.
8 8 #
9 9 # This program is distributed in the hope that it will be useful,
10 10 # but WITHOUT ANY WARRANTY; without even the implied warranty of
11 11 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 12 # GNU General Public License for more details.
13 13 #
14 14 # You should have received a copy of the GNU Affero General Public License
15 15 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 16 #
17 17 # This program is dual-licensed. If you wish to learn more about the
18 18 # RhodeCode Enterprise Edition, including its added features, Support services,
19 19 # and proprietary license terms, please see https://rhodecode.com/licenses/
20 20
21 21 import datetime
22 22 from urllib2 import URLError
23 23
24 24 import mock
25 25 import pytest
26 26
27 27 from rhodecode.lib.vcs import backends
28 28 from rhodecode.lib.vcs.backends.base import (
29 29 Config, BaseInMemoryCommit, Reference, MergeResponse, MergeFailureReason)
30 30 from rhodecode.lib.vcs.exceptions import VCSError, RepositoryError
31 31 from rhodecode.lib.vcs.nodes import FileNode
32 32 from rhodecode.tests.vcs.conftest import BackendTestMixin
33 33 from rhodecode.tests import repo_id_generator
34 34
35 35
36 36 @pytest.mark.usefixtures("vcs_repository_support")
37 37 class TestRepositoryBase(BackendTestMixin):
38 38 recreate_repo_per_test = False
39 39
40 40 def test_init_accepts_unicode_path(self, tmpdir):
41 41 path = unicode(tmpdir.join(u'unicode Γ€'))
42 42 self.Backend(path, create=True)
43 43
44 44 def test_init_accepts_str_path(self, tmpdir):
45 45 path = str(tmpdir.join('str Γ€'))
46 46 self.Backend(path, create=True)
47 47
48 48 def test_init_fails_if_path_does_not_exist(self, tmpdir):
49 49 path = unicode(tmpdir.join('i-do-not-exist'))
50 50 with pytest.raises(VCSError):
51 51 self.Backend(path)
52 52
53 53 def test_init_fails_if_path_is_not_a_valid_repository(self, tmpdir):
54 54 path = unicode(tmpdir.mkdir(u'unicode Γ€'))
55 55 with pytest.raises(VCSError):
56 56 self.Backend(path)
57 57
58 58 def test_has_commits_attribute(self):
59 59 self.repo.commit_ids
60 60
61 61 def test_name(self):
62 62 assert self.repo.name.startswith('vcs-test')
63 63
64 64 @pytest.mark.backends("hg", "git")
65 65 def test_has_default_branch_name(self):
66 66 assert self.repo.DEFAULT_BRANCH_NAME is not None
67 67
68 68 @pytest.mark.backends("svn")
69 69 def test_has_no_default_branch_name(self):
70 70 assert self.repo.DEFAULT_BRANCH_NAME is None
71 71
72 72 def test_has_empty_commit(self):
73 73 assert self.repo.EMPTY_COMMIT_ID is not None
74 74 assert self.repo.EMPTY_COMMIT is not None
75 75
76 76 def test_empty_changeset_is_deprecated(self):
77 77 def get_empty_changeset(repo):
78 78 return repo.EMPTY_CHANGESET
79 79 pytest.deprecated_call(get_empty_changeset, self.repo)
80 80
81 81 def test_bookmarks(self):
82 82 assert len(self.repo.bookmarks) == 0
83 83
84 84 # TODO: Cover two cases: Local repo path, remote URL
85 85 def test_check_url(self):
86 86 config = Config()
87 87 assert self.Backend.check_url(self.repo.path, config)
88 88
89 89 def test_check_url_invalid(self):
90 90 config = Config()
91 91 with pytest.raises(URLError):
92 92 self.Backend.check_url(self.repo.path + "invalid", config)
93 93
94 94 def test_get_contact(self):
95 95 assert self.repo.contact
96 96
97 97 def test_get_description(self):
98 98 assert self.repo.description
99 99
100 100 def test_get_hook_location(self):
101 101 assert len(self.repo.get_hook_location()) != 0
102 102
103 103 def test_last_change(self, local_dt_to_utc):
104 104 assert self.repo.last_change >= local_dt_to_utc(
105 105 datetime.datetime(2010, 1, 1, 21, 0))
106 106
107 107 def test_last_change_in_empty_repository(self, vcsbackend, local_dt_to_utc):
108 108 delta = datetime.timedelta(seconds=1)
109 109
110 110 start = local_dt_to_utc(datetime.datetime.now())
111 111 empty_repo = vcsbackend.create_repo()
112 112 now = local_dt_to_utc(datetime.datetime.now())
113 113 assert empty_repo.last_change >= start - delta
114 114 assert empty_repo.last_change <= now + delta
115 115
116 116 def test_repo_equality(self):
117 117 assert self.repo == self.repo
118 118
119 119 def test_repo_equality_broken_object(self):
120 120 import copy
121 121 _repo = copy.copy(self.repo)
122 122 delattr(_repo, 'path')
123 123 assert self.repo != _repo
124 124
125 125 def test_repo_equality_other_object(self):
126 126 class dummy(object):
127 127 path = self.repo.path
128 128 assert self.repo != dummy()
129 129
130 130 def test_get_commit_is_implemented(self):
131 131 self.repo.get_commit()
132 132
133 133 def test_get_commits_is_implemented(self):
134 134 commit_iter = iter(self.repo.get_commits())
135 135 commit = next(commit_iter)
136 136 assert commit.idx == 0
137 137
138 138 def test_supports_iteration(self):
139 139 repo_iter = iter(self.repo)
140 140 commit = next(repo_iter)
141 141 assert commit.idx == 0
142 142
143 143 def test_in_memory_commit(self):
144 144 imc = self.repo.in_memory_commit
145 145 assert isinstance(imc, BaseInMemoryCommit)
146 146
147 147 @pytest.mark.backends("hg")
148 148 def test__get_url_unicode(self):
149 149 url = u'/home/repos/malmΓΆ'
150 150 assert self.repo._get_url(url)
151 151
152 152
153 153 @pytest.mark.usefixtures("vcs_repository_support")
154 154 class TestDeprecatedRepositoryAPI(BackendTestMixin):
155 155 recreate_repo_per_test = False
156 156
157 157 def test_revisions_is_deprecated(self):
158 158 def get_revisions(repo):
159 159 return repo.revisions
160 160 pytest.deprecated_call(get_revisions, self.repo)
161 161
162 162 def test_get_changeset_is_deprecated(self):
163 163 pytest.deprecated_call(self.repo.get_changeset)
164 164
165 165 def test_get_changesets_is_deprecated(self):
166 166 pytest.deprecated_call(self.repo.get_changesets)
167 167
168 168 def test_in_memory_changeset_is_deprecated(self):
169 169 def get_imc(repo):
170 170 return repo.in_memory_changeset
171 171 pytest.deprecated_call(get_imc, self.repo)
172 172
173 173
174 174 # TODO: these tests are incomplete, must check the resulting compare result for
175 175 # correcteness
176 176 class TestRepositoryCompare:
177 177
178 178 @pytest.mark.parametrize('merge', [True, False])
179 179 def test_compare_commits_of_same_repository(self, vcsbackend, merge):
180 180 target_repo = vcsbackend.create_repo(number_of_commits=5)
181 181 target_repo.compare(
182 182 target_repo[1].raw_id, target_repo[3].raw_id, target_repo,
183 183 merge=merge)
184 184
185 185 @pytest.mark.xfail_backends('svn')
186 186 @pytest.mark.parametrize('merge', [True, False])
187 187 def test_compare_cloned_repositories(self, vcsbackend, merge):
188 188 target_repo = vcsbackend.create_repo(number_of_commits=5)
189 189 source_repo = vcsbackend.clone_repo(target_repo)
190 190 assert target_repo != source_repo
191 191
192 192 vcsbackend.add_file(source_repo, 'newfile', 'somecontent')
193 193 source_commit = source_repo.get_commit()
194 194
195 195 target_repo.compare(
196 196 target_repo[1].raw_id, source_repo[3].raw_id, source_repo,
197 197 merge=merge)
198 198
199 199 @pytest.mark.xfail_backends('svn')
200 200 @pytest.mark.parametrize('merge', [True, False])
201 201 def test_compare_unrelated_repositories(self, vcsbackend, merge):
202 202 orig = vcsbackend.create_repo(number_of_commits=5)
203 203 unrelated = vcsbackend.create_repo(number_of_commits=5)
204 204 assert orig != unrelated
205 205
206 206 orig.compare(
207 207 orig[1].raw_id, unrelated[3].raw_id, unrelated, merge=merge)
208 208
209 209
210 210 class TestRepositoryGetCommonAncestor:
211 211
212 212 def test_get_common_ancestor_from_same_repo_existing(self, vcsbackend):
213 213 target_repo = vcsbackend.create_repo(number_of_commits=5)
214 214
215 215 expected_ancestor = target_repo[2].raw_id
216 216
217 217 assert target_repo.get_common_ancestor(
218 218 commit_id1=target_repo[2].raw_id,
219 219 commit_id2=target_repo[4].raw_id,
220 220 repo2=target_repo
221 221 ) == expected_ancestor
222 222
223 223 assert target_repo.get_common_ancestor(
224 224 commit_id1=target_repo[4].raw_id,
225 225 commit_id2=target_repo[2].raw_id,
226 226 repo2=target_repo
227 227 ) == expected_ancestor
228 228
229 229 @pytest.mark.xfail_backends("svn")
230 230 def test_get_common_ancestor_from_cloned_repo_existing(self, vcsbackend):
231 231 target_repo = vcsbackend.create_repo(number_of_commits=5)
232 232 source_repo = vcsbackend.clone_repo(target_repo)
233 233 assert target_repo != source_repo
234 234
235 235 vcsbackend.add_file(source_repo, 'newfile', 'somecontent')
236 236 source_commit = source_repo.get_commit()
237 237
238 238 expected_ancestor = target_repo[4].raw_id
239 239
240 240 assert target_repo.get_common_ancestor(
241 241 commit_id1=target_repo[4].raw_id,
242 242 commit_id2=source_commit.raw_id,
243 243 repo2=source_repo
244 244 ) == expected_ancestor
245 245
246 246 assert target_repo.get_common_ancestor(
247 247 commit_id1=source_commit.raw_id,
248 248 commit_id2=target_repo[4].raw_id,
249 249 repo2=target_repo
250 250 ) == expected_ancestor
251 251
252 252 @pytest.mark.xfail_backends("svn")
253 253 def test_get_common_ancestor_from_unrelated_repo_missing(self, vcsbackend):
254 254 original = vcsbackend.create_repo(number_of_commits=5)
255 255 unrelated = vcsbackend.create_repo(number_of_commits=5)
256 256 assert original != unrelated
257 257
258 258 assert original.get_common_ancestor(
259 259 commit_id1=original[0].raw_id,
260 260 commit_id2=unrelated[0].raw_id,
261 261 repo2=unrelated
262 ) == None
262 ) is None
263 263
264 264 assert original.get_common_ancestor(
265 265 commit_id1=original[-1].raw_id,
266 266 commit_id2=unrelated[-1].raw_id,
267 267 repo2=unrelated
268 ) == None
268 ) is None
269 269
270 270
271 271 @pytest.mark.backends("git", "hg")
272 272 class TestRepositoryMerge(object):
273 273 def prepare_for_success(self, vcsbackend):
274 274 self.target_repo = vcsbackend.create_repo(number_of_commits=1)
275 275 self.source_repo = vcsbackend.clone_repo(self.target_repo)
276 276 vcsbackend.add_file(self.target_repo, 'README_MERGE1', 'Version 1')
277 277 vcsbackend.add_file(self.source_repo, 'README_MERGE2', 'Version 2')
278 278 imc = self.source_repo.in_memory_commit
279 279 imc.add(FileNode('file_x', content=self.source_repo.name))
280 280 imc.commit(
281 281 message=u'Automatic commit from repo merge test',
282 282 author=u'Automatic')
283 283 self.target_commit = self.target_repo.get_commit()
284 284 self.source_commit = self.source_repo.get_commit()
285 285 # This only works for Git and Mercurial
286 286 default_branch = self.target_repo.DEFAULT_BRANCH_NAME
287 287 self.target_ref = Reference(
288 288 'branch', default_branch, self.target_commit.raw_id)
289 289 self.source_ref = Reference(
290 290 'branch', default_branch, self.source_commit.raw_id)
291 291 self.workspace_id = 'test-merge'
292 292 self.repo_id = repo_id_generator(self.target_repo.path)
293 293
294 294 def prepare_for_conflict(self, vcsbackend):
295 295 self.target_repo = vcsbackend.create_repo(number_of_commits=1)
296 296 self.source_repo = vcsbackend.clone_repo(self.target_repo)
297 297 vcsbackend.add_file(self.target_repo, 'README_MERGE', 'Version 1')
298 298 vcsbackend.add_file(self.source_repo, 'README_MERGE', 'Version 2')
299 299 self.target_commit = self.target_repo.get_commit()
300 300 self.source_commit = self.source_repo.get_commit()
301 301 # This only works for Git and Mercurial
302 302 default_branch = self.target_repo.DEFAULT_BRANCH_NAME
303 303 self.target_ref = Reference(
304 304 'branch', default_branch, self.target_commit.raw_id)
305 305 self.source_ref = Reference(
306 306 'branch', default_branch, self.source_commit.raw_id)
307 307 self.workspace_id = 'test-merge'
308 308 self.repo_id = repo_id_generator(self.target_repo.path)
309 309
310 310 def test_merge_success(self, vcsbackend):
311 311 self.prepare_for_success(vcsbackend)
312 312
313 313 merge_response = self.target_repo.merge(
314 314 self.repo_id, self.workspace_id, self.target_ref, self.source_repo,
315 315 self.source_ref,
316 316 'test user', 'test@rhodecode.com', 'merge message 1',
317 317 dry_run=False)
318 318 expected_merge_response = MergeResponse(
319 319 True, True, merge_response.merge_ref,
320 320 MergeFailureReason.NONE)
321 321 assert merge_response == expected_merge_response
322 322
323 323 target_repo = backends.get_backend(vcsbackend.alias)(
324 324 self.target_repo.path)
325 325 target_commits = list(target_repo.get_commits())
326 326 commit_ids = [c.raw_id for c in target_commits[:-1]]
327 327 assert self.source_ref.commit_id in commit_ids
328 328 assert self.target_ref.commit_id in commit_ids
329 329
330 330 merge_commit = target_commits[-1]
331 331 assert merge_commit.raw_id == merge_response.merge_ref.commit_id
332 332 assert merge_commit.message.strip() == 'merge message 1'
333 333 assert merge_commit.author == 'test user <test@rhodecode.com>'
334 334
335 335 # We call it twice so to make sure we can handle updates
336 336 target_ref = Reference(
337 337 self.target_ref.type, self.target_ref.name,
338 338 merge_response.merge_ref.commit_id)
339 339
340 340 merge_response = target_repo.merge(
341 341 self.repo_id, self.workspace_id, target_ref, self.source_repo, self.source_ref,
342 342 'test user', 'test@rhodecode.com', 'merge message 2',
343 343 dry_run=False)
344 344 expected_merge_response = MergeResponse(
345 345 True, True, merge_response.merge_ref,
346 346 MergeFailureReason.NONE)
347 347 assert merge_response == expected_merge_response
348 348
349 349 target_repo = backends.get_backend(
350 350 vcsbackend.alias)(self.target_repo.path)
351 351 merge_commit = target_repo.get_commit(
352 352 merge_response.merge_ref.commit_id)
353 353 assert merge_commit.message.strip() == 'merge message 1'
354 354 assert merge_commit.author == 'test user <test@rhodecode.com>'
355 355
356 356 def test_merge_success_dry_run(self, vcsbackend):
357 357 self.prepare_for_success(vcsbackend)
358 358
359 359 merge_response = self.target_repo.merge(
360 360 self.repo_id, self.workspace_id, self.target_ref, self.source_repo,
361 361 self.source_ref, dry_run=True)
362 362
363 363 # We call it twice so to make sure we can handle updates
364 364 merge_response_update = self.target_repo.merge(
365 365 self.repo_id, self.workspace_id, self.target_ref, self.source_repo,
366 366 self.source_ref, dry_run=True)
367 367
368 368 # Multiple merges may differ in their commit id. Therefore we set the
369 369 # commit id to `None` before comparing the merge responses.
370 370 merge_response = merge_response._replace(
371 371 merge_ref=merge_response.merge_ref._replace(commit_id=None))
372 372 merge_response_update = merge_response_update._replace(
373 373 merge_ref=merge_response_update.merge_ref._replace(commit_id=None))
374 374
375 375 assert merge_response == merge_response_update
376 376 assert merge_response.possible is True
377 377 assert merge_response.executed is False
378 378 assert merge_response.merge_ref
379 379 assert merge_response.failure_reason is MergeFailureReason.NONE
380 380
381 381 @pytest.mark.parametrize('dry_run', [True, False])
382 382 def test_merge_conflict(self, vcsbackend, dry_run):
383 383 self.prepare_for_conflict(vcsbackend)
384 384 expected_merge_response = MergeResponse(
385 385 False, False, None, MergeFailureReason.MERGE_FAILED)
386 386
387 387 merge_response = self.target_repo.merge(
388 388 self.repo_id, self.workspace_id, self.target_ref,
389 389 self.source_repo, self.source_ref,
390 390 'test_user', 'test@rhodecode.com', 'test message', dry_run=dry_run)
391 391 assert merge_response == expected_merge_response
392 392
393 393 # We call it twice so to make sure we can handle updates
394 394 merge_response = self.target_repo.merge(
395 395 self.repo_id, self.workspace_id, self.target_ref, self.source_repo,
396 396 self.source_ref,
397 397 'test_user', 'test@rhodecode.com', 'test message', dry_run=dry_run)
398 398 assert merge_response == expected_merge_response
399 399
400 400 def test_merge_target_is_not_head(self, vcsbackend):
401 401 self.prepare_for_success(vcsbackend)
402 402 expected_merge_response = MergeResponse(
403 403 False, False, None, MergeFailureReason.TARGET_IS_NOT_HEAD)
404 404
405 405 target_ref = Reference(
406 406 self.target_ref.type, self.target_ref.name, '0' * 40)
407 407
408 408 merge_response = self.target_repo.merge(
409 409 self.repo_id, self.workspace_id, target_ref, self.source_repo,
410 410 self.source_ref, dry_run=True)
411 411
412 412 assert merge_response == expected_merge_response
413 413
414 414 def test_merge_missing_source_reference(self, vcsbackend):
415 415 self.prepare_for_success(vcsbackend)
416 416 expected_merge_response = MergeResponse(
417 417 False, False, None, MergeFailureReason.MISSING_SOURCE_REF)
418 418
419 419 source_ref = Reference(
420 420 self.source_ref.type, 'not_existing', self.source_ref.commit_id)
421 421
422 422 merge_response = self.target_repo.merge(
423 423 self.repo_id, self.workspace_id, self.target_ref,
424 424 self.source_repo, source_ref,
425 425 dry_run=True)
426 426
427 427 assert merge_response == expected_merge_response
428 428
429 429 def test_merge_raises_exception(self, vcsbackend):
430 430 self.prepare_for_success(vcsbackend)
431 431 expected_merge_response = MergeResponse(
432 432 False, False, None, MergeFailureReason.UNKNOWN)
433 433
434 434 with mock.patch.object(self.target_repo, '_merge_repo',
435 435 side_effect=RepositoryError()):
436 436 merge_response = self.target_repo.merge(
437 437 self.repo_id, self.workspace_id, self.target_ref,
438 438 self.source_repo, self.source_ref,
439 439 dry_run=True)
440 440
441 441 assert merge_response == expected_merge_response
442 442
443 443 def test_merge_invalid_user_name(self, vcsbackend):
444 444 repo = vcsbackend.create_repo(number_of_commits=1)
445 445 ref = Reference('branch', 'master', 'not_used')
446 446 workspace_id = 'test-errors-in-merge'
447 447 repo_id = repo_id_generator(workspace_id)
448 448 with pytest.raises(ValueError):
449 449 repo.merge(repo_id, workspace_id, ref, self, ref)
450 450
451 451 def test_merge_invalid_user_email(self, vcsbackend):
452 452 repo = vcsbackend.create_repo(number_of_commits=1)
453 453 ref = Reference('branch', 'master', 'not_used')
454 454 workspace_id = 'test-errors-in-merge'
455 455 repo_id = repo_id_generator(workspace_id)
456 456 with pytest.raises(ValueError):
457 457 repo.merge(
458 458 repo_id, workspace_id, ref, self, ref, 'user name')
459 459
460 460 def test_merge_invalid_message(self, vcsbackend):
461 461 repo = vcsbackend.create_repo(number_of_commits=1)
462 462 ref = Reference('branch', 'master', 'not_used')
463 463 workspace_id = 'test-errors-in-merge'
464 464 repo_id = repo_id_generator(workspace_id)
465 465 with pytest.raises(ValueError):
466 466 repo.merge(
467 467 repo_id, workspace_id, ref, self, ref,
468 468 'user name', 'user@email.com')
469 469
470 470
471 471 @pytest.mark.usefixtures("vcs_repository_support")
472 472 class TestRepositoryStrip(BackendTestMixin):
473 473 recreate_repo_per_test = True
474 474
475 475 @classmethod
476 476 def _get_commits(cls):
477 477 commits = [
478 478 {
479 479 'message': 'Initial commit',
480 480 'author': 'Joe Doe <joe.doe@example.com>',
481 481 'date': datetime.datetime(2010, 1, 1, 20),
482 482 'branch': 'master',
483 483 'added': [
484 484 FileNode('foobar', content='foobar'),
485 485 FileNode('foobar2', content='foobar2'),
486 486 ],
487 487 },
488 488 ]
489 489 for x in xrange(10):
490 490 commit_data = {
491 491 'message': 'Changed foobar - commit%s' % x,
492 492 'author': 'Jane Doe <jane.doe@example.com>',
493 493 'date': datetime.datetime(2010, 1, 1, 21, x),
494 494 'branch': 'master',
495 495 'changed': [
496 496 FileNode('foobar', 'FOOBAR - %s' % x),
497 497 ],
498 498 }
499 499 commits.append(commit_data)
500 500 return commits
501 501
502 502 @pytest.mark.backends("git", "hg")
503 503 def test_strip_commit(self):
504 504 tip = self.repo.get_commit()
505 505 assert tip.idx == 10
506 506 self.repo.strip(tip.raw_id, self.repo.DEFAULT_BRANCH_NAME)
507 507
508 508 tip = self.repo.get_commit()
509 509 assert tip.idx == 9
510 510
511 511 @pytest.mark.backends("git", "hg")
512 512 def test_strip_multiple_commits(self):
513 513 tip = self.repo.get_commit()
514 514 assert tip.idx == 10
515 515
516 516 old = self.repo.get_commit(commit_idx=5)
517 517 self.repo.strip(old.raw_id, self.repo.DEFAULT_BRANCH_NAME)
518 518
519 519 tip = self.repo.get_commit()
520 520 assert tip.idx == 4
521 521
522 522
523 523 @pytest.mark.backends('hg', 'git')
524 524 class TestRepositoryPull(object):
525 525
526 526 def test_pull(self, vcsbackend):
527 527 source_repo = vcsbackend.repo
528 528 target_repo = vcsbackend.create_repo()
529 529 assert len(source_repo.commit_ids) > len(target_repo.commit_ids)
530 530
531 531 target_repo.pull(source_repo.path)
532 532 # Note: Get a fresh instance, avoids caching trouble
533 533 target_repo = vcsbackend.backend(target_repo.path)
534 534 assert len(source_repo.commit_ids) == len(target_repo.commit_ids)
535 535
536 536 def test_pull_wrong_path(self, vcsbackend):
537 537 target_repo = vcsbackend.create_repo()
538 538 with pytest.raises(RepositoryError):
539 539 target_repo.pull(target_repo.path + "wrong")
540 540
541 541 def test_pull_specific_commits(self, vcsbackend):
542 542 source_repo = vcsbackend.repo
543 543 target_repo = vcsbackend.create_repo()
544 544
545 545 second_commit = source_repo[1].raw_id
546 546 if vcsbackend.alias == 'git':
547 547 second_commit_ref = 'refs/test-refs/a'
548 548 source_repo.set_refs(second_commit_ref, second_commit)
549 549
550 550 target_repo.pull(source_repo.path, commit_ids=[second_commit])
551 551 target_repo = vcsbackend.backend(target_repo.path)
552 552 assert 2 == len(target_repo.commit_ids)
553 553 assert second_commit == target_repo.get_commit().raw_id
General Comments 0
You need to be logged in to leave comments. Login now